mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-09 22:35:54 -05:00
Compare commits
3 Commits
feat/mcp-b
...
otto/copil
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9e604528ea | ||
|
|
c3ec7c2880 | ||
|
|
7d9380a793 |
16
.github/workflows/platform-frontend-ci.yml
vendored
16
.github/workflows/platform-frontend-ci.yml
vendored
@@ -27,20 +27,11 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
outputs:
|
outputs:
|
||||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||||
components-changed: ${{ steps.filter.outputs.components }}
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Check for component changes
|
|
||||||
uses: dorny/paths-filter@v3
|
|
||||||
id: filter
|
|
||||||
with:
|
|
||||||
filters: |
|
|
||||||
components:
|
|
||||||
- 'autogpt_platform/frontend/src/components/**'
|
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
@@ -99,11 +90,8 @@ jobs:
|
|||||||
chromatic:
|
chromatic:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: setup
|
needs: setup
|
||||||
# Disabled: to re-enable, remove 'false &&' from the condition below
|
# Only run on dev branch pushes or PRs targeting dev
|
||||||
if: >-
|
if: github.ref == 'refs/heads/dev' || github.base_ref == 'dev'
|
||||||
false
|
|
||||||
&& (github.ref == 'refs/heads/dev' || github.base_ref == 'dev')
|
|
||||||
&& needs.setup.outputs.components-changed == 'true'
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -180,4 +180,3 @@ autogpt_platform/backend/settings.py
|
|||||||
.claude/settings.local.json
|
.claude/settings.local.json
|
||||||
CLAUDE.local.md
|
CLAUDE.local.md
|
||||||
/autogpt_platform/backend/logs
|
/autogpt_platform/backend/logs
|
||||||
.next
|
|
||||||
1320
autogpt_platform/autogpt_libs/poetry.lock
generated
1320
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -11,15 +11,15 @@ python = ">=3.10,<4.0"
|
|||||||
colorama = "^0.4.6"
|
colorama = "^0.4.6"
|
||||||
cryptography = "^45.0"
|
cryptography = "^45.0"
|
||||||
expiringdict = "^1.2.2"
|
expiringdict = "^1.2.2"
|
||||||
fastapi = "^0.128.0"
|
fastapi = "^0.116.1"
|
||||||
google-cloud-logging = "^3.13.0"
|
google-cloud-logging = "^3.12.1"
|
||||||
launchdarkly-server-sdk = "^9.14.1"
|
launchdarkly-server-sdk = "^9.12.0"
|
||||||
pydantic = "^2.12.5"
|
pydantic = "^2.11.7"
|
||||||
pydantic-settings = "^2.12.0"
|
pydantic-settings = "^2.10.1"
|
||||||
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
||||||
redis = "^6.2.0"
|
redis = "^6.2.0"
|
||||||
supabase = "^2.27.2"
|
supabase = "^2.16.0"
|
||||||
uvicorn = "^0.40.0"
|
uvicorn = "^0.35.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pyright = "^1.1.404"
|
pyright = "^1.1.404"
|
||||||
|
|||||||
@@ -152,7 +152,6 @@ REPLICATE_API_KEY=
|
|||||||
REVID_API_KEY=
|
REVID_API_KEY=
|
||||||
SCREENSHOTONE_API_KEY=
|
SCREENSHOTONE_API_KEY=
|
||||||
UNREAL_SPEECH_API_KEY=
|
UNREAL_SPEECH_API_KEY=
|
||||||
ELEVENLABS_API_KEY=
|
|
||||||
|
|
||||||
# Data & Search Services
|
# Data & Search Services
|
||||||
E2B_API_KEY=
|
E2B_API_KEY=
|
||||||
|
|||||||
3
autogpt_platform/backend/.gitignore
vendored
3
autogpt_platform/backend/.gitignore
vendored
@@ -19,6 +19,3 @@ load-tests/*.json
|
|||||||
load-tests/*.log
|
load-tests/*.log
|
||||||
load-tests/node_modules/*
|
load-tests/node_modules/*
|
||||||
migrations/*/rollback*.sql
|
migrations/*/rollback*.sql
|
||||||
|
|
||||||
# Workspace files
|
|
||||||
workspaces/
|
|
||||||
|
|||||||
@@ -62,12 +62,10 @@ ENV POETRY_HOME=/opt/poetry \
|
|||||||
DEBIAN_FRONTEND=noninteractive
|
DEBIAN_FRONTEND=noninteractive
|
||||||
ENV PATH=/opt/poetry/bin:$PATH
|
ENV PATH=/opt/poetry/bin:$PATH
|
||||||
|
|
||||||
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
|
# Install Python without upgrading system-managed packages
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
python3.13 \
|
python3.13 \
|
||||||
python3-pip \
|
python3-pip \
|
||||||
ffmpeg \
|
|
||||||
imagemagick \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy only necessary files from builder
|
# Copy only necessary files from builder
|
||||||
|
|||||||
@@ -1,76 +0,0 @@
|
|||||||
# MCP Block Implementation Plan
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
Create a single **MCPBlock** that dynamically integrates with any MCP (Model Context Protocol)
|
|
||||||
server. Users provide a server URL, the block discovers available tools, presents them as a
|
|
||||||
dropdown, and dynamically adjusts input/output schema based on the selected tool — exactly like
|
|
||||||
`AgentExecutorBlock` handles dynamic schemas.
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
```
|
|
||||||
User provides MCP server URL + credentials
|
|
||||||
↓
|
|
||||||
MCPBlock fetches tools via MCP protocol (tools/list)
|
|
||||||
↓
|
|
||||||
User selects tool from dropdown (stored in constantInput)
|
|
||||||
↓
|
|
||||||
Input schema dynamically updates based on selected tool's inputSchema
|
|
||||||
↓
|
|
||||||
On execution: MCPBlock calls the tool via MCP protocol (tools/call)
|
|
||||||
↓
|
|
||||||
Result yielded as block output
|
|
||||||
```
|
|
||||||
|
|
||||||
## Design Decisions
|
|
||||||
|
|
||||||
1. **Single block, not many blocks** — One `MCPBlock` handles all MCP servers/tools
|
|
||||||
2. **Dynamic schema via AgentExecutorBlock pattern** — Override `get_input_schema()`,
|
|
||||||
`get_input_defaults()`, `get_missing_input()` on the Input class
|
|
||||||
3. **Auth via API key or OAuth2 credentials** — Use existing `APIKeyCredentials` or
|
|
||||||
`OAuth2Credentials` with `ProviderName.MCP` provider. API keys are sent as Bearer tokens;
|
|
||||||
OAuth2 uses the access token.
|
|
||||||
4. **HTTP-based MCP client** — Use `aiohttp` (already a dependency) to implement MCP Streamable
|
|
||||||
HTTP transport directly. No need for the `mcp` Python SDK — the protocol is simple JSON-RPC
|
|
||||||
over HTTP. Handles both JSON and SSE response formats.
|
|
||||||
5. **No new DB tables** — Everything fits in existing `AgentBlock` + `AgentNode` tables
|
|
||||||
|
|
||||||
## Implementation Files
|
|
||||||
|
|
||||||
### New Files
|
|
||||||
- `backend/blocks/mcp/` — MCP block package
|
|
||||||
- `__init__.py`
|
|
||||||
- `block.py` — MCPToolBlock implementation
|
|
||||||
- `client.py` — MCP HTTP client (list_tools, call_tool)
|
|
||||||
- `oauth.py` — MCP OAuth handler for dynamic endpoint discovery
|
|
||||||
- `test_mcp.py` — Unit tests
|
|
||||||
- `test_oauth.py` — OAuth handler tests
|
|
||||||
- `test_integration.py` — Integration tests with local test server
|
|
||||||
- `test_e2e.py` — E2E tests against real MCP servers
|
|
||||||
|
|
||||||
### Modified Files
|
|
||||||
- `backend/integrations/providers.py` — Add `MCP = "mcp"` to ProviderName
|
|
||||||
|
|
||||||
## Dev Loop
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd autogpt_platform/backend
|
|
||||||
poetry run pytest backend/blocks/mcp/test_mcp.py -xvs # Unit tests
|
|
||||||
poetry run pytest backend/blocks/mcp/test_oauth.py -xvs # OAuth tests
|
|
||||||
poetry run pytest backend/blocks/mcp/test_integration.py -xvs # Integration tests
|
|
||||||
poetry run pytest backend/blocks/mcp/ -xvs # All MCP tests
|
|
||||||
```
|
|
||||||
|
|
||||||
## Status
|
|
||||||
|
|
||||||
- [x] Research & Design
|
|
||||||
- [x] Add ProviderName.MCP
|
|
||||||
- [x] Implement MCP client (client.py)
|
|
||||||
- [x] Implement MCPToolBlock (block.py)
|
|
||||||
- [x] Add OAuth2 support (oauth.py)
|
|
||||||
- [x] Write unit tests
|
|
||||||
- [x] Write integration tests
|
|
||||||
- [x] Write E2E tests
|
|
||||||
- [x] Run tests & fix issues
|
|
||||||
- [x] Create PR
|
|
||||||
@@ -1,368 +0,0 @@
|
|||||||
"""Redis Streams consumer for operation completion messages.
|
|
||||||
|
|
||||||
This module provides a consumer (ChatCompletionConsumer) that listens for
|
|
||||||
completion notifications (OperationCompleteMessage) from external services
|
|
||||||
(like Agent Generator) and triggers the appropriate stream registry and
|
|
||||||
chat service updates via process_operation_success/process_operation_failure.
|
|
||||||
|
|
||||||
Why Redis Streams instead of RabbitMQ?
|
|
||||||
--------------------------------------
|
|
||||||
While the project typically uses RabbitMQ for async task queues (e.g., execution
|
|
||||||
queue), Redis Streams was chosen for chat completion notifications because:
|
|
||||||
|
|
||||||
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
|
|
||||||
Streams (via stream_registry) for message persistence and replay. Using Redis
|
|
||||||
Streams for completion notifications keeps all chat streaming infrastructure
|
|
||||||
in one system, simplifying operations and reducing cross-system coordination.
|
|
||||||
|
|
||||||
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
|
|
||||||
allowing consumers to replay missed messages after reconnection. This aligns
|
|
||||||
with the SSE reconnection pattern where clients can resume from last_message_id.
|
|
||||||
|
|
||||||
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
|
|
||||||
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
|
|
||||||
recovering from dead consumers - ideal for the completion callback pattern.
|
|
||||||
|
|
||||||
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
|
|
||||||
stream_registry) provides lower latency than an additional RabbitMQ hop.
|
|
||||||
|
|
||||||
5. **Atomicity with Task State**: Completion processing often needs to update
|
|
||||||
task metadata stored in Redis. Keeping both in Redis enables simpler
|
|
||||||
transactional semantics without distributed coordination.
|
|
||||||
|
|
||||||
The consumer uses Redis Streams with consumer groups for reliable message
|
|
||||||
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
|
|
||||||
stale pending messages from dead consumers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import orjson
|
|
||||||
from prisma import Prisma
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from redis.exceptions import ResponseError
|
|
||||||
|
|
||||||
from backend.data.redis_client import get_redis_async
|
|
||||||
|
|
||||||
from . import stream_registry
|
|
||||||
from .completion_handler import process_operation_failure, process_operation_success
|
|
||||||
from .config import ChatConfig
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
config = ChatConfig()
|
|
||||||
|
|
||||||
|
|
||||||
class OperationCompleteMessage(BaseModel):
|
|
||||||
"""Message format for operation completion notifications."""
|
|
||||||
|
|
||||||
operation_id: str
|
|
||||||
task_id: str
|
|
||||||
success: bool
|
|
||||||
result: dict | str | None = None
|
|
||||||
error: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionConsumer:
|
|
||||||
"""Consumer for chat operation completion messages from Redis Streams.
|
|
||||||
|
|
||||||
This consumer initializes its own Prisma client in start() to ensure
|
|
||||||
database operations work correctly within this async context.
|
|
||||||
|
|
||||||
Uses Redis consumer groups to allow multiple platform pods to consume
|
|
||||||
messages reliably with automatic redelivery on failure.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._consumer_task: asyncio.Task | None = None
|
|
||||||
self._running = False
|
|
||||||
self._prisma: Prisma | None = None
|
|
||||||
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
"""Start the completion consumer."""
|
|
||||||
if self._running:
|
|
||||||
logger.warning("Completion consumer already running")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Create consumer group if it doesn't exist
|
|
||||||
try:
|
|
||||||
redis = await get_redis_async()
|
|
||||||
await redis.xgroup_create(
|
|
||||||
config.stream_completion_name,
|
|
||||||
config.stream_consumer_group,
|
|
||||||
id="0",
|
|
||||||
mkstream=True,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Created consumer group '{config.stream_consumer_group}' "
|
|
||||||
f"on stream '{config.stream_completion_name}'"
|
|
||||||
)
|
|
||||||
except ResponseError as e:
|
|
||||||
if "BUSYGROUP" in str(e):
|
|
||||||
logger.debug(
|
|
||||||
f"Consumer group '{config.stream_consumer_group}' already exists"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
self._running = True
|
|
||||||
self._consumer_task = asyncio.create_task(self._consume_messages())
|
|
||||||
logger.info(
|
|
||||||
f"Chat completion consumer started (consumer: {self._consumer_name})"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _ensure_prisma(self) -> Prisma:
|
|
||||||
"""Lazily initialize Prisma client on first use."""
|
|
||||||
if self._prisma is None:
|
|
||||||
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
|
||||||
self._prisma = Prisma(datasource={"url": database_url})
|
|
||||||
await self._prisma.connect()
|
|
||||||
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
|
|
||||||
return self._prisma
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
"""Stop the completion consumer."""
|
|
||||||
self._running = False
|
|
||||||
|
|
||||||
if self._consumer_task:
|
|
||||||
self._consumer_task.cancel()
|
|
||||||
try:
|
|
||||||
await self._consumer_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
self._consumer_task = None
|
|
||||||
|
|
||||||
if self._prisma:
|
|
||||||
await self._prisma.disconnect()
|
|
||||||
self._prisma = None
|
|
||||||
logger.info("[COMPLETION] Consumer Prisma client disconnected")
|
|
||||||
|
|
||||||
logger.info("Chat completion consumer stopped")
|
|
||||||
|
|
||||||
async def _consume_messages(self) -> None:
|
|
||||||
"""Main message consumption loop with retry logic."""
|
|
||||||
max_retries = 10
|
|
||||||
retry_delay = 5 # seconds
|
|
||||||
retry_count = 0
|
|
||||||
block_timeout = 5000 # milliseconds
|
|
||||||
|
|
||||||
while self._running and retry_count < max_retries:
|
|
||||||
try:
|
|
||||||
redis = await get_redis_async()
|
|
||||||
|
|
||||||
# Reset retry count on successful connection
|
|
||||||
retry_count = 0
|
|
||||||
|
|
||||||
while self._running:
|
|
||||||
# First, claim any stale pending messages from dead consumers
|
|
||||||
# Redis does NOT auto-redeliver pending messages; we must explicitly
|
|
||||||
# claim them using XAUTOCLAIM
|
|
||||||
try:
|
|
||||||
claimed_result = await redis.xautoclaim(
|
|
||||||
name=config.stream_completion_name,
|
|
||||||
groupname=config.stream_consumer_group,
|
|
||||||
consumername=self._consumer_name,
|
|
||||||
min_idle_time=config.stream_claim_min_idle_ms,
|
|
||||||
start_id="0-0",
|
|
||||||
count=10,
|
|
||||||
)
|
|
||||||
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
|
|
||||||
if claimed_result and len(claimed_result) >= 2:
|
|
||||||
claimed_entries = claimed_result[1]
|
|
||||||
if claimed_entries:
|
|
||||||
logger.info(
|
|
||||||
f"Claimed {len(claimed_entries)} stale pending messages"
|
|
||||||
)
|
|
||||||
for entry_id, data in claimed_entries:
|
|
||||||
if not self._running:
|
|
||||||
return
|
|
||||||
await self._process_entry(redis, entry_id, data)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
|
|
||||||
|
|
||||||
# Read new messages from the stream
|
|
||||||
messages = await redis.xreadgroup(
|
|
||||||
groupname=config.stream_consumer_group,
|
|
||||||
consumername=self._consumer_name,
|
|
||||||
streams={config.stream_completion_name: ">"},
|
|
||||||
block=block_timeout,
|
|
||||||
count=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not messages:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for stream_name, entries in messages:
|
|
||||||
for entry_id, data in entries:
|
|
||||||
if not self._running:
|
|
||||||
return
|
|
||||||
await self._process_entry(redis, entry_id, data)
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.info("Consumer cancelled")
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
retry_count += 1
|
|
||||||
logger.error(
|
|
||||||
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
if self._running and retry_count < max_retries:
|
|
||||||
await asyncio.sleep(retry_delay)
|
|
||||||
else:
|
|
||||||
logger.error("Max retries reached, stopping consumer")
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _process_entry(
|
|
||||||
self, redis: Any, entry_id: str, data: dict[str, Any]
|
|
||||||
) -> None:
|
|
||||||
"""Process a single stream entry and acknowledge it on success.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
redis: Redis client connection
|
|
||||||
entry_id: The stream entry ID
|
|
||||||
data: The entry data dict
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Handle the message
|
|
||||||
message_data = data.get("data")
|
|
||||||
if message_data:
|
|
||||||
await self._handle_message(
|
|
||||||
message_data.encode()
|
|
||||||
if isinstance(message_data, str)
|
|
||||||
else message_data
|
|
||||||
)
|
|
||||||
|
|
||||||
# Acknowledge the message after successful processing
|
|
||||||
await redis.xack(
|
|
||||||
config.stream_completion_name,
|
|
||||||
config.stream_consumer_group,
|
|
||||||
entry_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error processing completion message {entry_id}: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
# Message remains in pending state and will be claimed by
|
|
||||||
# XAUTOCLAIM after min_idle_time expires
|
|
||||||
|
|
||||||
async def _handle_message(self, body: bytes) -> None:
|
|
||||||
"""Handle a completion message using our own Prisma client."""
|
|
||||||
try:
|
|
||||||
data = orjson.loads(body)
|
|
||||||
message = OperationCompleteMessage(**data)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to parse completion message: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[COMPLETION] Received completion for operation {message.operation_id} "
|
|
||||||
f"(task_id={message.task_id}, success={message.success})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Find task in registry
|
|
||||||
task = await stream_registry.find_task_by_operation_id(message.operation_id)
|
|
||||||
if task is None:
|
|
||||||
task = await stream_registry.get_task(message.task_id)
|
|
||||||
|
|
||||||
if task is None:
|
|
||||||
logger.warning(
|
|
||||||
f"[COMPLETION] Task not found for operation {message.operation_id} "
|
|
||||||
f"(task_id={message.task_id})"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[COMPLETION] Found task: task_id={task.task_id}, "
|
|
||||||
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Guard against empty task fields
|
|
||||||
if not task.task_id or not task.session_id or not task.tool_call_id:
|
|
||||||
logger.error(
|
|
||||||
f"[COMPLETION] Task has empty critical fields! "
|
|
||||||
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
|
|
||||||
f"tool_call_id={task.tool_call_id!r}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if message.success:
|
|
||||||
await self._handle_success(task, message)
|
|
||||||
else:
|
|
||||||
await self._handle_failure(task, message)
|
|
||||||
|
|
||||||
async def _handle_success(
|
|
||||||
self,
|
|
||||||
task: stream_registry.ActiveTask,
|
|
||||||
message: OperationCompleteMessage,
|
|
||||||
) -> None:
|
|
||||||
"""Handle successful operation completion."""
|
|
||||||
prisma = await self._ensure_prisma()
|
|
||||||
await process_operation_success(task, message.result, prisma)
|
|
||||||
|
|
||||||
async def _handle_failure(
|
|
||||||
self,
|
|
||||||
task: stream_registry.ActiveTask,
|
|
||||||
message: OperationCompleteMessage,
|
|
||||||
) -> None:
|
|
||||||
"""Handle failed operation completion."""
|
|
||||||
prisma = await self._ensure_prisma()
|
|
||||||
await process_operation_failure(task, message.error, prisma)
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level consumer instance
|
|
||||||
_consumer: ChatCompletionConsumer | None = None
|
|
||||||
|
|
||||||
|
|
||||||
async def start_completion_consumer() -> None:
|
|
||||||
"""Start the global completion consumer."""
|
|
||||||
global _consumer
|
|
||||||
if _consumer is None:
|
|
||||||
_consumer = ChatCompletionConsumer()
|
|
||||||
await _consumer.start()
|
|
||||||
|
|
||||||
|
|
||||||
async def stop_completion_consumer() -> None:
|
|
||||||
"""Stop the global completion consumer."""
|
|
||||||
global _consumer
|
|
||||||
if _consumer:
|
|
||||||
await _consumer.stop()
|
|
||||||
_consumer = None
|
|
||||||
|
|
||||||
|
|
||||||
async def publish_operation_complete(
|
|
||||||
operation_id: str,
|
|
||||||
task_id: str,
|
|
||||||
success: bool,
|
|
||||||
result: dict | str | None = None,
|
|
||||||
error: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Publish an operation completion message to Redis Streams.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
operation_id: The operation ID that completed.
|
|
||||||
task_id: The task ID associated with the operation.
|
|
||||||
success: Whether the operation succeeded.
|
|
||||||
result: The result data (for success).
|
|
||||||
error: The error message (for failure).
|
|
||||||
"""
|
|
||||||
message = OperationCompleteMessage(
|
|
||||||
operation_id=operation_id,
|
|
||||||
task_id=task_id,
|
|
||||||
success=success,
|
|
||||||
result=result,
|
|
||||||
error=error,
|
|
||||||
)
|
|
||||||
|
|
||||||
redis = await get_redis_async()
|
|
||||||
await redis.xadd(
|
|
||||||
config.stream_completion_name,
|
|
||||||
{"data": message.model_dump_json()},
|
|
||||||
maxlen=config.stream_max_length,
|
|
||||||
)
|
|
||||||
logger.info(f"Published completion for operation {operation_id}")
|
|
||||||
@@ -1,344 +0,0 @@
|
|||||||
"""Shared completion handling for operation success and failure.
|
|
||||||
|
|
||||||
This module provides common logic for handling operation completion from both:
|
|
||||||
- The Redis Streams consumer (completion_consumer.py)
|
|
||||||
- The HTTP webhook endpoint (routes.py)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import orjson
|
|
||||||
from prisma import Prisma
|
|
||||||
|
|
||||||
from . import service as chat_service
|
|
||||||
from . import stream_registry
|
|
||||||
from .response_model import StreamError, StreamToolOutputAvailable
|
|
||||||
from .tools.models import ErrorResponse
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Tools that produce agent_json that needs to be saved to library
|
|
||||||
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
|
|
||||||
|
|
||||||
# Keys that should be stripped from agent_json when returning in error responses
|
|
||||||
SENSITIVE_KEYS = frozenset(
|
|
||||||
{
|
|
||||||
"api_key",
|
|
||||||
"apikey",
|
|
||||||
"api_secret",
|
|
||||||
"password",
|
|
||||||
"secret",
|
|
||||||
"credentials",
|
|
||||||
"credential",
|
|
||||||
"token",
|
|
||||||
"access_token",
|
|
||||||
"refresh_token",
|
|
||||||
"private_key",
|
|
||||||
"privatekey",
|
|
||||||
"auth",
|
|
||||||
"authorization",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_agent_json(obj: Any) -> Any:
|
|
||||||
"""Recursively sanitize agent_json by removing sensitive keys.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obj: The object to sanitize (dict, list, or primitive)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Sanitized copy with sensitive keys removed/redacted
|
|
||||||
"""
|
|
||||||
if isinstance(obj, dict):
|
|
||||||
return {
|
|
||||||
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
|
|
||||||
for k, v in obj.items()
|
|
||||||
}
|
|
||||||
elif isinstance(obj, list):
|
|
||||||
return [_sanitize_agent_json(item) for item in obj]
|
|
||||||
else:
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
|
||||||
class ToolMessageUpdateError(Exception):
|
|
||||||
"""Raised when updating a tool message in the database fails."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def _update_tool_message(
|
|
||||||
session_id: str,
|
|
||||||
tool_call_id: str,
|
|
||||||
content: str,
|
|
||||||
prisma_client: Prisma | None,
|
|
||||||
) -> None:
|
|
||||||
"""Update tool message in database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_id: The session ID
|
|
||||||
tool_call_id: The tool call ID to update
|
|
||||||
content: The new content for the message
|
|
||||||
prisma_client: Optional Prisma client. If None, uses chat_service.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ToolMessageUpdateError: If the database update fails. The caller should
|
|
||||||
handle this to avoid marking the task as completed with inconsistent state.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if prisma_client:
|
|
||||||
# Use provided Prisma client (for consumer with its own connection)
|
|
||||||
updated_count = await prisma_client.chatmessage.update_many(
|
|
||||||
where={
|
|
||||||
"sessionId": session_id,
|
|
||||||
"toolCallId": tool_call_id,
|
|
||||||
},
|
|
||||||
data={"content": content},
|
|
||||||
)
|
|
||||||
# Check if any rows were updated - 0 means message not found
|
|
||||||
if updated_count == 0:
|
|
||||||
raise ToolMessageUpdateError(
|
|
||||||
f"No message found with tool_call_id={tool_call_id} in session {session_id}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Use service function (for webhook endpoint)
|
|
||||||
await chat_service._update_pending_operation(
|
|
||||||
session_id=session_id,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
result=content,
|
|
||||||
)
|
|
||||||
except ToolMessageUpdateError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
|
|
||||||
raise ToolMessageUpdateError(
|
|
||||||
f"Failed to update tool message for tool_call_id={tool_call_id}: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
|
|
||||||
"""Serialize result to JSON string with sensible defaults.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
result: The result to serialize. Can be a dict, list, string,
|
|
||||||
number, boolean, or None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
JSON string representation of the result. Returns '{"status": "completed"}'
|
|
||||||
only when result is explicitly None.
|
|
||||||
"""
|
|
||||||
if isinstance(result, str):
|
|
||||||
return result
|
|
||||||
if result is None:
|
|
||||||
return '{"status": "completed"}'
|
|
||||||
return orjson.dumps(result).decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
async def _save_agent_from_result(
|
|
||||||
result: dict[str, Any],
|
|
||||||
user_id: str | None,
|
|
||||||
tool_name: str,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Save agent to library if result contains agent_json.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
result: The result dict that may contain agent_json
|
|
||||||
user_id: The user ID to save the agent for
|
|
||||||
tool_name: The tool name (create_agent or edit_agent)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated result dict with saved agent details, or original result if no agent_json
|
|
||||||
"""
|
|
||||||
if not user_id:
|
|
||||||
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
|
|
||||||
return result
|
|
||||||
|
|
||||||
agent_json = result.get("agent_json")
|
|
||||||
if not agent_json:
|
|
||||||
logger.warning(
|
|
||||||
f"[COMPLETION] {tool_name} completed but no agent_json in result"
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
try:
|
|
||||||
from .tools.agent_generator import save_agent_to_library
|
|
||||||
|
|
||||||
is_update = tool_name == "edit_agent"
|
|
||||||
created_graph, library_agent = await save_agent_to_library(
|
|
||||||
agent_json, user_id, is_update=is_update
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
|
|
||||||
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return a response similar to AgentSavedResponse
|
|
||||||
return {
|
|
||||||
"type": "agent_saved",
|
|
||||||
"message": f"Agent '{created_graph.name}' has been saved to your library!",
|
|
||||||
"agent_id": created_graph.id,
|
|
||||||
"agent_name": created_graph.name,
|
|
||||||
"library_agent_id": library_agent.id,
|
|
||||||
"library_agent_link": f"/library/agents/{library_agent.id}",
|
|
||||||
"agent_page_link": f"/build?flowID={created_graph.id}",
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"[COMPLETION] Failed to save agent to library: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
# Return error but don't fail the whole operation
|
|
||||||
# Sanitize agent_json to remove sensitive keys before returning
|
|
||||||
return {
|
|
||||||
"type": "error",
|
|
||||||
"message": f"Agent was generated but failed to save: {str(e)}",
|
|
||||||
"error": str(e),
|
|
||||||
"agent_json": _sanitize_agent_json(agent_json),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def process_operation_success(
|
|
||||||
task: stream_registry.ActiveTask,
|
|
||||||
result: dict | str | None,
|
|
||||||
prisma_client: Prisma | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Handle successful operation completion.
|
|
||||||
|
|
||||||
Publishes the result to the stream registry, updates the database,
|
|
||||||
generates LLM continuation, and marks the task as completed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task: The active task that completed
|
|
||||||
result: The result data from the operation
|
|
||||||
prisma_client: Optional Prisma client for database operations.
|
|
||||||
If None, uses chat_service._update_pending_operation instead.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ToolMessageUpdateError: If the database update fails. The task will be
|
|
||||||
marked as failed instead of completed to avoid inconsistent state.
|
|
||||||
"""
|
|
||||||
# For agent generation tools, save the agent to library
|
|
||||||
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
|
|
||||||
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
|
|
||||||
|
|
||||||
# Serialize result for output (only substitute default when result is exactly None)
|
|
||||||
result_output = result if result is not None else {"status": "completed"}
|
|
||||||
output_str = (
|
|
||||||
result_output
|
|
||||||
if isinstance(result_output, str)
|
|
||||||
else orjson.dumps(result_output).decode("utf-8")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Publish result to stream registry
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task.task_id,
|
|
||||||
StreamToolOutputAvailable(
|
|
||||||
toolCallId=task.tool_call_id,
|
|
||||||
toolName=task.tool_name,
|
|
||||||
output=output_str,
|
|
||||||
success=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update pending operation in database
|
|
||||||
# If this fails, we must not continue to mark the task as completed
|
|
||||||
result_str = serialize_result(result)
|
|
||||||
try:
|
|
||||||
await _update_tool_message(
|
|
||||||
session_id=task.session_id,
|
|
||||||
tool_call_id=task.tool_call_id,
|
|
||||||
content=result_str,
|
|
||||||
prisma_client=prisma_client,
|
|
||||||
)
|
|
||||||
except ToolMessageUpdateError:
|
|
||||||
# DB update failed - mark task as failed to avoid inconsistent state
|
|
||||||
logger.error(
|
|
||||||
f"[COMPLETION] DB update failed for task {task.task_id}, "
|
|
||||||
"marking as failed instead of completed"
|
|
||||||
)
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task.task_id,
|
|
||||||
StreamError(errorText="Failed to save operation result to database"),
|
|
||||||
)
|
|
||||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Generate LLM continuation with streaming
|
|
||||||
try:
|
|
||||||
await chat_service._generate_llm_continuation_with_streaming(
|
|
||||||
session_id=task.session_id,
|
|
||||||
user_id=task.user_id,
|
|
||||||
task_id=task.task_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"[COMPLETION] Failed to generate LLM continuation: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mark task as completed and release Redis lock
|
|
||||||
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
|
||||||
try:
|
|
||||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def process_operation_failure(
|
|
||||||
task: stream_registry.ActiveTask,
|
|
||||||
error: str | None,
|
|
||||||
prisma_client: Prisma | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Handle failed operation completion.
|
|
||||||
|
|
||||||
Publishes the error to the stream registry, updates the database with
|
|
||||||
the error response, and marks the task as failed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task: The active task that failed
|
|
||||||
error: The error message from the operation
|
|
||||||
prisma_client: Optional Prisma client for database operations.
|
|
||||||
If None, uses chat_service._update_pending_operation instead.
|
|
||||||
"""
|
|
||||||
error_msg = error or "Operation failed"
|
|
||||||
|
|
||||||
# Publish error to stream registry
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task.task_id,
|
|
||||||
StreamError(errorText=error_msg),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update pending operation with error
|
|
||||||
# If this fails, we still continue to mark the task as failed
|
|
||||||
error_response = ErrorResponse(
|
|
||||||
message=error_msg,
|
|
||||||
error=error,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
await _update_tool_message(
|
|
||||||
session_id=task.session_id,
|
|
||||||
tool_call_id=task.tool_call_id,
|
|
||||||
content=error_response.model_dump_json(),
|
|
||||||
prisma_client=prisma_client,
|
|
||||||
)
|
|
||||||
except ToolMessageUpdateError:
|
|
||||||
# DB update failed - log but continue with cleanup
|
|
||||||
logger.error(
|
|
||||||
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
|
|
||||||
"continuing with cleanup"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mark task as failed and release Redis lock
|
|
||||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
|
||||||
try:
|
|
||||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
|
||||||
|
|
||||||
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")
|
|
||||||
@@ -11,7 +11,7 @@ class ChatConfig(BaseSettings):
|
|||||||
|
|
||||||
# OpenAI API Configuration
|
# OpenAI API Configuration
|
||||||
model: str = Field(
|
model: str = Field(
|
||||||
default="anthropic/claude-opus-4.6", description="Default model to use"
|
default="anthropic/claude-opus-4.5", description="Default model to use"
|
||||||
)
|
)
|
||||||
title_model: str = Field(
|
title_model: str = Field(
|
||||||
default="openai/gpt-4o-mini",
|
default="openai/gpt-4o-mini",
|
||||||
@@ -44,48 +44,6 @@ class ChatConfig(BaseSettings):
|
|||||||
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Stream registry configuration for SSE reconnection
|
|
||||||
stream_ttl: int = Field(
|
|
||||||
default=3600,
|
|
||||||
description="TTL in seconds for stream data in Redis (1 hour)",
|
|
||||||
)
|
|
||||||
stream_max_length: int = Field(
|
|
||||||
default=10000,
|
|
||||||
description="Maximum number of messages to store per stream",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Redis Streams configuration for completion consumer
|
|
||||||
stream_completion_name: str = Field(
|
|
||||||
default="chat:completions",
|
|
||||||
description="Redis Stream name for operation completions",
|
|
||||||
)
|
|
||||||
stream_consumer_group: str = Field(
|
|
||||||
default="chat_consumers",
|
|
||||||
description="Consumer group name for completion stream",
|
|
||||||
)
|
|
||||||
stream_claim_min_idle_ms: int = Field(
|
|
||||||
default=60000,
|
|
||||||
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Redis key prefixes for stream registry
|
|
||||||
task_meta_prefix: str = Field(
|
|
||||||
default="chat:task:meta:",
|
|
||||||
description="Prefix for task metadata hash keys",
|
|
||||||
)
|
|
||||||
task_stream_prefix: str = Field(
|
|
||||||
default="chat:stream:",
|
|
||||||
description="Prefix for task message stream keys",
|
|
||||||
)
|
|
||||||
task_op_prefix: str = Field(
|
|
||||||
default="chat:task:op:",
|
|
||||||
description="Prefix for operation ID to task ID mapping keys",
|
|
||||||
)
|
|
||||||
internal_api_key: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Langfuse Prompt Management Configuration
|
# Langfuse Prompt Management Configuration
|
||||||
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
||||||
langfuse_prompt_name: str = Field(
|
langfuse_prompt_name: str = Field(
|
||||||
@@ -124,14 +82,6 @@ class ChatConfig(BaseSettings):
|
|||||||
v = "https://openrouter.ai/api/v1"
|
v = "https://openrouter.ai/api/v1"
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("internal_api_key", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def get_internal_api_key(cls, v):
|
|
||||||
"""Get internal API key from environment if not provided."""
|
|
||||||
if v is None:
|
|
||||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
|
||||||
return v
|
|
||||||
|
|
||||||
# Prompt paths for different contexts
|
# Prompt paths for different contexts
|
||||||
PROMPT_PATHS: dict[str, str] = {
|
PROMPT_PATHS: dict[str, str] = {
|
||||||
"default": "prompts/chat_system.md",
|
"default": "prompts/chat_system.md",
|
||||||
|
|||||||
@@ -52,10 +52,6 @@ class StreamStart(StreamBaseResponse):
|
|||||||
|
|
||||||
type: ResponseType = ResponseType.START
|
type: ResponseType = ResponseType.START
|
||||||
messageId: str = Field(..., description="Unique message ID")
|
messageId: str = Field(..., description="Unique message ID")
|
||||||
taskId: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StreamFinish(StreamBaseResponse):
|
class StreamFinish(StreamBaseResponse):
|
||||||
|
|||||||
@@ -1,26 +1,30 @@
|
|||||||
"""Chat API routes for chat session management and streaming via SSE."""
|
"""Chat API routes for chat session management and streaming via SSE."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import uuid as uuid_module
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from autogpt_libs import auth
|
from autogpt_libs import auth
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
|
from fastapi import APIRouter, Depends, Query, Security
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
from . import service as chat_service
|
from . import service as chat_service
|
||||||
from . import stream_registry
|
|
||||||
from .completion_handler import process_operation_failure, process_operation_success
|
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||||
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
|
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
|
# SSE response headers for streaming
|
||||||
|
SSE_RESPONSE_HEADERS = {
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"x-vercel-ai-ui-message-stream": "v1",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -36,6 +40,60 @@ async def _validate_and_get_session(
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_stream_generator(
|
||||||
|
session_id: str,
|
||||||
|
message: str,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
is_user_message: bool = True,
|
||||||
|
context: dict[str, str] | None = None,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Create SSE event generator for chat streaming.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: Chat session ID
|
||||||
|
message: User message to process
|
||||||
|
user_id: Optional authenticated user ID
|
||||||
|
session: Pre-fetched chat session
|
||||||
|
is_user_message: Whether the message is from a user
|
||||||
|
context: Optional context dict with url and content
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
SSE-formatted chunks from the chat completion stream
|
||||||
|
"""
|
||||||
|
chunk_count = 0
|
||||||
|
first_chunk_type: str | None = None
|
||||||
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
|
session_id,
|
||||||
|
message,
|
||||||
|
is_user_message=is_user_message,
|
||||||
|
user_id=user_id,
|
||||||
|
session=session,
|
||||||
|
context=context,
|
||||||
|
):
|
||||||
|
if chunk_count < 3:
|
||||||
|
logger.info(
|
||||||
|
"Chat stream chunk",
|
||||||
|
extra={
|
||||||
|
"session_id": session_id,
|
||||||
|
"chunk_type": str(chunk.type),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if not first_chunk_type:
|
||||||
|
first_chunk_type = str(chunk.type)
|
||||||
|
chunk_count += 1
|
||||||
|
yield chunk.to_sse()
|
||||||
|
logger.info(
|
||||||
|
"Chat stream completed",
|
||||||
|
extra={
|
||||||
|
"session_id": session_id,
|
||||||
|
"chunk_count": chunk_count,
|
||||||
|
"first_chunk_type": first_chunk_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
tags=["chat"],
|
tags=["chat"],
|
||||||
)
|
)
|
||||||
@@ -59,15 +117,6 @@ class CreateSessionResponse(BaseModel):
|
|||||||
user_id: str | None
|
user_id: str | None
|
||||||
|
|
||||||
|
|
||||||
class ActiveStreamInfo(BaseModel):
|
|
||||||
"""Information about an active stream for reconnection."""
|
|
||||||
|
|
||||||
task_id: str
|
|
||||||
last_message_id: str # Redis Stream message ID for resumption
|
|
||||||
operation_id: str # Operation ID for completion tracking
|
|
||||||
tool_name: str # Name of the tool being executed
|
|
||||||
|
|
||||||
|
|
||||||
class SessionDetailResponse(BaseModel):
|
class SessionDetailResponse(BaseModel):
|
||||||
"""Response model providing complete details for a chat session, including messages."""
|
"""Response model providing complete details for a chat session, including messages."""
|
||||||
|
|
||||||
@@ -76,7 +125,6 @@ class SessionDetailResponse(BaseModel):
|
|||||||
updated_at: str
|
updated_at: str
|
||||||
user_id: str | None
|
user_id: str | None
|
||||||
messages: list[dict]
|
messages: list[dict]
|
||||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
|
||||||
|
|
||||||
|
|
||||||
class SessionSummaryResponse(BaseModel):
|
class SessionSummaryResponse(BaseModel):
|
||||||
@@ -95,14 +143,6 @@ class ListSessionsResponse(BaseModel):
|
|||||||
total: int
|
total: int
|
||||||
|
|
||||||
|
|
||||||
class OperationCompleteRequest(BaseModel):
|
|
||||||
"""Request model for external completion webhook."""
|
|
||||||
|
|
||||||
success: bool
|
|
||||||
result: dict | str | None = None
|
|
||||||
error: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Routes ==========
|
# ========== Routes ==========
|
||||||
|
|
||||||
|
|
||||||
@@ -188,14 +228,13 @@ async def get_session(
|
|||||||
Retrieve the details of a specific chat session.
|
Retrieve the details of a specific chat session.
|
||||||
|
|
||||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||||
If there's an active stream for this session, returns the task_id for reconnection.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The unique identifier for the desired chat session.
|
session_id: The unique identifier for the desired chat session.
|
||||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
|
SessionDetailResponse: Details for the requested session, or None if not found.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
session = await get_chat_session(session_id, user_id)
|
session = await get_chat_session(session_id, user_id)
|
||||||
@@ -203,28 +242,11 @@ async def get_session(
|
|||||||
raise NotFoundError(f"Session {session_id} not found.")
|
raise NotFoundError(f"Session {session_id} not found.")
|
||||||
|
|
||||||
messages = [message.model_dump() for message in session.messages]
|
messages = [message.model_dump() for message in session.messages]
|
||||||
|
logger.info(
|
||||||
# Check if there's an active stream for this session
|
f"Returning session {session_id}: "
|
||||||
active_stream_info = None
|
f"message_count={len(messages)}, "
|
||||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
f"roles={[m.get('role') for m in messages]}"
|
||||||
session_id, user_id
|
|
||||||
)
|
)
|
||||||
if active_task:
|
|
||||||
# Filter out the in-progress assistant message from the session response.
|
|
||||||
# The client will receive the complete assistant response through the SSE
|
|
||||||
# stream replay instead, preventing duplicate content.
|
|
||||||
if messages and messages[-1].get("role") == "assistant":
|
|
||||||
messages = messages[:-1]
|
|
||||||
|
|
||||||
# Use "0-0" as last_message_id to replay the stream from the beginning.
|
|
||||||
# Since we filtered out the cached assistant message, the client needs
|
|
||||||
# the full stream to reconstruct the response.
|
|
||||||
active_stream_info = ActiveStreamInfo(
|
|
||||||
task_id=active_task.task_id,
|
|
||||||
last_message_id="0-0",
|
|
||||||
operation_id=active_task.operation_id,
|
|
||||||
tool_name=active_task.tool_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
return SessionDetailResponse(
|
return SessionDetailResponse(
|
||||||
id=session.session_id,
|
id=session.session_id,
|
||||||
@@ -232,7 +254,6 @@ async def get_session(
|
|||||||
updated_at=session.updated_at.isoformat(),
|
updated_at=session.updated_at.isoformat(),
|
||||||
user_id=session.user_id or None,
|
user_id=session.user_id or None,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
active_stream=active_stream_info,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -252,122 +273,27 @@ async def stream_chat_post(
|
|||||||
- Tool call UI elements (if invoked)
|
- Tool call UI elements (if invoked)
|
||||||
- Tool execution results
|
- Tool execution results
|
||||||
|
|
||||||
The AI generation runs in a background task that continues even if the client disconnects.
|
|
||||||
All chunks are written to Redis for reconnection support. If the client disconnects,
|
|
||||||
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The chat session identifier to associate with the streamed messages.
|
session_id: The chat session identifier to associate with the streamed messages.
|
||||||
request: Request body containing message, is_user_message, and optional context.
|
request: Request body containing message, is_user_message, and optional context.
|
||||||
user_id: Optional authenticated user ID.
|
user_id: Optional authenticated user ID.
|
||||||
Returns:
|
Returns:
|
||||||
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
|
StreamingResponse: SSE-formatted response chunks.
|
||||||
containing the task_id for reconnection.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
|
||||||
|
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
# Create a task in the stream registry for reconnection support
|
|
||||||
task_id = str(uuid_module.uuid4())
|
|
||||||
operation_id = str(uuid_module.uuid4())
|
|
||||||
await stream_registry.create_task(
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=user_id,
|
|
||||||
tool_call_id="chat_stream", # Not a tool call, but needed for the model
|
|
||||||
tool_name="chat",
|
|
||||||
operation_id=operation_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Background task that runs the AI generation independently of SSE connection
|
|
||||||
async def run_ai_generation():
|
|
||||||
try:
|
|
||||||
# Emit a start event with task_id for reconnection
|
|
||||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
|
||||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
|
||||||
|
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
|
||||||
session_id,
|
|
||||||
request.message,
|
|
||||||
is_user_message=request.is_user_message,
|
|
||||||
user_id=user_id,
|
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
|
||||||
context=request.context,
|
|
||||||
):
|
|
||||||
# Write to Redis (subscribers will receive via XREAD)
|
|
||||||
await stream_registry.publish_chunk(task_id, chunk)
|
|
||||||
|
|
||||||
# Mark task as completed
|
|
||||||
await stream_registry.mark_task_completed(task_id, "completed")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error in background AI generation for session {session_id}: {e}"
|
|
||||||
)
|
|
||||||
await stream_registry.mark_task_completed(task_id, "failed")
|
|
||||||
|
|
||||||
# Start the AI generation in a background task
|
|
||||||
bg_task = asyncio.create_task(run_ai_generation())
|
|
||||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
|
||||||
|
|
||||||
# SSE endpoint that subscribes to the task's stream
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
|
||||||
subscriber_queue = None
|
|
||||||
try:
|
|
||||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
|
||||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
|
||||||
task_id=task_id,
|
|
||||||
user_id=user_id,
|
|
||||||
last_message_id="0-0", # Get all messages from the beginning
|
|
||||||
)
|
|
||||||
|
|
||||||
if subscriber_queue is None:
|
|
||||||
yield StreamFinish().to_sse()
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
return
|
|
||||||
|
|
||||||
# Read from the subscriber queue and yield to SSE
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
|
||||||
yield chunk.to_sse()
|
|
||||||
|
|
||||||
# Check for finish signal
|
|
||||||
if isinstance(chunk, StreamFinish):
|
|
||||||
break
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
# Send heartbeat to keep connection alive
|
|
||||||
yield StreamHeartbeat().to_sse()
|
|
||||||
|
|
||||||
except GeneratorExit:
|
|
||||||
pass # Client disconnected - background task continues
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in SSE stream for task {task_id}: {e}")
|
|
||||||
finally:
|
|
||||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
|
||||||
if subscriber_queue is not None:
|
|
||||||
try:
|
|
||||||
await stream_registry.unsubscribe_from_task(
|
|
||||||
task_id, subscriber_queue
|
|
||||||
)
|
|
||||||
except Exception as unsub_err:
|
|
||||||
logger.error(
|
|
||||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
_create_stream_generator(
|
||||||
|
session_id=session_id,
|
||||||
|
message=request.message,
|
||||||
|
user_id=user_id,
|
||||||
|
session=session,
|
||||||
|
is_user_message=request.is_user_message,
|
||||||
|
context=request.context,
|
||||||
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers=SSE_RESPONSE_HEADERS,
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
|
||||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -399,48 +325,16 @@ async def stream_chat_get(
|
|||||||
"""
|
"""
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
|
||||||
chunk_count = 0
|
|
||||||
first_chunk_type: str | None = None
|
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
|
||||||
session_id,
|
|
||||||
message,
|
|
||||||
is_user_message=is_user_message,
|
|
||||||
user_id=user_id,
|
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
|
||||||
):
|
|
||||||
if chunk_count < 3:
|
|
||||||
logger.info(
|
|
||||||
"Chat stream chunk",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_type": str(chunk.type),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if not first_chunk_type:
|
|
||||||
first_chunk_type = str(chunk.type)
|
|
||||||
chunk_count += 1
|
|
||||||
yield chunk.to_sse()
|
|
||||||
logger.info(
|
|
||||||
"Chat stream completed",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_count": chunk_count,
|
|
||||||
"first_chunk_type": first_chunk_type,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# AI SDK protocol termination
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
_create_stream_generator(
|
||||||
|
session_id=session_id,
|
||||||
|
message=message,
|
||||||
|
user_id=user_id,
|
||||||
|
session=session,
|
||||||
|
is_user_message=is_user_message,
|
||||||
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers=SSE_RESPONSE_HEADERS,
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
|
||||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -470,251 +364,6 @@ async def session_assign_user(
|
|||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
# ========== Task Streaming (SSE Reconnection) ==========
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/tasks/{task_id}/stream",
|
|
||||||
)
|
|
||||||
async def stream_task(
|
|
||||||
task_id: str,
|
|
||||||
user_id: str | None = Depends(auth.get_user_id),
|
|
||||||
last_message_id: str = Query(
|
|
||||||
default="0-0",
|
|
||||||
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
|
||||||
),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Reconnect to a long-running task's SSE stream.
|
|
||||||
|
|
||||||
When a long-running operation (like agent generation) starts, the client
|
|
||||||
receives a task_id. If the connection drops, the client can reconnect
|
|
||||||
using this endpoint to resume receiving updates.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: The task ID from the operation_started response.
|
|
||||||
user_id: Authenticated user ID for ownership validation.
|
|
||||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
|
|
||||||
"""
|
|
||||||
# Check task existence and expiry before subscribing
|
|
||||||
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
|
|
||||||
|
|
||||||
if error_code == "TASK_EXPIRED":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=410,
|
|
||||||
detail={
|
|
||||||
"code": "TASK_EXPIRED",
|
|
||||||
"message": "This operation has expired. Please try again.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if error_code == "TASK_NOT_FOUND":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail={
|
|
||||||
"code": "TASK_NOT_FOUND",
|
|
||||||
"message": f"Task {task_id} not found.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate ownership if task has an owner
|
|
||||||
if task and task.user_id and user_id != task.user_id:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail={
|
|
||||||
"code": "ACCESS_DENIED",
|
|
||||||
"message": "You do not have access to this task.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get subscriber queue from stream registry
|
|
||||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
|
||||||
task_id=task_id,
|
|
||||||
user_id=user_id,
|
|
||||||
last_message_id=last_message_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if subscriber_queue is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail={
|
|
||||||
"code": "TASK_NOT_FOUND",
|
|
||||||
"message": f"Task {task_id} not found or access denied.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
# Wait for next chunk with timeout for heartbeats
|
|
||||||
chunk = await asyncio.wait_for(
|
|
||||||
subscriber_queue.get(), timeout=heartbeat_interval
|
|
||||||
)
|
|
||||||
yield chunk.to_sse()
|
|
||||||
|
|
||||||
# Check for finish signal
|
|
||||||
if isinstance(chunk, StreamFinish):
|
|
||||||
break
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
# Send heartbeat to keep connection alive
|
|
||||||
yield StreamHeartbeat().to_sse()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
|
|
||||||
finally:
|
|
||||||
# Unsubscribe when client disconnects or stream ends
|
|
||||||
try:
|
|
||||||
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
|
|
||||||
except Exception as unsub_err:
|
|
||||||
logger.error(
|
|
||||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
event_generator(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
"x-vercel-ai-ui-message-stream": "v1",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/tasks/{task_id}",
|
|
||||||
)
|
|
||||||
async def get_task_status(
|
|
||||||
task_id: str,
|
|
||||||
user_id: str | None = Depends(auth.get_user_id),
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Get the status of a long-running task.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: The task ID to check.
|
|
||||||
user_id: Authenticated user ID for ownership validation.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Task status including task_id, status, tool_name, and operation_id.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotFoundError: If task_id is not found or user doesn't have access.
|
|
||||||
"""
|
|
||||||
task = await stream_registry.get_task(task_id)
|
|
||||||
|
|
||||||
if task is None:
|
|
||||||
raise NotFoundError(f"Task {task_id} not found.")
|
|
||||||
|
|
||||||
# Validate ownership - if task has an owner, requester must match
|
|
||||||
if task.user_id and user_id != task.user_id:
|
|
||||||
raise NotFoundError(f"Task {task_id} not found.")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"task_id": task.task_id,
|
|
||||||
"session_id": task.session_id,
|
|
||||||
"status": task.status,
|
|
||||||
"tool_name": task.tool_name,
|
|
||||||
"operation_id": task.operation_id,
|
|
||||||
"created_at": task.created_at.isoformat(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ========== External Completion Webhook ==========
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/operations/{operation_id}/complete",
|
|
||||||
status_code=200,
|
|
||||||
)
|
|
||||||
async def complete_operation(
|
|
||||||
operation_id: str,
|
|
||||||
request: OperationCompleteRequest,
|
|
||||||
x_api_key: str | None = Header(default=None),
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
External completion webhook for long-running operations.
|
|
||||||
|
|
||||||
Called by Agent Generator (or other services) when an operation completes.
|
|
||||||
This triggers the stream registry to publish completion and continue LLM generation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
operation_id: The operation ID to complete.
|
|
||||||
request: Completion payload with success status and result/error.
|
|
||||||
x_api_key: Internal API key for authentication.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Status of the completion.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If API key is invalid or operation not found.
|
|
||||||
"""
|
|
||||||
# Validate internal API key - reject if not configured or invalid
|
|
||||||
if not config.internal_api_key:
|
|
||||||
logger.error(
|
|
||||||
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=503,
|
|
||||||
detail="Webhook not available: internal API key not configured",
|
|
||||||
)
|
|
||||||
if x_api_key != config.internal_api_key:
|
|
||||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
|
||||||
|
|
||||||
# Find task by operation_id
|
|
||||||
task = await stream_registry.find_task_by_operation_id(operation_id)
|
|
||||||
if task is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail=f"Operation {operation_id} not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Received completion webhook for operation {operation_id} "
|
|
||||||
f"(task_id={task.task_id}, success={request.success})"
|
|
||||||
)
|
|
||||||
|
|
||||||
if request.success:
|
|
||||||
await process_operation_success(task, request.result)
|
|
||||||
else:
|
|
||||||
await process_operation_failure(task, request.error)
|
|
||||||
|
|
||||||
return {"status": "ok", "task_id": task.task_id}
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Configuration ==========
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/config/ttl", status_code=200)
|
|
||||||
async def get_ttl_config() -> dict:
|
|
||||||
"""
|
|
||||||
Get the stream TTL configuration.
|
|
||||||
|
|
||||||
Returns the Time-To-Live settings for chat streams, which determines
|
|
||||||
how long clients can reconnect to an active stream.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: TTL configuration with seconds and milliseconds values.
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"stream_ttl_seconds": config.stream_ttl,
|
|
||||||
"stream_ttl_ms": config.stream_ttl * 1000,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Health Check ==========
|
# ========== Health Check ==========
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -33,10 +33,9 @@ from backend.data.understanding import (
|
|||||||
get_business_understanding,
|
get_business_understanding,
|
||||||
)
|
)
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
from backend.util.settings import AppEnvironment, Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
from . import db as chat_db
|
from . import db as chat_db
|
||||||
from . import stream_registry
|
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import (
|
from .model import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
@@ -222,18 +221,8 @@ async def _get_system_prompt_template(context: str) -> str:
|
|||||||
try:
|
try:
|
||||||
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
||||||
# Use asyncio.to_thread to avoid blocking the event loop
|
# Use asyncio.to_thread to avoid blocking the event loop
|
||||||
# In non-production environments, fetch the latest prompt version
|
|
||||||
# instead of the production-labeled version for easier testing
|
|
||||||
label = (
|
|
||||||
None
|
|
||||||
if settings.config.app_env == AppEnvironment.PRODUCTION
|
|
||||||
else "latest"
|
|
||||||
)
|
|
||||||
prompt = await asyncio.to_thread(
|
prompt = await asyncio.to_thread(
|
||||||
langfuse.get_prompt,
|
langfuse.get_prompt, config.langfuse_prompt_name, cache_ttl_seconds=0
|
||||||
config.langfuse_prompt_name,
|
|
||||||
label=label,
|
|
||||||
cache_ttl_seconds=0,
|
|
||||||
)
|
)
|
||||||
return prompt.compile(users_information=context)
|
return prompt.compile(users_information=context)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -628,9 +617,6 @@ async def stream_chat_completion(
|
|||||||
total_tokens=chunk.totalTokens,
|
total_tokens=chunk.totalTokens,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(chunk, StreamHeartbeat):
|
|
||||||
# Pass through heartbeat to keep SSE connection alive
|
|
||||||
yield chunk
|
|
||||||
else:
|
else:
|
||||||
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
||||||
|
|
||||||
@@ -1198,9 +1184,8 @@ async def _yield_tool_call(
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Generate operation ID and task ID
|
# Generate operation ID
|
||||||
operation_id = str(uuid_module.uuid4())
|
operation_id = str(uuid_module.uuid4())
|
||||||
task_id = str(uuid_module.uuid4())
|
|
||||||
|
|
||||||
# Build a user-friendly message based on tool and arguments
|
# Build a user-friendly message based on tool and arguments
|
||||||
if tool_name == "create_agent":
|
if tool_name == "create_agent":
|
||||||
@@ -1243,16 +1228,6 @@ async def _yield_tool_call(
|
|||||||
|
|
||||||
# Wrap session save and task creation in try-except to release lock on failure
|
# Wrap session save and task creation in try-except to release lock on failure
|
||||||
try:
|
try:
|
||||||
# Create task in stream registry for SSE reconnection support
|
|
||||||
await stream_registry.create_task(
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=session.session_id,
|
|
||||||
user_id=session.user_id,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
tool_name=tool_name,
|
|
||||||
operation_id=operation_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save assistant message with tool_call FIRST (required by LLM)
|
# Save assistant message with tool_call FIRST (required by LLM)
|
||||||
assistant_message = ChatMessage(
|
assistant_message = ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
@@ -1274,27 +1249,23 @@ async def _yield_tool_call(
|
|||||||
session.messages.append(pending_message)
|
session.messages.append(pending_message)
|
||||||
await upsert_chat_session(session)
|
await upsert_chat_session(session)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Saved pending operation {operation_id} (task_id={task_id}) "
|
f"Saved pending operation {operation_id} for tool {tool_name} "
|
||||||
f"for tool {tool_name} in session {session.session_id}"
|
f"in session {session.session_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store task reference in module-level set to prevent GC before completion
|
# Store task reference in module-level set to prevent GC before completion
|
||||||
bg_task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
_execute_long_running_tool_with_streaming(
|
_execute_long_running_tool(
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
parameters=arguments,
|
parameters=arguments,
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
task_id=task_id,
|
|
||||||
session_id=session.session_id,
|
session_id=session.session_id,
|
||||||
user_id=session.user_id,
|
user_id=session.user_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
_background_tasks.add(bg_task)
|
_background_tasks.add(task)
|
||||||
bg_task.add_done_callback(_background_tasks.discard)
|
task.add_done_callback(_background_tasks.discard)
|
||||||
|
|
||||||
# Associate the asyncio task with the stream registry task
|
|
||||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Roll back appended messages to prevent data corruption on subsequent saves
|
# Roll back appended messages to prevent data corruption on subsequent saves
|
||||||
if (
|
if (
|
||||||
@@ -1312,11 +1283,6 @@ async def _yield_tool_call(
|
|||||||
|
|
||||||
# Release the Redis lock since the background task won't be spawned
|
# Release the Redis lock since the background task won't be spawned
|
||||||
await _mark_operation_completed(tool_call_id)
|
await _mark_operation_completed(tool_call_id)
|
||||||
# Mark stream registry task as failed if it was created
|
|
||||||
try:
|
|
||||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
|
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
|
||||||
)
|
)
|
||||||
@@ -1330,7 +1296,6 @@ async def _yield_tool_call(
|
|||||||
message=started_msg,
|
message=started_msg,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
task_id=task_id, # Include task_id for SSE reconnection
|
|
||||||
).model_dump_json(),
|
).model_dump_json(),
|
||||||
success=True,
|
success=True,
|
||||||
)
|
)
|
||||||
@@ -1400,9 +1365,6 @@ async def _execute_long_running_tool(
|
|||||||
|
|
||||||
This function runs independently of the SSE connection, so the operation
|
This function runs independently of the SSE connection, so the operation
|
||||||
survives if the user closes their browser tab.
|
survives if the user closes their browser tab.
|
||||||
|
|
||||||
NOTE: This is the legacy function without stream registry support.
|
|
||||||
Use _execute_long_running_tool_with_streaming for new implementations.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Load fresh session (not stale reference)
|
# Load fresh session (not stale reference)
|
||||||
@@ -1455,133 +1417,6 @@ async def _execute_long_running_tool(
|
|||||||
await _mark_operation_completed(tool_call_id)
|
await _mark_operation_completed(tool_call_id)
|
||||||
|
|
||||||
|
|
||||||
async def _execute_long_running_tool_with_streaming(
|
|
||||||
tool_name: str,
|
|
||||||
parameters: dict[str, Any],
|
|
||||||
tool_call_id: str,
|
|
||||||
operation_id: str,
|
|
||||||
task_id: str,
|
|
||||||
session_id: str,
|
|
||||||
user_id: str | None,
|
|
||||||
) -> None:
|
|
||||||
"""Execute a long-running tool with stream registry support for SSE reconnection.
|
|
||||||
|
|
||||||
This function runs independently of the SSE connection, publishes progress
|
|
||||||
to the stream registry, and survives if the user closes their browser tab.
|
|
||||||
Clients can reconnect via GET /chat/tasks/{task_id}/stream to resume streaming.
|
|
||||||
|
|
||||||
If the external service returns a 202 Accepted (async), this function exits
|
|
||||||
early and lets the Redis Streams completion consumer handle the rest.
|
|
||||||
"""
|
|
||||||
# Track whether we delegated to async processing - if so, the Redis Streams
|
|
||||||
# completion consumer (stream_registry / completion_consumer) will handle cleanup, not us
|
|
||||||
delegated_to_async = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load fresh session (not stale reference)
|
|
||||||
session = await get_chat_session(session_id, user_id)
|
|
||||||
if not session:
|
|
||||||
logger.error(f"Session {session_id} not found for background tool")
|
|
||||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Pass operation_id and task_id to the tool for async processing
|
|
||||||
enriched_parameters = {
|
|
||||||
**parameters,
|
|
||||||
"_operation_id": operation_id,
|
|
||||||
"_task_id": task_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Execute the actual tool
|
|
||||||
result = await execute_tool(
|
|
||||||
tool_name=tool_name,
|
|
||||||
parameters=enriched_parameters,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
user_id=user_id,
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if the tool result indicates async processing
|
|
||||||
# (e.g., Agent Generator returned 202 Accepted)
|
|
||||||
try:
|
|
||||||
if isinstance(result.output, dict):
|
|
||||||
result_data = result.output
|
|
||||||
elif result.output:
|
|
||||||
result_data = orjson.loads(result.output)
|
|
||||||
else:
|
|
||||||
result_data = {}
|
|
||||||
if result_data.get("status") == "accepted":
|
|
||||||
logger.info(
|
|
||||||
f"Tool {tool_name} delegated to async processing "
|
|
||||||
f"(operation_id={operation_id}, task_id={task_id}). "
|
|
||||||
f"Redis Streams completion consumer will handle the rest."
|
|
||||||
)
|
|
||||||
# Don't publish result, don't continue with LLM, and don't cleanup
|
|
||||||
# The Redis Streams consumer (completion_consumer) will handle
|
|
||||||
# everything when the external service completes via webhook
|
|
||||||
delegated_to_async = True
|
|
||||||
return
|
|
||||||
except (orjson.JSONDecodeError, TypeError):
|
|
||||||
pass # Not JSON or not async - continue normally
|
|
||||||
|
|
||||||
# Publish tool result to stream registry
|
|
||||||
await stream_registry.publish_chunk(task_id, result)
|
|
||||||
|
|
||||||
# Update the pending message with result
|
|
||||||
result_str = (
|
|
||||||
result.output
|
|
||||||
if isinstance(result.output, str)
|
|
||||||
else orjson.dumps(result.output).decode("utf-8")
|
|
||||||
)
|
|
||||||
await _update_pending_operation(
|
|
||||||
session_id=session_id,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
result=result_str,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Background tool {tool_name} completed for session {session_id} "
|
|
||||||
f"(task_id={task_id})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate LLM continuation and stream chunks to registry
|
|
||||||
await _generate_llm_continuation_with_streaming(
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=user_id,
|
|
||||||
task_id=task_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mark task as completed in stream registry
|
|
||||||
await stream_registry.mark_task_completed(task_id, status="completed")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Background tool {tool_name} failed: {e}", exc_info=True)
|
|
||||||
error_response = ErrorResponse(
|
|
||||||
message=f"Tool {tool_name} failed: {str(e)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Publish error to stream registry followed by finish event
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task_id,
|
|
||||||
StreamError(errorText=str(e)),
|
|
||||||
)
|
|
||||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
|
||||||
|
|
||||||
await _update_pending_operation(
|
|
||||||
session_id=session_id,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
result=error_response.model_dump_json(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mark task as failed in stream registry
|
|
||||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
|
||||||
finally:
|
|
||||||
# Only cleanup if we didn't delegate to async processing
|
|
||||||
# For async path, the Redis Streams completion consumer handles cleanup
|
|
||||||
if not delegated_to_async:
|
|
||||||
await _mark_operation_completed(tool_call_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def _update_pending_operation(
|
async def _update_pending_operation(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
tool_call_id: str,
|
tool_call_id: str,
|
||||||
@@ -1762,128 +1597,3 @@ async def _generate_llm_continuation(
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
|
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
async def _generate_llm_continuation_with_streaming(
|
|
||||||
session_id: str,
|
|
||||||
user_id: str | None,
|
|
||||||
task_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""Generate an LLM response with streaming to the stream registry.
|
|
||||||
|
|
||||||
This is called by background tasks to continue the conversation
|
|
||||||
after a tool result is saved. Chunks are published to the stream registry
|
|
||||||
so reconnecting clients can receive them.
|
|
||||||
"""
|
|
||||||
import uuid as uuid_module
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load fresh session from DB (bypass cache to get the updated tool result)
|
|
||||||
await invalidate_session_cache(session_id)
|
|
||||||
session = await get_chat_session(session_id, user_id)
|
|
||||||
if not session:
|
|
||||||
logger.error(f"Session {session_id} not found for LLM continuation")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Build system prompt
|
|
||||||
system_prompt, _ = await _build_system_prompt(user_id)
|
|
||||||
|
|
||||||
# Build messages in OpenAI format
|
|
||||||
messages = session.to_openai_messages()
|
|
||||||
if system_prompt:
|
|
||||||
from openai.types.chat import ChatCompletionSystemMessageParam
|
|
||||||
|
|
||||||
system_message = ChatCompletionSystemMessageParam(
|
|
||||||
role="system",
|
|
||||||
content=system_prompt,
|
|
||||||
)
|
|
||||||
messages = [system_message] + messages
|
|
||||||
|
|
||||||
# Build extra_body for tracing
|
|
||||||
extra_body: dict[str, Any] = {
|
|
||||||
"posthogProperties": {
|
|
||||||
"environment": settings.config.app_env.value,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if user_id:
|
|
||||||
extra_body["user"] = user_id[:128]
|
|
||||||
extra_body["posthogDistinctId"] = user_id
|
|
||||||
if session_id:
|
|
||||||
extra_body["session_id"] = session_id[:128]
|
|
||||||
|
|
||||||
# Make streaming LLM call (no tools - just text response)
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
|
||||||
|
|
||||||
# Generate unique IDs for AI SDK protocol
|
|
||||||
message_id = str(uuid_module.uuid4())
|
|
||||||
text_block_id = str(uuid_module.uuid4())
|
|
||||||
|
|
||||||
# Publish start event
|
|
||||||
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
|
|
||||||
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
|
|
||||||
|
|
||||||
# Stream the response
|
|
||||||
stream = await client.chat.completions.create(
|
|
||||||
model=config.model,
|
|
||||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
|
||||||
extra_body=extra_body,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
assistant_content = ""
|
|
||||||
async for chunk in stream:
|
|
||||||
if chunk.choices and chunk.choices[0].delta.content:
|
|
||||||
delta = chunk.choices[0].delta.content
|
|
||||||
assistant_content += delta
|
|
||||||
# Publish delta to stream registry
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task_id,
|
|
||||||
StreamTextDelta(id=text_block_id, delta=delta),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Publish end events
|
|
||||||
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
|
|
||||||
|
|
||||||
if assistant_content:
|
|
||||||
# Reload session from DB to avoid race condition with user messages
|
|
||||||
fresh_session = await get_chat_session(session_id, user_id)
|
|
||||||
if not fresh_session:
|
|
||||||
logger.error(
|
|
||||||
f"Session {session_id} disappeared during LLM continuation"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Save assistant message to database
|
|
||||||
assistant_message = ChatMessage(
|
|
||||||
role="assistant",
|
|
||||||
content=assistant_content,
|
|
||||||
)
|
|
||||||
fresh_session.messages.append(assistant_message)
|
|
||||||
|
|
||||||
# Save to database (not cache) to persist the response
|
|
||||||
await upsert_chat_session(fresh_session)
|
|
||||||
|
|
||||||
# Invalidate cache so next poll/refresh gets fresh data
|
|
||||||
await invalidate_session_cache(session_id)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Generated streaming LLM continuation for session {session_id} "
|
|
||||||
f"(task_id={task_id}), response length: {len(assistant_content)}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Streaming LLM continuation returned empty response for {session_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to generate streaming LLM continuation: {e}", exc_info=True
|
|
||||||
)
|
|
||||||
# Publish error to stream registry followed by finish event
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task_id,
|
|
||||||
StreamError(errorText=f"Failed to generate response: {e}"),
|
|
||||||
)
|
|
||||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
|
||||||
|
|||||||
@@ -1,704 +0,0 @@
|
|||||||
"""Stream registry for managing reconnectable SSE streams.
|
|
||||||
|
|
||||||
This module provides a registry for tracking active streaming tasks and their
|
|
||||||
messages. It uses Redis for all state management (no in-memory state), making
|
|
||||||
pods stateless and horizontally scalable.
|
|
||||||
|
|
||||||
Architecture:
|
|
||||||
- Redis Stream: Persists all messages for replay and real-time delivery
|
|
||||||
- Redis Hash: Task metadata (status, session_id, etc.)
|
|
||||||
|
|
||||||
Subscribers:
|
|
||||||
1. Replay missed messages from Redis Stream (XREAD)
|
|
||||||
2. Listen for live updates via blocking XREAD
|
|
||||||
3. No in-memory state required on the subscribing pod
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
import orjson
|
|
||||||
|
|
||||||
from backend.data.redis_client import get_redis_async
|
|
||||||
|
|
||||||
from .config import ChatConfig
|
|
||||||
from .response_model import StreamBaseResponse, StreamError, StreamFinish
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
config = ChatConfig()
|
|
||||||
|
|
||||||
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
|
|
||||||
_local_tasks: dict[str, asyncio.Task] = {}
|
|
||||||
|
|
||||||
# Track listener tasks per subscriber queue for cleanup
|
|
||||||
# Maps queue id() to (task_id, asyncio.Task) for proper cleanup on unsubscribe
|
|
||||||
_listener_tasks: dict[int, tuple[str, asyncio.Task]] = {}
|
|
||||||
|
|
||||||
# Timeout for putting chunks into subscriber queues (seconds)
|
|
||||||
# If the queue is full and doesn't drain within this time, send an overflow error
|
|
||||||
QUEUE_PUT_TIMEOUT = 5.0
|
|
||||||
|
|
||||||
# Lua script for atomic compare-and-swap status update (idempotent completion)
|
|
||||||
# Returns 1 if status was updated, 0 if already completed/failed
|
|
||||||
COMPLETE_TASK_SCRIPT = """
|
|
||||||
local current = redis.call("HGET", KEYS[1], "status")
|
|
||||||
if current == "running" then
|
|
||||||
redis.call("HSET", KEYS[1], "status", ARGV[1])
|
|
||||||
return 1
|
|
||||||
end
|
|
||||||
return 0
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ActiveTask:
|
|
||||||
"""Represents an active streaming task (metadata only, no in-memory queues)."""
|
|
||||||
|
|
||||||
task_id: str
|
|
||||||
session_id: str
|
|
||||||
user_id: str | None
|
|
||||||
tool_call_id: str
|
|
||||||
tool_name: str
|
|
||||||
operation_id: str
|
|
||||||
status: Literal["running", "completed", "failed"] = "running"
|
|
||||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
||||||
asyncio_task: asyncio.Task | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_task_meta_key(task_id: str) -> str:
|
|
||||||
"""Get Redis key for task metadata."""
|
|
||||||
return f"{config.task_meta_prefix}{task_id}"
|
|
||||||
|
|
||||||
|
|
||||||
def _get_task_stream_key(task_id: str) -> str:
|
|
||||||
"""Get Redis key for task message stream."""
|
|
||||||
return f"{config.task_stream_prefix}{task_id}"
|
|
||||||
|
|
||||||
|
|
||||||
def _get_operation_mapping_key(operation_id: str) -> str:
|
|
||||||
"""Get Redis key for operation_id to task_id mapping."""
|
|
||||||
return f"{config.task_op_prefix}{operation_id}"
|
|
||||||
|
|
||||||
|
|
||||||
async def create_task(
|
|
||||||
task_id: str,
|
|
||||||
session_id: str,
|
|
||||||
user_id: str | None,
|
|
||||||
tool_call_id: str,
|
|
||||||
tool_name: str,
|
|
||||||
operation_id: str,
|
|
||||||
) -> ActiveTask:
|
|
||||||
"""Create a new streaming task in Redis.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Unique identifier for the task
|
|
||||||
session_id: Chat session ID
|
|
||||||
user_id: User ID (may be None for anonymous)
|
|
||||||
tool_call_id: Tool call ID from the LLM
|
|
||||||
tool_name: Name of the tool being executed
|
|
||||||
operation_id: Operation ID for webhook callbacks
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The created ActiveTask instance (metadata only)
|
|
||||||
"""
|
|
||||||
task = ActiveTask(
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=user_id,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
tool_name=tool_name,
|
|
||||||
operation_id=operation_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store metadata in Redis
|
|
||||||
redis = await get_redis_async()
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
|
||||||
op_key = _get_operation_mapping_key(operation_id)
|
|
||||||
|
|
||||||
await redis.hset( # type: ignore[misc]
|
|
||||||
meta_key,
|
|
||||||
mapping={
|
|
||||||
"task_id": task_id,
|
|
||||||
"session_id": session_id,
|
|
||||||
"user_id": user_id or "",
|
|
||||||
"tool_call_id": tool_call_id,
|
|
||||||
"tool_name": tool_name,
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"status": task.status,
|
|
||||||
"created_at": task.created_at.isoformat(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
await redis.expire(meta_key, config.stream_ttl)
|
|
||||||
|
|
||||||
# Create operation_id -> task_id mapping for webhook lookups
|
|
||||||
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
|
||||||
|
|
||||||
logger.debug(f"Created task {task_id} for session {session_id}")
|
|
||||||
|
|
||||||
return task
|
|
||||||
|
|
||||||
|
|
||||||
async def publish_chunk(
|
|
||||||
task_id: str,
|
|
||||||
chunk: StreamBaseResponse,
|
|
||||||
) -> str:
|
|
||||||
"""Publish a chunk to Redis Stream.
|
|
||||||
|
|
||||||
All delivery is via Redis Streams - no in-memory state.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID to publish to
|
|
||||||
chunk: The stream response chunk to publish
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The Redis Stream message ID
|
|
||||||
"""
|
|
||||||
chunk_json = chunk.model_dump_json()
|
|
||||||
message_id = "0-0"
|
|
||||||
|
|
||||||
try:
|
|
||||||
redis = await get_redis_async()
|
|
||||||
stream_key = _get_task_stream_key(task_id)
|
|
||||||
|
|
||||||
# Write to Redis Stream for persistence and real-time delivery
|
|
||||||
raw_id = await redis.xadd(
|
|
||||||
stream_key,
|
|
||||||
{"data": chunk_json},
|
|
||||||
maxlen=config.stream_max_length,
|
|
||||||
)
|
|
||||||
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
|
||||||
|
|
||||||
# Set TTL on stream to match task metadata TTL
|
|
||||||
await redis.expire(stream_key, config.stream_ttl)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to publish chunk for task {task_id}: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return message_id
|
|
||||||
|
|
||||||
|
|
||||||
async def subscribe_to_task(
|
|
||||||
task_id: str,
|
|
||||||
user_id: str | None,
|
|
||||||
last_message_id: str = "0-0",
|
|
||||||
) -> asyncio.Queue[StreamBaseResponse] | None:
|
|
||||||
"""Subscribe to a task's stream with replay of missed messages.
|
|
||||||
|
|
||||||
This is fully stateless - uses Redis Stream for replay and pub/sub for live updates.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID to subscribe to
|
|
||||||
user_id: User ID for ownership validation
|
|
||||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An asyncio Queue that will receive stream chunks, or None if task not found
|
|
||||||
or user doesn't have access
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
|
||||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
|
||||||
|
|
||||||
if not meta:
|
|
||||||
logger.debug(f"Task {task_id} not found in Redis")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
|
||||||
task_status = meta.get("status", "")
|
|
||||||
task_user_id = meta.get("user_id", "") or None
|
|
||||||
|
|
||||||
# Validate ownership - if task has an owner, requester must match
|
|
||||||
if task_user_id:
|
|
||||||
if user_id != task_user_id:
|
|
||||||
logger.warning(
|
|
||||||
f"User {user_id} denied access to task {task_id} "
|
|
||||||
f"owned by {task_user_id}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
|
|
||||||
stream_key = _get_task_stream_key(task_id)
|
|
||||||
|
|
||||||
# Step 1: Replay messages from Redis Stream
|
|
||||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
|
||||||
|
|
||||||
replayed_count = 0
|
|
||||||
replay_last_id = last_message_id
|
|
||||||
if messages:
|
|
||||||
for _stream_name, stream_messages in messages:
|
|
||||||
for msg_id, msg_data in stream_messages:
|
|
||||||
replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
|
||||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
|
||||||
if "data" in msg_data:
|
|
||||||
try:
|
|
||||||
chunk_data = orjson.loads(msg_data["data"])
|
|
||||||
chunk = _reconstruct_chunk(chunk_data)
|
|
||||||
if chunk:
|
|
||||||
await subscriber_queue.put(chunk)
|
|
||||||
replayed_count += 1
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to replay message: {e}")
|
|
||||||
|
|
||||||
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
|
|
||||||
|
|
||||||
# Step 2: If task is still running, start stream listener for live updates
|
|
||||||
if task_status == "running":
|
|
||||||
listener_task = asyncio.create_task(
|
|
||||||
_stream_listener(task_id, subscriber_queue, replay_last_id)
|
|
||||||
)
|
|
||||||
# Track listener task for cleanup on unsubscribe
|
|
||||||
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
|
||||||
else:
|
|
||||||
# Task is completed/failed - add finish marker
|
|
||||||
await subscriber_queue.put(StreamFinish())
|
|
||||||
|
|
||||||
return subscriber_queue
|
|
||||||
|
|
||||||
|
|
||||||
async def _stream_listener(
|
|
||||||
task_id: str,
|
|
||||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
|
||||||
last_replayed_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""Listen to Redis Stream for new messages using blocking XREAD.
|
|
||||||
|
|
||||||
This approach avoids the duplicate message issue that can occur with pub/sub
|
|
||||||
when messages are published during the gap between replay and subscription.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID to listen for
|
|
||||||
subscriber_queue: Queue to deliver messages to
|
|
||||||
last_replayed_id: Last message ID from replay (continue from here)
|
|
||||||
"""
|
|
||||||
queue_id = id(subscriber_queue)
|
|
||||||
# Track the last successfully delivered message ID for recovery hints
|
|
||||||
last_delivered_id = last_replayed_id
|
|
||||||
|
|
||||||
try:
|
|
||||||
redis = await get_redis_async()
|
|
||||||
stream_key = _get_task_stream_key(task_id)
|
|
||||||
current_id = last_replayed_id
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# Block for up to 30 seconds waiting for new messages
|
|
||||||
# This allows periodic checking if task is still running
|
|
||||||
messages = await redis.xread(
|
|
||||||
{stream_key: current_id}, block=30000, count=100
|
|
||||||
)
|
|
||||||
|
|
||||||
if not messages:
|
|
||||||
# Timeout - check if task is still running
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
|
||||||
status = await redis.hget(meta_key, "status") # type: ignore[misc]
|
|
||||||
if status and status != "running":
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(
|
|
||||||
subscriber_queue.put(StreamFinish()),
|
|
||||||
timeout=QUEUE_PUT_TIMEOUT,
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning(
|
|
||||||
f"Timeout delivering finish event for task {task_id}"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
continue
|
|
||||||
|
|
||||||
for _stream_name, stream_messages in messages:
|
|
||||||
for msg_id, msg_data in stream_messages:
|
|
||||||
current_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
|
||||||
|
|
||||||
if "data" not in msg_data:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
chunk_data = orjson.loads(msg_data["data"])
|
|
||||||
chunk = _reconstruct_chunk(chunk_data)
|
|
||||||
if chunk:
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(
|
|
||||||
subscriber_queue.put(chunk),
|
|
||||||
timeout=QUEUE_PUT_TIMEOUT,
|
|
||||||
)
|
|
||||||
# Update last delivered ID on successful delivery
|
|
||||||
last_delivered_id = current_id
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning(
|
|
||||||
f"Subscriber queue full for task {task_id}, "
|
|
||||||
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
|
|
||||||
)
|
|
||||||
# Send overflow error with recovery info
|
|
||||||
try:
|
|
||||||
overflow_error = StreamError(
|
|
||||||
errorText="Message delivery timeout - some messages may have been missed",
|
|
||||||
code="QUEUE_OVERFLOW",
|
|
||||||
details={
|
|
||||||
"last_delivered_id": last_delivered_id,
|
|
||||||
"recovery_hint": f"Reconnect with last_message_id={last_delivered_id}",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
subscriber_queue.put_nowait(overflow_error)
|
|
||||||
except asyncio.QueueFull:
|
|
||||||
# Queue is completely stuck, nothing more we can do
|
|
||||||
logger.error(
|
|
||||||
f"Cannot deliver overflow error for task {task_id}, "
|
|
||||||
"queue completely blocked"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stop listening on finish
|
|
||||||
if isinstance(chunk, StreamFinish):
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error processing stream message: {e}")
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.debug(f"Stream listener cancelled for task {task_id}")
|
|
||||||
raise # Re-raise to propagate cancellation
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Stream listener error for task {task_id}: {e}")
|
|
||||||
# On error, send finish to unblock subscriber
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(
|
|
||||||
subscriber_queue.put(StreamFinish()),
|
|
||||||
timeout=QUEUE_PUT_TIMEOUT,
|
|
||||||
)
|
|
||||||
except (asyncio.TimeoutError, asyncio.QueueFull):
|
|
||||||
logger.warning(
|
|
||||||
f"Could not deliver finish event for task {task_id} after error"
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
# Clean up listener task mapping on exit
|
|
||||||
_listener_tasks.pop(queue_id, None)
|
|
||||||
|
|
||||||
|
|
||||||
async def mark_task_completed(
|
|
||||||
task_id: str,
|
|
||||||
status: Literal["completed", "failed"] = "completed",
|
|
||||||
) -> bool:
|
|
||||||
"""Mark a task as completed and publish finish event.
|
|
||||||
|
|
||||||
This is idempotent - calling multiple times with the same task_id is safe.
|
|
||||||
Uses atomic compare-and-swap via Lua script to prevent race conditions.
|
|
||||||
Status is updated first (source of truth), then finish event is published (best-effort).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID to mark as completed
|
|
||||||
status: Final status ("completed" or "failed")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if task was newly marked completed, False if already completed/failed
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
|
||||||
|
|
||||||
# Atomic compare-and-swap: only update if status is "running"
|
|
||||||
# This prevents race conditions when multiple callers try to complete simultaneously
|
|
||||||
result = await redis.eval(COMPLETE_TASK_SCRIPT, 1, meta_key, status) # type: ignore[misc]
|
|
||||||
|
|
||||||
if result == 0:
|
|
||||||
logger.debug(f"Task {task_id} already completed/failed, skipping")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# THEN publish finish event (best-effort - listeners can detect via status polling)
|
|
||||||
try:
|
|
||||||
await publish_chunk(task_id, StreamFinish())
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to publish finish event for task {task_id}: {e}. "
|
|
||||||
"Listeners will detect completion via status polling."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clean up local task reference if exists
|
|
||||||
_local_tasks.pop(task_id, None)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
|
|
||||||
"""Find a task by its operation ID.
|
|
||||||
|
|
||||||
Used by webhook callbacks to locate the task to update.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
operation_id: Operation ID to search for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ActiveTask if found, None otherwise
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
op_key = _get_operation_mapping_key(operation_id)
|
|
||||||
task_id = await redis.get(op_key)
|
|
||||||
|
|
||||||
if not task_id:
|
|
||||||
return None
|
|
||||||
|
|
||||||
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
|
|
||||||
return await get_task(task_id_str)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_task(task_id: str) -> ActiveTask | None:
|
|
||||||
"""Get a task by its ID from Redis.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID to look up
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ActiveTask if found, None otherwise
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
|
||||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
|
||||||
|
|
||||||
if not meta:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
|
||||||
return ActiveTask(
|
|
||||||
task_id=meta.get("task_id", ""),
|
|
||||||
session_id=meta.get("session_id", ""),
|
|
||||||
user_id=meta.get("user_id", "") or None,
|
|
||||||
tool_call_id=meta.get("tool_call_id", ""),
|
|
||||||
tool_name=meta.get("tool_name", ""),
|
|
||||||
operation_id=meta.get("operation_id", ""),
|
|
||||||
status=meta.get("status", "running"), # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_task_with_expiry_info(
|
|
||||||
task_id: str,
|
|
||||||
) -> tuple[ActiveTask | None, str | None]:
|
|
||||||
"""Get a task by its ID with expiration detection.
|
|
||||||
|
|
||||||
Returns (task, error_code) where error_code is:
|
|
||||||
- None if task found
|
|
||||||
- "TASK_EXPIRED" if stream exists but metadata is gone (TTL expired)
|
|
||||||
- "TASK_NOT_FOUND" if neither exists
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID to look up
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (ActiveTask or None, error_code or None)
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
|
||||||
stream_key = _get_task_stream_key(task_id)
|
|
||||||
|
|
||||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
|
||||||
|
|
||||||
if not meta:
|
|
||||||
# Check if stream still has data (metadata expired but stream hasn't)
|
|
||||||
stream_len = await redis.xlen(stream_key)
|
|
||||||
if stream_len > 0:
|
|
||||||
return None, "TASK_EXPIRED"
|
|
||||||
return None, "TASK_NOT_FOUND"
|
|
||||||
|
|
||||||
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
|
||||||
return (
|
|
||||||
ActiveTask(
|
|
||||||
task_id=meta.get("task_id", ""),
|
|
||||||
session_id=meta.get("session_id", ""),
|
|
||||||
user_id=meta.get("user_id", "") or None,
|
|
||||||
tool_call_id=meta.get("tool_call_id", ""),
|
|
||||||
tool_name=meta.get("tool_name", ""),
|
|
||||||
operation_id=meta.get("operation_id", ""),
|
|
||||||
status=meta.get("status", "running"), # type: ignore[arg-type]
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_active_task_for_session(
|
|
||||||
session_id: str,
|
|
||||||
user_id: str | None = None,
|
|
||||||
) -> tuple[ActiveTask | None, str]:
|
|
||||||
"""Get the active (running) task for a session, if any.
|
|
||||||
|
|
||||||
Scans Redis for tasks matching the session_id with status="running".
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_id: Session ID to look up
|
|
||||||
user_id: User ID for ownership validation (optional)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (ActiveTask if found and running, last_message_id from Redis Stream)
|
|
||||||
"""
|
|
||||||
|
|
||||||
redis = await get_redis_async()
|
|
||||||
|
|
||||||
# Scan Redis for task metadata keys
|
|
||||||
cursor = 0
|
|
||||||
tasks_checked = 0
|
|
||||||
|
|
||||||
while True:
|
|
||||||
cursor, keys = await redis.scan(
|
|
||||||
cursor, match=f"{config.task_meta_prefix}*", count=100
|
|
||||||
)
|
|
||||||
|
|
||||||
for key in keys:
|
|
||||||
tasks_checked += 1
|
|
||||||
meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc]
|
|
||||||
if not meta:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
|
||||||
task_session_id = meta.get("session_id", "")
|
|
||||||
task_status = meta.get("status", "")
|
|
||||||
task_user_id = meta.get("user_id", "") or None
|
|
||||||
task_id = meta.get("task_id", "")
|
|
||||||
|
|
||||||
if task_session_id == session_id and task_status == "running":
|
|
||||||
# Validate ownership - if task has an owner, requester must match
|
|
||||||
if task_user_id and user_id != task_user_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get the last message ID from Redis Stream
|
|
||||||
stream_key = _get_task_stream_key(task_id)
|
|
||||||
last_id = "0-0"
|
|
||||||
try:
|
|
||||||
messages = await redis.xrevrange(stream_key, count=1)
|
|
||||||
if messages:
|
|
||||||
msg_id = messages[0][0]
|
|
||||||
last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get last message ID: {e}")
|
|
||||||
|
|
||||||
return (
|
|
||||||
ActiveTask(
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=task_session_id,
|
|
||||||
user_id=task_user_id,
|
|
||||||
tool_call_id=meta.get("tool_call_id", ""),
|
|
||||||
tool_name=meta.get("tool_name", ""),
|
|
||||||
operation_id=meta.get("operation_id", ""),
|
|
||||||
status="running",
|
|
||||||
),
|
|
||||||
last_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cursor == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
return None, "0-0"
|
|
||||||
|
|
||||||
|
|
||||||
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
|
||||||
"""Reconstruct a StreamBaseResponse from JSON data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chunk_data: Parsed JSON data from Redis
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Reconstructed response object, or None if unknown type
|
|
||||||
"""
|
|
||||||
from .response_model import (
|
|
||||||
ResponseType,
|
|
||||||
StreamError,
|
|
||||||
StreamFinish,
|
|
||||||
StreamHeartbeat,
|
|
||||||
StreamStart,
|
|
||||||
StreamTextDelta,
|
|
||||||
StreamTextEnd,
|
|
||||||
StreamTextStart,
|
|
||||||
StreamToolInputAvailable,
|
|
||||||
StreamToolInputStart,
|
|
||||||
StreamToolOutputAvailable,
|
|
||||||
StreamUsage,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Map response types to their corresponding classes
|
|
||||||
type_to_class: dict[str, type[StreamBaseResponse]] = {
|
|
||||||
ResponseType.START.value: StreamStart,
|
|
||||||
ResponseType.FINISH.value: StreamFinish,
|
|
||||||
ResponseType.TEXT_START.value: StreamTextStart,
|
|
||||||
ResponseType.TEXT_DELTA.value: StreamTextDelta,
|
|
||||||
ResponseType.TEXT_END.value: StreamTextEnd,
|
|
||||||
ResponseType.TOOL_INPUT_START.value: StreamToolInputStart,
|
|
||||||
ResponseType.TOOL_INPUT_AVAILABLE.value: StreamToolInputAvailable,
|
|
||||||
ResponseType.TOOL_OUTPUT_AVAILABLE.value: StreamToolOutputAvailable,
|
|
||||||
ResponseType.ERROR.value: StreamError,
|
|
||||||
ResponseType.USAGE.value: StreamUsage,
|
|
||||||
ResponseType.HEARTBEAT.value: StreamHeartbeat,
|
|
||||||
}
|
|
||||||
|
|
||||||
chunk_type = chunk_data.get("type")
|
|
||||||
chunk_class = type_to_class.get(chunk_type) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
if chunk_class is None:
|
|
||||||
logger.warning(f"Unknown chunk type: {chunk_type}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
return chunk_class(**chunk_data)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
|
|
||||||
"""Track the asyncio.Task for a task (local reference only).
|
|
||||||
|
|
||||||
This is just for cleanup purposes - the task state is in Redis.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID
|
|
||||||
asyncio_task: The asyncio Task to track
|
|
||||||
"""
|
|
||||||
_local_tasks[task_id] = asyncio_task
|
|
||||||
|
|
||||||
|
|
||||||
async def unsubscribe_from_task(
|
|
||||||
task_id: str,
|
|
||||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
|
||||||
) -> None:
|
|
||||||
"""Clean up when a subscriber disconnects.
|
|
||||||
|
|
||||||
Cancels the XREAD-based listener task associated with this subscriber queue
|
|
||||||
to prevent resource leaks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID
|
|
||||||
subscriber_queue: The subscriber's queue used to look up the listener task
|
|
||||||
"""
|
|
||||||
queue_id = id(subscriber_queue)
|
|
||||||
listener_entry = _listener_tasks.pop(queue_id, None)
|
|
||||||
|
|
||||||
if listener_entry is None:
|
|
||||||
logger.debug(
|
|
||||||
f"No listener task found for task {task_id} queue {queue_id} "
|
|
||||||
"(may have already completed)"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
stored_task_id, listener_task = listener_entry
|
|
||||||
|
|
||||||
if stored_task_id != task_id:
|
|
||||||
logger.warning(
|
|
||||||
f"Task ID mismatch in unsubscribe: expected {task_id}, "
|
|
||||||
f"found {stored_task_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if listener_task.done():
|
|
||||||
logger.debug(f"Listener task for task {task_id} already completed")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Cancel the listener task
|
|
||||||
listener_task.cancel()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Wait for the task to be cancelled with a timeout
|
|
||||||
await asyncio.wait_for(listener_task, timeout=5.0)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
# Expected - the task was successfully cancelled
|
|
||||||
pass
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning(
|
|
||||||
f"Timeout waiting for listener task cancellation for task {task_id}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during listener task cancellation for task {task_id}: {e}")
|
|
||||||
|
|
||||||
logger.debug(f"Successfully unsubscribed from task {task_id}")
|
|
||||||
@@ -10,7 +10,6 @@ from .add_understanding import AddUnderstandingTool
|
|||||||
from .agent_output import AgentOutputTool
|
from .agent_output import AgentOutputTool
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .create_agent import CreateAgentTool
|
from .create_agent import CreateAgentTool
|
||||||
from .customize_agent import CustomizeAgentTool
|
|
||||||
from .edit_agent import EditAgentTool
|
from .edit_agent import EditAgentTool
|
||||||
from .find_agent import FindAgentTool
|
from .find_agent import FindAgentTool
|
||||||
from .find_block import FindBlockTool
|
from .find_block import FindBlockTool
|
||||||
@@ -35,7 +34,6 @@ logger = logging.getLogger(__name__)
|
|||||||
TOOL_REGISTRY: dict[str, BaseTool] = {
|
TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||||
"add_understanding": AddUnderstandingTool(),
|
"add_understanding": AddUnderstandingTool(),
|
||||||
"create_agent": CreateAgentTool(),
|
"create_agent": CreateAgentTool(),
|
||||||
"customize_agent": CustomizeAgentTool(),
|
|
||||||
"edit_agent": EditAgentTool(),
|
"edit_agent": EditAgentTool(),
|
||||||
"find_agent": FindAgentTool(),
|
"find_agent": FindAgentTool(),
|
||||||
"find_block": FindBlockTool(),
|
"find_block": FindBlockTool(),
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from .core import (
|
|||||||
DecompositionStep,
|
DecompositionStep,
|
||||||
LibraryAgentSummary,
|
LibraryAgentSummary,
|
||||||
MarketplaceAgentSummary,
|
MarketplaceAgentSummary,
|
||||||
customize_template,
|
|
||||||
decompose_goal,
|
decompose_goal,
|
||||||
enrich_library_agents_from_steps,
|
enrich_library_agents_from_steps,
|
||||||
extract_search_terms_from_steps,
|
extract_search_terms_from_steps,
|
||||||
@@ -20,7 +19,6 @@ from .core import (
|
|||||||
get_library_agent_by_graph_id,
|
get_library_agent_by_graph_id,
|
||||||
get_library_agent_by_id,
|
get_library_agent_by_id,
|
||||||
get_library_agents_for_generation,
|
get_library_agents_for_generation,
|
||||||
graph_to_json,
|
|
||||||
json_to_graph,
|
json_to_graph,
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
search_marketplace_agents_for_generation,
|
search_marketplace_agents_for_generation,
|
||||||
@@ -38,7 +36,6 @@ __all__ = [
|
|||||||
"LibraryAgentSummary",
|
"LibraryAgentSummary",
|
||||||
"MarketplaceAgentSummary",
|
"MarketplaceAgentSummary",
|
||||||
"check_external_service_health",
|
"check_external_service_health",
|
||||||
"customize_template",
|
|
||||||
"decompose_goal",
|
"decompose_goal",
|
||||||
"enrich_library_agents_from_steps",
|
"enrich_library_agents_from_steps",
|
||||||
"extract_search_terms_from_steps",
|
"extract_search_terms_from_steps",
|
||||||
@@ -51,7 +48,6 @@ __all__ = [
|
|||||||
"get_library_agent_by_id",
|
"get_library_agent_by_id",
|
||||||
"get_library_agents_for_generation",
|
"get_library_agents_for_generation",
|
||||||
"get_user_message_for_error",
|
"get_user_message_for_error",
|
||||||
"graph_to_json",
|
|
||||||
"is_external_service_configured",
|
"is_external_service_configured",
|
||||||
"json_to_graph",
|
"json_to_graph",
|
||||||
"save_agent_to_library",
|
"save_agent_to_library",
|
||||||
|
|||||||
@@ -7,11 +7,18 @@ from typing import Any, NotRequired, TypedDict
|
|||||||
|
|
||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
|
from backend.data.graph import (
|
||||||
|
Graph,
|
||||||
|
Link,
|
||||||
|
Node,
|
||||||
|
create_graph,
|
||||||
|
get_graph,
|
||||||
|
get_graph_all_versions,
|
||||||
|
get_store_listed_graphs,
|
||||||
|
)
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
from .service import (
|
from .service import (
|
||||||
customize_template_external,
|
|
||||||
decompose_goal_external,
|
decompose_goal_external,
|
||||||
generate_agent_external,
|
generate_agent_external,
|
||||||
generate_agent_patch_external,
|
generate_agent_patch_external,
|
||||||
@@ -20,6 +27,8 @@ from .service import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
|
||||||
|
|
||||||
|
|
||||||
class ExecutionSummary(TypedDict):
|
class ExecutionSummary(TypedDict):
|
||||||
"""Summary of a single execution for quality assessment."""
|
"""Summary of a single execution for quality assessment."""
|
||||||
@@ -540,21 +549,15 @@ async def decompose_goal(
|
|||||||
async def generate_agent(
|
async def generate_agent(
|
||||||
instructions: DecompositionResult | dict[str, Any],
|
instructions: DecompositionResult | dict[str, Any],
|
||||||
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
||||||
operation_id: str | None = None,
|
|
||||||
task_id: str | None = None,
|
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Generate agent JSON from instructions.
|
"""Generate agent JSON from instructions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
instructions: Structured instructions from decompose_goal
|
instructions: Structured instructions from decompose_goal
|
||||||
library_agents: User's library agents available for sub-agent composition
|
library_agents: User's library agents available for sub-agent composition
|
||||||
operation_id: Operation ID for async processing (enables Redis Streams
|
|
||||||
completion notification)
|
|
||||||
task_id: Task ID for async processing (enables Redis Streams persistence
|
|
||||||
and SSE delivery)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
Agent JSON dict, error dict {"type": "error", ...}, or None on error
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
@@ -562,13 +565,8 @@ async def generate_agent(
|
|||||||
_check_service_configured()
|
_check_service_configured()
|
||||||
logger.info("Calling external Agent Generator service for generate_agent")
|
logger.info("Calling external Agent Generator service for generate_agent")
|
||||||
result = await generate_agent_external(
|
result = await generate_agent_external(
|
||||||
dict(instructions), _to_dict_list(library_agents), operation_id, task_id
|
dict(instructions), _to_dict_list(library_agents)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Don't modify async response
|
|
||||||
if result and result.get("status") == "accepted":
|
|
||||||
return result
|
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
if isinstance(result, dict) and result.get("type") == "error":
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
return result
|
return result
|
||||||
@@ -659,6 +657,45 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _reassign_node_ids(graph: Graph) -> None:
|
||||||
|
"""Reassign all node and link IDs to new UUIDs.
|
||||||
|
|
||||||
|
This is needed when creating a new version to avoid unique constraint violations.
|
||||||
|
"""
|
||||||
|
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
||||||
|
|
||||||
|
for node in graph.nodes:
|
||||||
|
node.id = id_map[node.id]
|
||||||
|
|
||||||
|
for link in graph.links:
|
||||||
|
link.id = str(uuid.uuid4())
|
||||||
|
if link.source_id in id_map:
|
||||||
|
link.source_id = id_map[link.source_id]
|
||||||
|
if link.sink_id in id_map:
|
||||||
|
link.sink_id = id_map[link.sink_id]
|
||||||
|
|
||||||
|
|
||||||
|
def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None:
|
||||||
|
"""Populate user_id in AgentExecutorBlock nodes.
|
||||||
|
|
||||||
|
The external agent generator creates AgentExecutorBlock nodes with empty user_id.
|
||||||
|
This function fills in the actual user_id so sub-agents run with correct permissions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_json: Agent JSON dict (modified in place)
|
||||||
|
user_id: User ID to set
|
||||||
|
"""
|
||||||
|
for node in agent_json.get("nodes", []):
|
||||||
|
if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID:
|
||||||
|
input_default = node.get("input_default") or {}
|
||||||
|
if not input_default.get("user_id"):
|
||||||
|
input_default["user_id"] = user_id
|
||||||
|
node["input_default"] = input_default
|
||||||
|
logger.debug(
|
||||||
|
f"Set user_id for AgentExecutorBlock node {node.get('id')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def save_agent_to_library(
|
async def save_agent_to_library(
|
||||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||||
) -> tuple[Graph, Any]:
|
) -> tuple[Graph, Any]:
|
||||||
@@ -672,21 +709,63 @@ async def save_agent_to_library(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (created Graph, LibraryAgent)
|
Tuple of (created Graph, LibraryAgent)
|
||||||
"""
|
"""
|
||||||
|
# Populate user_id in AgentExecutorBlock nodes before conversion
|
||||||
|
_populate_agent_executor_user_ids(agent_json, user_id)
|
||||||
|
|
||||||
graph = json_to_graph(agent_json)
|
graph = json_to_graph(agent_json)
|
||||||
|
|
||||||
if is_update:
|
if is_update:
|
||||||
return await library_db.update_graph_in_library(graph, user_id)
|
if graph.id:
|
||||||
return await library_db.create_graph_in_library(graph, user_id)
|
existing_versions = await get_graph_all_versions(graph.id, user_id)
|
||||||
|
if existing_versions:
|
||||||
|
latest_version = max(v.version for v in existing_versions)
|
||||||
|
graph.version = latest_version + 1
|
||||||
|
_reassign_node_ids(graph)
|
||||||
|
logger.info(f"Updating agent {graph.id} to version {graph.version}")
|
||||||
|
else:
|
||||||
|
graph.id = str(uuid.uuid4())
|
||||||
|
graph.version = 1
|
||||||
|
_reassign_node_ids(graph)
|
||||||
|
logger.info(f"Creating new agent with ID {graph.id}")
|
||||||
|
|
||||||
|
created_graph = await create_graph(graph, user_id)
|
||||||
|
|
||||||
|
library_agents = await library_db.create_library_agent(
|
||||||
|
graph=created_graph,
|
||||||
|
user_id=user_id,
|
||||||
|
sensitive_action_safe_mode=True,
|
||||||
|
create_library_agents_for_sub_graphs=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return created_graph, library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
async def get_agent_as_json(
|
||||||
"""Convert a Graph object to JSON format for the agent generator.
|
agent_id: str, user_id: str | None
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Fetch an agent and convert to JSON format for editing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph: Graph object to convert
|
agent_id: Graph ID or library agent ID
|
||||||
|
user_id: User ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent as JSON dict
|
Agent as JSON dict or None if not found
|
||||||
"""
|
"""
|
||||||
|
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
||||||
|
|
||||||
|
if not graph and user_id:
|
||||||
|
try:
|
||||||
|
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
|
graph = await get_graph(
|
||||||
|
library_agent.graph_id, version=None, user_id=user_id
|
||||||
|
)
|
||||||
|
except NotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not graph:
|
||||||
|
return None
|
||||||
|
|
||||||
nodes = []
|
nodes = []
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
nodes.append(
|
nodes.append(
|
||||||
@@ -723,41 +802,10 @@ def graph_to_json(graph: Graph) -> dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def get_agent_as_json(
|
|
||||||
agent_id: str, user_id: str | None
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Fetch an agent and convert to JSON format for editing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_id: Graph ID or library agent ID
|
|
||||||
user_id: User ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Agent as JSON dict or None if not found
|
|
||||||
"""
|
|
||||||
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
|
||||||
|
|
||||||
if not graph and user_id:
|
|
||||||
try:
|
|
||||||
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
|
||||||
graph = await get_graph(
|
|
||||||
library_agent.graph_id, version=None, user_id=user_id
|
|
||||||
)
|
|
||||||
except NotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not graph:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return graph_to_json(graph)
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_patch(
|
async def generate_agent_patch(
|
||||||
update_request: str,
|
update_request: str,
|
||||||
current_agent: dict[str, Any],
|
current_agent: dict[str, Any],
|
||||||
library_agents: list[AgentSummary] | None = None,
|
library_agents: list[AgentSummary] | None = None,
|
||||||
operation_id: str | None = None,
|
|
||||||
task_id: str | None = None,
|
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Update an existing agent using natural language.
|
"""Update an existing agent using natural language.
|
||||||
|
|
||||||
@@ -770,12 +818,10 @@ async def generate_agent_patch(
|
|||||||
update_request: Natural language description of changes
|
update_request: Natural language description of changes
|
||||||
current_agent: Current agent JSON
|
current_agent: Current agent JSON
|
||||||
library_agents: User's library agents available for sub-agent composition
|
library_agents: User's library agents available for sub-agent composition
|
||||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
|
||||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||||
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
error dict {"type": "error", ...}, or None on unexpected error
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
@@ -783,43 +829,5 @@ async def generate_agent_patch(
|
|||||||
_check_service_configured()
|
_check_service_configured()
|
||||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||||
return await generate_agent_patch_external(
|
return await generate_agent_patch_external(
|
||||||
update_request,
|
update_request, current_agent, _to_dict_list(library_agents)
|
||||||
current_agent,
|
|
||||||
_to_dict_list(library_agents),
|
|
||||||
operation_id,
|
|
||||||
task_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def customize_template(
|
|
||||||
template_agent: dict[str, Any],
|
|
||||||
modification_request: str,
|
|
||||||
context: str = "",
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Customize a template/marketplace agent using natural language.
|
|
||||||
|
|
||||||
This is used when users want to modify a template or marketplace agent
|
|
||||||
to fit their specific needs before adding it to their library.
|
|
||||||
|
|
||||||
The external Agent Generator service handles:
|
|
||||||
- Understanding the modification request
|
|
||||||
- Applying changes to the template
|
|
||||||
- Fixing and validating the result
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template_agent: The template agent JSON to customize
|
|
||||||
modification_request: Natural language description of customizations
|
|
||||||
context: Additional context (e.g., answers to previous questions)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
|
||||||
error dict {"type": "error", ...}, or None on unexpected error
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
|
||||||
"""
|
|
||||||
_check_service_configured()
|
|
||||||
logger.info("Calling external Agent Generator service for customize_template")
|
|
||||||
return await customize_template_external(
|
|
||||||
template_agent, modification_request, context
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -212,45 +212,24 @@ async def decompose_goal_external(
|
|||||||
async def generate_agent_external(
|
async def generate_agent_external(
|
||||||
instructions: dict[str, Any],
|
instructions: dict[str, Any],
|
||||||
library_agents: list[dict[str, Any]] | None = None,
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
operation_id: str | None = None,
|
|
||||||
task_id: str | None = None,
|
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Call the external service to generate an agent from instructions.
|
"""Call the external service to generate an agent from instructions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
instructions: Structured instructions from decompose_goal
|
instructions: Structured instructions from decompose_goal
|
||||||
library_agents: User's library agents available for sub-agent composition
|
library_agents: User's library agents available for sub-agent composition
|
||||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
|
||||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
|
Agent JSON dict on success, or error dict {"type": "error", ...} on error
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
# Build request payload
|
|
||||||
payload: dict[str, Any] = {"instructions": instructions}
|
payload: dict[str, Any] = {"instructions": instructions}
|
||||||
if library_agents:
|
if library_agents:
|
||||||
payload["library_agents"] = library_agents
|
payload["library_agents"] = library_agents
|
||||||
if operation_id and task_id:
|
|
||||||
payload["operation_id"] = operation_id
|
|
||||||
payload["task_id"] = task_id
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post("/api/generate-agent", json=payload)
|
response = await client.post("/api/generate-agent", json=payload)
|
||||||
|
|
||||||
# Handle 202 Accepted for async processing
|
|
||||||
if response.status_code == 202:
|
|
||||||
logger.info(
|
|
||||||
f"Agent Generator accepted async request "
|
|
||||||
f"(operation_id={operation_id}, task_id={task_id})"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"status": "accepted",
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"task_id": task_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
@@ -282,8 +261,6 @@ async def generate_agent_patch_external(
|
|||||||
update_request: str,
|
update_request: str,
|
||||||
current_agent: dict[str, Any],
|
current_agent: dict[str, Any],
|
||||||
library_agents: list[dict[str, Any]] | None = None,
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
operation_id: str | None = None,
|
|
||||||
task_id: str | None = None,
|
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Call the external service to generate a patch for an existing agent.
|
"""Call the external service to generate a patch for an existing agent.
|
||||||
|
|
||||||
@@ -291,40 +268,21 @@ async def generate_agent_patch_external(
|
|||||||
update_request: Natural language description of changes
|
update_request: Natural language description of changes
|
||||||
current_agent: Current agent JSON
|
current_agent: Current agent JSON
|
||||||
library_agents: User's library agents available for sub-agent composition
|
library_agents: User's library agents available for sub-agent composition
|
||||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
|
||||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
|
Updated agent JSON, clarifying questions dict, or error dict on error
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
# Build request payload
|
|
||||||
payload: dict[str, Any] = {
|
payload: dict[str, Any] = {
|
||||||
"update_request": update_request,
|
"update_request": update_request,
|
||||||
"current_agent_json": current_agent,
|
"current_agent_json": current_agent,
|
||||||
}
|
}
|
||||||
if library_agents:
|
if library_agents:
|
||||||
payload["library_agents"] = library_agents
|
payload["library_agents"] = library_agents
|
||||||
if operation_id and task_id:
|
|
||||||
payload["operation_id"] = operation_id
|
|
||||||
payload["task_id"] = task_id
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post("/api/update-agent", json=payload)
|
response = await client.post("/api/update-agent", json=payload)
|
||||||
|
|
||||||
# Handle 202 Accepted for async processing
|
|
||||||
if response.status_code == 202:
|
|
||||||
logger.info(
|
|
||||||
f"Agent Generator accepted async update request "
|
|
||||||
f"(operation_id={operation_id}, task_id={task_id})"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"status": "accepted",
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"task_id": task_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
@@ -368,77 +326,6 @@ async def generate_agent_patch_external(
|
|||||||
return _create_error_response(error_msg, "unexpected_error")
|
return _create_error_response(error_msg, "unexpected_error")
|
||||||
|
|
||||||
|
|
||||||
async def customize_template_external(
|
|
||||||
template_agent: dict[str, Any],
|
|
||||||
modification_request: str,
|
|
||||||
context: str = "",
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Call the external service to customize a template/marketplace agent.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template_agent: The template agent JSON to customize
|
|
||||||
modification_request: Natural language description of customizations
|
|
||||||
context: Additional context (e.g., answers to previous questions)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Customized agent JSON, clarifying questions dict, or error dict on error
|
|
||||||
"""
|
|
||||||
client = _get_client()
|
|
||||||
|
|
||||||
request = modification_request
|
|
||||||
if context:
|
|
||||||
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
|
|
||||||
|
|
||||||
payload: dict[str, Any] = {
|
|
||||||
"template_agent_json": template_agent,
|
|
||||||
"modification_request": request,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await client.post("/api/template-modification", json=payload)
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
if not data.get("success"):
|
|
||||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
|
||||||
error_type = data.get("error_type", "unknown")
|
|
||||||
logger.error(
|
|
||||||
f"Agent Generator template customization failed: {error_msg} "
|
|
||||||
f"(type: {error_type})"
|
|
||||||
)
|
|
||||||
return _create_error_response(error_msg, error_type)
|
|
||||||
|
|
||||||
# Check if it's clarifying questions
|
|
||||||
if data.get("type") == "clarifying_questions":
|
|
||||||
return {
|
|
||||||
"type": "clarifying_questions",
|
|
||||||
"questions": data.get("questions", []),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Check if it's an error passed through
|
|
||||||
if data.get("type") == "error":
|
|
||||||
return _create_error_response(
|
|
||||||
data.get("error", "Unknown error"),
|
|
||||||
data.get("error_type", "unknown"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Otherwise return the customized agent JSON
|
|
||||||
return data.get("agent_json")
|
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
error_type, error_msg = _classify_http_error(e)
|
|
||||||
logger.error(error_msg)
|
|
||||||
return _create_error_response(error_msg, error_type)
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
error_type, error_msg = _classify_request_error(e)
|
|
||||||
logger.error(error_msg)
|
|
||||||
return _create_error_response(error_msg, error_type)
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
return _create_error_response(error_msg, "unexpected_error")
|
|
||||||
|
|
||||||
|
|
||||||
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||||
"""Get available blocks from the external service.
|
"""Get available blocks from the external service.
|
||||||
|
|
||||||
|
|||||||
@@ -206,9 +206,9 @@ async def search_agents(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
no_results_msg = (
|
no_results_msg = (
|
||||||
f"No agents found matching '{query}'. Let the user know they can try different keywords or browse the marketplace. Also let them know you can create a custom agent for them based on their needs."
|
f"No agents found matching '{query}'. Try different keywords or browse the marketplace."
|
||||||
if source == "marketplace"
|
if source == "marketplace"
|
||||||
else f"No agents matching '{query}' found in your library. Let the user know you can create a custom agent for them based on their needs."
|
else f"No agents matching '{query}' found in your library."
|
||||||
)
|
)
|
||||||
return NoResultsResponse(
|
return NoResultsResponse(
|
||||||
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
||||||
@@ -224,10 +224,10 @@ async def search_agents(
|
|||||||
message = (
|
message = (
|
||||||
"Now you have found some options for the user to choose from. "
|
"Now you have found some options for the user to choose from. "
|
||||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||||
"Please ask the user if they would like to use any of these agents. Let the user know we can create a custom agent for them based on their needs."
|
"Please ask the user if they would like to use any of these agents."
|
||||||
if source == "marketplace"
|
if source == "marketplace"
|
||||||
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
||||||
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute. Let the user know we can create a custom agent for them based on their needs."
|
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute."
|
||||||
)
|
)
|
||||||
|
|
||||||
return AgentsFoundResponse(
|
return AgentsFoundResponse(
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from .base import BaseTool
|
|||||||
from .models import (
|
from .models import (
|
||||||
AgentPreviewResponse,
|
AgentPreviewResponse,
|
||||||
AgentSavedResponse,
|
AgentSavedResponse,
|
||||||
AsyncProcessingResponse,
|
|
||||||
ClarificationNeededResponse,
|
ClarificationNeededResponse,
|
||||||
ClarifyingQuestion,
|
ClarifyingQuestion,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -99,10 +98,6 @@ class CreateAgentTool(BaseTool):
|
|||||||
save = kwargs.get("save", True)
|
save = kwargs.get("save", True)
|
||||||
session_id = session.session_id if session else None
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
# Extract async processing params (passed by long-running tool handler)
|
|
||||||
operation_id = kwargs.get("_operation_id")
|
|
||||||
task_id = kwargs.get("_task_id")
|
|
||||||
|
|
||||||
if not description:
|
if not description:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide a description of what the agent should do.",
|
message="Please provide a description of what the agent should do.",
|
||||||
@@ -224,12 +219,7 @@ class CreateAgentTool(BaseTool):
|
|||||||
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent_json = await generate_agent(
|
agent_json = await generate_agent(decomposition_result, library_agents)
|
||||||
decomposition_result,
|
|
||||||
library_agents,
|
|
||||||
operation_id=operation_id,
|
|
||||||
task_id=task_id,
|
|
||||||
)
|
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
@@ -273,19 +263,6 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if Agent Generator accepted for async processing
|
|
||||||
if agent_json.get("status") == "accepted":
|
|
||||||
logger.info(
|
|
||||||
f"Agent generation delegated to async processing "
|
|
||||||
f"(operation_id={operation_id}, task_id={task_id})"
|
|
||||||
)
|
|
||||||
return AsyncProcessingResponse(
|
|
||||||
message="Agent generation started. You'll be notified when it's complete.",
|
|
||||||
operation_id=operation_id,
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_name = agent_json.get("name", "Generated Agent")
|
agent_name = agent_json.get("name", "Generated Agent")
|
||||||
agent_description = agent_json.get("description", "")
|
agent_description = agent_json.get("description", "")
|
||||||
node_count = len(agent_json.get("nodes", []))
|
node_count = len(agent_json.get("nodes", []))
|
||||||
|
|||||||
@@ -1,337 +0,0 @@
|
|||||||
"""CustomizeAgentTool - Customizes marketplace/template agents using natural language."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
from backend.api.features.store import db as store_db
|
|
||||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
|
||||||
|
|
||||||
from .agent_generator import (
|
|
||||||
AgentGeneratorNotConfiguredError,
|
|
||||||
customize_template,
|
|
||||||
get_user_message_for_error,
|
|
||||||
graph_to_json,
|
|
||||||
save_agent_to_library,
|
|
||||||
)
|
|
||||||
from .base import BaseTool
|
|
||||||
from .models import (
|
|
||||||
AgentPreviewResponse,
|
|
||||||
AgentSavedResponse,
|
|
||||||
ClarificationNeededResponse,
|
|
||||||
ClarifyingQuestion,
|
|
||||||
ErrorResponse,
|
|
||||||
ToolResponseBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomizeAgentTool(BaseTool):
|
|
||||||
"""Tool for customizing marketplace/template agents using natural language."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "customize_agent"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return (
|
|
||||||
"Customize a marketplace or template agent using natural language. "
|
|
||||||
"Takes an existing agent from the marketplace and modifies it based on "
|
|
||||||
"the user's requirements before adding to their library."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_auth(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_long_running(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"agent_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"The marketplace agent ID in format 'creator/slug' "
|
|
||||||
"(e.g., 'autogpt/newsletter-writer'). "
|
|
||||||
"Get this from find_agent results."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"modifications": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"Natural language description of how to customize the agent. "
|
|
||||||
"Be specific about what changes you want to make."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"context": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"Additional context or answers to previous clarifying questions."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"save": {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": (
|
|
||||||
"Whether to save the customized agent to the user's library. "
|
|
||||||
"Default is true. Set to false for preview only."
|
|
||||||
),
|
|
||||||
"default": True,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["agent_id", "modifications"],
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
**kwargs,
|
|
||||||
) -> ToolResponseBase:
|
|
||||||
"""Execute the customize_agent tool.
|
|
||||||
|
|
||||||
Flow:
|
|
||||||
1. Parse the agent ID to get creator/slug
|
|
||||||
2. Fetch the template agent from the marketplace
|
|
||||||
3. Call customize_template with the modification request
|
|
||||||
4. Preview or save based on the save parameter
|
|
||||||
"""
|
|
||||||
agent_id = kwargs.get("agent_id", "").strip()
|
|
||||||
modifications = kwargs.get("modifications", "").strip()
|
|
||||||
context = kwargs.get("context", "")
|
|
||||||
save = kwargs.get("save", True)
|
|
||||||
session_id = session.session_id if session else None
|
|
||||||
|
|
||||||
if not agent_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
|
|
||||||
error="missing_agent_id",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not modifications:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please describe how you want to customize this agent.",
|
|
||||||
error="missing_modifications",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Parse agent_id in format "creator/slug"
|
|
||||||
parts = [p.strip() for p in agent_id.split("/")]
|
|
||||||
if len(parts) != 2 or not parts[0] or not parts[1]:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"Invalid agent ID format: '{agent_id}'. "
|
|
||||||
"Expected format is 'creator/agent-name' "
|
|
||||||
"(e.g., 'autogpt/newsletter-writer')."
|
|
||||||
),
|
|
||||||
error="invalid_agent_id_format",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
creator_username, agent_slug = parts
|
|
||||||
|
|
||||||
# Fetch the marketplace agent details
|
|
||||||
try:
|
|
||||||
agent_details = await store_db.get_store_agent_details(
|
|
||||||
username=creator_username, agent_name=agent_slug
|
|
||||||
)
|
|
||||||
except AgentNotFoundError:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"Could not find marketplace agent '{agent_id}'. "
|
|
||||||
"Please check the agent ID and try again."
|
|
||||||
),
|
|
||||||
error="agent_not_found",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to fetch the marketplace agent. Please try again.",
|
|
||||||
error="fetch_error",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not agent_details.store_listing_version_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"The agent '{agent_id}' does not have an available version. "
|
|
||||||
"Please try a different agent."
|
|
||||||
),
|
|
||||||
error="no_version_available",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the full agent graph
|
|
||||||
try:
|
|
||||||
graph = await store_db.get_agent(agent_details.store_listing_version_id)
|
|
||||||
template_agent = graph_to_json(graph)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to fetch the agent configuration. Please try again.",
|
|
||||||
error="graph_fetch_error",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call customize_template
|
|
||||||
try:
|
|
||||||
result = await customize_template(
|
|
||||||
template_agent=template_agent,
|
|
||||||
modification_request=modifications,
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
except AgentGeneratorNotConfiguredError:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
"Agent customization is not available. "
|
|
||||||
"The Agent Generator service is not configured."
|
|
||||||
),
|
|
||||||
error="service_not_configured",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error calling customize_template for {agent_id}: {e}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
"Failed to customize the agent due to a service error. "
|
|
||||||
"Please try again."
|
|
||||||
),
|
|
||||||
error="customization_service_error",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result is None:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
"Failed to customize the agent. "
|
|
||||||
"The agent generation service may be unavailable or timed out. "
|
|
||||||
"Please try again."
|
|
||||||
),
|
|
||||||
error="customization_failed",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle error response
|
|
||||||
if isinstance(result, dict) and result.get("type") == "error":
|
|
||||||
error_msg = result.get("error", "Unknown error")
|
|
||||||
error_type = result.get("error_type", "unknown")
|
|
||||||
user_message = get_user_message_for_error(
|
|
||||||
error_type,
|
|
||||||
operation="customize the agent",
|
|
||||||
llm_parse_message=(
|
|
||||||
"The AI had trouble customizing the agent. "
|
|
||||||
"Please try again or simplify your request."
|
|
||||||
),
|
|
||||||
validation_message=(
|
|
||||||
"The customized agent failed validation. "
|
|
||||||
"Please try rephrasing your request."
|
|
||||||
),
|
|
||||||
error_details=error_msg,
|
|
||||||
)
|
|
||||||
return ErrorResponse(
|
|
||||||
message=user_message,
|
|
||||||
error=f"customization_failed:{error_type}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle clarifying questions
|
|
||||||
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
|
|
||||||
questions = result.get("questions") or []
|
|
||||||
if not isinstance(questions, list):
|
|
||||||
logger.error(
|
|
||||||
f"Unexpected clarifying questions format: {type(questions)}"
|
|
||||||
)
|
|
||||||
questions = []
|
|
||||||
return ClarificationNeededResponse(
|
|
||||||
message=(
|
|
||||||
"I need some more information to customize this agent. "
|
|
||||||
"Please answer the following questions:"
|
|
||||||
),
|
|
||||||
questions=[
|
|
||||||
ClarifyingQuestion(
|
|
||||||
question=q.get("question", ""),
|
|
||||||
keyword=q.get("keyword", ""),
|
|
||||||
example=q.get("example"),
|
|
||||||
)
|
|
||||||
for q in questions
|
|
||||||
if isinstance(q, dict)
|
|
||||||
],
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Result should be the customized agent JSON
|
|
||||||
if not isinstance(result, dict):
|
|
||||||
logger.error(f"Unexpected customize_template response type: {type(result)}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to customize the agent due to an unexpected response.",
|
|
||||||
error="unexpected_response_type",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
customized_agent = result
|
|
||||||
|
|
||||||
agent_name = customized_agent.get(
|
|
||||||
"name", f"Customized {agent_details.agent_name}"
|
|
||||||
)
|
|
||||||
agent_description = customized_agent.get("description", "")
|
|
||||||
nodes = customized_agent.get("nodes")
|
|
||||||
links = customized_agent.get("links")
|
|
||||||
node_count = len(nodes) if isinstance(nodes, list) else 0
|
|
||||||
link_count = len(links) if isinstance(links, list) else 0
|
|
||||||
|
|
||||||
if not save:
|
|
||||||
return AgentPreviewResponse(
|
|
||||||
message=(
|
|
||||||
f"I've customized the agent '{agent_details.agent_name}'. "
|
|
||||||
f"The customized agent has {node_count} blocks. "
|
|
||||||
f"Review it and call customize_agent with save=true to save it."
|
|
||||||
),
|
|
||||||
agent_json=customized_agent,
|
|
||||||
agent_name=agent_name,
|
|
||||||
description=agent_description,
|
|
||||||
node_count=node_count,
|
|
||||||
link_count=link_count,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not user_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="You must be logged in to save agents.",
|
|
||||||
error="auth_required",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save to user's library
|
|
||||||
try:
|
|
||||||
created_graph, library_agent = await save_agent_to_library(
|
|
||||||
customized_agent, user_id, is_update=False
|
|
||||||
)
|
|
||||||
|
|
||||||
return AgentSavedResponse(
|
|
||||||
message=(
|
|
||||||
f"Customized agent '{created_graph.name}' "
|
|
||||||
f"(based on '{agent_details.agent_name}') "
|
|
||||||
f"has been saved to your library!"
|
|
||||||
),
|
|
||||||
agent_id=created_graph.id,
|
|
||||||
agent_name=created_graph.name,
|
|
||||||
library_agent_id=library_agent.id,
|
|
||||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
|
||||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error saving customized agent: {e}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to save the customized agent. Please try again.",
|
|
||||||
error="save_failed",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
@@ -17,7 +17,6 @@ from .base import BaseTool
|
|||||||
from .models import (
|
from .models import (
|
||||||
AgentPreviewResponse,
|
AgentPreviewResponse,
|
||||||
AgentSavedResponse,
|
AgentSavedResponse,
|
||||||
AsyncProcessingResponse,
|
|
||||||
ClarificationNeededResponse,
|
ClarificationNeededResponse,
|
||||||
ClarifyingQuestion,
|
ClarifyingQuestion,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -105,10 +104,6 @@ class EditAgentTool(BaseTool):
|
|||||||
save = kwargs.get("save", True)
|
save = kwargs.get("save", True)
|
||||||
session_id = session.session_id if session else None
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
# Extract async processing params (passed by long-running tool handler)
|
|
||||||
operation_id = kwargs.get("_operation_id")
|
|
||||||
task_id = kwargs.get("_task_id")
|
|
||||||
|
|
||||||
if not agent_id:
|
if not agent_id:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide the agent ID to edit.",
|
message="Please provide the agent ID to edit.",
|
||||||
@@ -154,11 +149,7 @@ class EditAgentTool(BaseTool):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
result = await generate_agent_patch(
|
result = await generate_agent_patch(
|
||||||
update_request,
|
update_request, current_agent, library_agents
|
||||||
current_agent,
|
|
||||||
library_agents,
|
|
||||||
operation_id=operation_id,
|
|
||||||
task_id=task_id,
|
|
||||||
)
|
)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
@@ -178,20 +169,6 @@ class EditAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if Agent Generator accepted for async processing
|
|
||||||
if result.get("status") == "accepted":
|
|
||||||
logger.info(
|
|
||||||
f"Agent edit delegated to async processing "
|
|
||||||
f"(operation_id={operation_id}, task_id={task_id})"
|
|
||||||
)
|
|
||||||
return AsyncProcessingResponse(
|
|
||||||
message="Agent edit started. You'll be notified when it's complete.",
|
|
||||||
operation_id=operation_id,
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if the result is an error from the external service
|
|
||||||
if isinstance(result, dict) and result.get("type") == "error":
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
error_msg = result.get("error", "Unknown error")
|
error_msg = result.get("error", "Unknown error")
|
||||||
error_type = result.get("error_type", "unknown")
|
error_type = result.get("error_type", "unknown")
|
||||||
|
|||||||
@@ -0,0 +1,77 @@
|
|||||||
|
"""Shared helpers for chat tools."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .models import ErrorResponse
|
||||||
|
|
||||||
|
|
||||||
|
def error_response(
|
||||||
|
message: str, session_id: str | None, **kwargs: Any
|
||||||
|
) -> ErrorResponse:
|
||||||
|
"""Create standardized error response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message to display
|
||||||
|
session_id: Current session ID
|
||||||
|
**kwargs: Additional fields to pass to ErrorResponse
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ErrorResponse with the given message and session_id
|
||||||
|
"""
|
||||||
|
return ErrorResponse(message=message, session_id=session_id, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_inputs_from_schema(
|
||||||
|
input_schema: dict[str, Any],
|
||||||
|
exclude_fields: set[str] | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Extract input field info from JSON schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_schema: JSON schema dict with 'properties' and 'required'
|
||||||
|
exclude_fields: Set of field names to exclude (e.g., credential fields)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with field info (name, title, type, description, required, default)
|
||||||
|
"""
|
||||||
|
exclude = exclude_fields or set()
|
||||||
|
properties = input_schema.get("properties", {})
|
||||||
|
required = set(input_schema.get("required", []))
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": name,
|
||||||
|
"title": schema.get("title", name),
|
||||||
|
"type": schema.get("type", "string"),
|
||||||
|
"description": schema.get("description", ""),
|
||||||
|
"required": name in required,
|
||||||
|
"default": schema.get("default"),
|
||||||
|
}
|
||||||
|
for name, schema in properties.items()
|
||||||
|
if name not in exclude
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def format_inputs_as_markdown(inputs: list[dict[str, Any]]) -> str:
|
||||||
|
"""Format input fields as a readable markdown list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: List of input dicts from get_inputs_from_schema
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Markdown-formatted string listing the inputs
|
||||||
|
"""
|
||||||
|
if not inputs:
|
||||||
|
return "No inputs required."
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
for inp in inputs:
|
||||||
|
required_marker = " (required)" if inp.get("required") else ""
|
||||||
|
default = inp.get("default")
|
||||||
|
default_info = f" [default: {default}]" if default is not None else ""
|
||||||
|
description = inp.get("description", "")
|
||||||
|
desc_info = f" - {description}" if description else ""
|
||||||
|
|
||||||
|
lines.append(f"- **{inp['name']}**{required_marker}{default_info}{desc_info}")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
@@ -38,8 +38,6 @@ class ResponseType(str, Enum):
|
|||||||
OPERATION_STARTED = "operation_started"
|
OPERATION_STARTED = "operation_started"
|
||||||
OPERATION_PENDING = "operation_pending"
|
OPERATION_PENDING = "operation_pending"
|
||||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||||
# Input validation
|
|
||||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
|
||||||
|
|
||||||
|
|
||||||
# Base response model
|
# Base response model
|
||||||
@@ -70,10 +68,6 @@ class AgentInfo(BaseModel):
|
|||||||
has_external_trigger: bool | None = None
|
has_external_trigger: bool | None = None
|
||||||
new_output: bool | None = None
|
new_output: bool | None = None
|
||||||
graph_id: str | None = None
|
graph_id: str | None = None
|
||||||
inputs: dict[str, Any] | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Input schema for the agent, including field names, types, and defaults",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AgentsFoundResponse(ToolResponseBase):
|
class AgentsFoundResponse(ToolResponseBase):
|
||||||
@@ -200,20 +194,6 @@ class ErrorResponse(ToolResponseBase):
|
|||||||
details: dict[str, Any] | None = None
|
details: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class InputValidationErrorResponse(ToolResponseBase):
|
|
||||||
"""Response when run_agent receives unknown input fields."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.INPUT_VALIDATION_ERROR
|
|
||||||
unrecognized_fields: list[str] = Field(
|
|
||||||
description="List of input field names that were not recognized"
|
|
||||||
)
|
|
||||||
inputs: dict[str, Any] = Field(
|
|
||||||
description="The agent's valid input schema for reference"
|
|
||||||
)
|
|
||||||
graph_id: str | None = None
|
|
||||||
graph_version: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# Agent output models
|
# Agent output models
|
||||||
class ExecutionOutputInfo(BaseModel):
|
class ExecutionOutputInfo(BaseModel):
|
||||||
"""Summary of a single execution's outputs."""
|
"""Summary of a single execution's outputs."""
|
||||||
@@ -372,15 +352,11 @@ class OperationStartedResponse(ToolResponseBase):
|
|||||||
|
|
||||||
This is returned immediately to the client while the operation continues
|
This is returned immediately to the client while the operation continues
|
||||||
to execute. The user can close the tab and check back later.
|
to execute. The user can close the tab and check back later.
|
||||||
|
|
||||||
The task_id can be used to reconnect to the SSE stream via
|
|
||||||
GET /chat/tasks/{task_id}/stream?last_idx=0
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||||
operation_id: str
|
operation_id: str
|
||||||
tool_name: str
|
tool_name: str
|
||||||
task_id: str | None = None # For SSE reconnection
|
|
||||||
|
|
||||||
|
|
||||||
class OperationPendingResponse(ToolResponseBase):
|
class OperationPendingResponse(ToolResponseBase):
|
||||||
@@ -404,20 +380,3 @@ class OperationInProgressResponse(ToolResponseBase):
|
|||||||
|
|
||||||
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
||||||
tool_call_id: str
|
tool_call_id: str
|
||||||
|
|
||||||
|
|
||||||
class AsyncProcessingResponse(ToolResponseBase):
|
|
||||||
"""Response when an operation has been delegated to async processing.
|
|
||||||
|
|
||||||
This is returned by tools when the external service accepts the request
|
|
||||||
for async processing (HTTP 202 Accepted). The Redis Streams completion
|
|
||||||
consumer will handle the result when the external service completes.
|
|
||||||
|
|
||||||
The status field is specifically "accepted" to allow the long-running tool
|
|
||||||
handler to detect this response and skip LLM continuation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
|
||||||
status: str = "accepted" # Must be "accepted" for detection
|
|
||||||
operation_id: str | None = None
|
|
||||||
task_id: str | None = None
|
|
||||||
|
|||||||
@@ -24,13 +24,13 @@ from backend.util.timezone_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
|
from .helpers import get_inputs_from_schema
|
||||||
from .models import (
|
from .models import (
|
||||||
AgentDetails,
|
AgentDetails,
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
ExecutionOptions,
|
ExecutionOptions,
|
||||||
ExecutionStartedResponse,
|
ExecutionStartedResponse,
|
||||||
InputValidationErrorResponse,
|
|
||||||
SetupInfo,
|
SetupInfo,
|
||||||
SetupRequirementsResponse,
|
SetupRequirementsResponse,
|
||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
@@ -274,22 +274,6 @@ class RunAgentTool(BaseTool):
|
|||||||
input_properties = graph.input_schema.get("properties", {})
|
input_properties = graph.input_schema.get("properties", {})
|
||||||
required_fields = set(graph.input_schema.get("required", []))
|
required_fields = set(graph.input_schema.get("required", []))
|
||||||
provided_inputs = set(params.inputs.keys())
|
provided_inputs = set(params.inputs.keys())
|
||||||
valid_fields = set(input_properties.keys())
|
|
||||||
|
|
||||||
# Check for unknown input fields
|
|
||||||
unrecognized_fields = provided_inputs - valid_fields
|
|
||||||
if unrecognized_fields:
|
|
||||||
return InputValidationErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
|
|
||||||
f"Agent was not executed. Please use the correct field names from the schema."
|
|
||||||
),
|
|
||||||
session_id=session_id,
|
|
||||||
unrecognized_fields=sorted(unrecognized_fields),
|
|
||||||
inputs=graph.input_schema,
|
|
||||||
graph_id=graph.id,
|
|
||||||
graph_version=graph.version,
|
|
||||||
)
|
|
||||||
|
|
||||||
# If agent has inputs but none were provided AND use_defaults is not set,
|
# If agent has inputs but none were provided AND use_defaults is not set,
|
||||||
# always show what's available first so user can decide
|
# always show what's available first so user can decide
|
||||||
@@ -371,19 +355,7 @@ class RunAgentTool(BaseTool):
|
|||||||
|
|
||||||
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
||||||
"""Extract inputs list from schema."""
|
"""Extract inputs list from schema."""
|
||||||
inputs_list = []
|
return get_inputs_from_schema(input_schema)
|
||||||
if isinstance(input_schema, dict) and "properties" in input_schema:
|
|
||||||
for field_name, field_schema in input_schema["properties"].items():
|
|
||||||
inputs_list.append(
|
|
||||||
{
|
|
||||||
"name": field_name,
|
|
||||||
"title": field_schema.get("title", field_name),
|
|
||||||
"type": field_schema.get("type", "string"),
|
|
||||||
"description": field_schema.get("description", ""),
|
|
||||||
"required": field_name in input_schema.get("required", []),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return inputs_list
|
|
||||||
|
|
||||||
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
||||||
"""Get available execution modes for the graph."""
|
"""Get available execution modes for the graph."""
|
||||||
|
|||||||
@@ -402,42 +402,3 @@ async def test_run_agent_schedule_without_name(setup_test_data):
|
|||||||
# Should return error about missing schedule_name
|
# Should return error about missing schedule_name
|
||||||
assert result_data.get("type") == "error"
|
assert result_data.get("type") == "error"
|
||||||
assert "schedule_name" in result_data["message"].lower()
|
assert "schedule_name" in result_data["message"].lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_run_agent_rejects_unknown_input_fields(setup_test_data):
|
|
||||||
"""Test that run_agent returns input_validation_error for unknown input fields."""
|
|
||||||
user = setup_test_data["user"]
|
|
||||||
store_submission = setup_test_data["store_submission"]
|
|
||||||
|
|
||||||
tool = RunAgentTool()
|
|
||||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
|
||||||
session = make_session(user_id=user.id)
|
|
||||||
|
|
||||||
# Execute with unknown input field names
|
|
||||||
response = await tool.execute(
|
|
||||||
user_id=user.id,
|
|
||||||
session_id=str(uuid.uuid4()),
|
|
||||||
tool_call_id=str(uuid.uuid4()),
|
|
||||||
username_agent_slug=agent_marketplace_id,
|
|
||||||
inputs={
|
|
||||||
"unknown_field": "some value",
|
|
||||||
"another_unknown": "another value",
|
|
||||||
},
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response is not None
|
|
||||||
assert hasattr(response, "output")
|
|
||||||
assert isinstance(response.output, str)
|
|
||||||
result_data = orjson.loads(response.output)
|
|
||||||
|
|
||||||
# Should return input_validation_error type with unrecognized fields
|
|
||||||
assert result_data.get("type") == "input_validation_error"
|
|
||||||
assert "unrecognized_fields" in result_data
|
|
||||||
assert set(result_data["unrecognized_fields"]) == {
|
|
||||||
"another_unknown",
|
|
||||||
"unknown_field",
|
|
||||||
}
|
|
||||||
assert "inputs" in result_data # Contains the valid schema
|
|
||||||
assert "Agent was not executed" in result_data["message"]
|
|
||||||
|
|||||||
@@ -5,17 +5,16 @@ import uuid
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic_core import PydanticUndefined
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.data.block import get_block
|
from backend.data.block import get_block
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||||
from backend.data.workspace import get_or_create_workspace
|
from backend.data.workspace import get_or_create_workspace
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import BlockError
|
from backend.util.exceptions import BlockError
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
|
from .helpers import get_inputs_from_schema
|
||||||
from .models import (
|
from .models import (
|
||||||
BlockOutputResponse,
|
BlockOutputResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -24,7 +23,10 @@ from .models import (
|
|||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
UserReadiness,
|
UserReadiness,
|
||||||
)
|
)
|
||||||
from .utils import build_missing_credentials_from_field_info
|
from .utils import (
|
||||||
|
build_missing_credentials_from_field_info,
|
||||||
|
match_credentials_to_requirements,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -73,90 +75,39 @@ class RunBlockTool(BaseTool):
|
|||||||
def requires_auth(self) -> bool:
|
def requires_auth(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _get_credentials_requirements(
|
||||||
|
self,
|
||||||
|
block: Any,
|
||||||
|
) -> dict[str, CredentialsFieldInfo]:
|
||||||
|
"""
|
||||||
|
Get credential requirements from block's input schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block: Block to get credentials for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping field names to CredentialsFieldInfo
|
||||||
|
"""
|
||||||
|
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||||
|
return credentials_fields_info if credentials_fields_info else {}
|
||||||
|
|
||||||
async def _check_block_credentials(
|
async def _check_block_credentials(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
block: Any,
|
block: Any,
|
||||||
input_data: dict[str, Any] | None = None,
|
|
||||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||||
"""
|
"""
|
||||||
Check if user has required credentials for a block.
|
Check if user has required credentials for a block.
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
block: Block to check credentials for
|
|
||||||
input_data: Input data for the block (used to determine provider via discriminator)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[matched_credentials, missing_credentials]
|
tuple[matched_credentials, missing_credentials]
|
||||||
"""
|
"""
|
||||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
requirements = self._get_credentials_requirements(block)
|
||||||
missing_credentials: list[CredentialsMetaInput] = []
|
|
||||||
input_data = input_data or {}
|
|
||||||
|
|
||||||
# Get credential field info from block's input schema
|
if not requirements:
|
||||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
return {}, []
|
||||||
|
|
||||||
if not credentials_fields_info:
|
return await match_credentials_to_requirements(user_id, requirements)
|
||||||
return matched_credentials, missing_credentials
|
|
||||||
|
|
||||||
# Get user's available credentials
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
|
||||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
|
||||||
|
|
||||||
for field_name, field_info in credentials_fields_info.items():
|
|
||||||
effective_field_info = field_info
|
|
||||||
if field_info.discriminator and field_info.discriminator_mapping:
|
|
||||||
# Get discriminator from input, falling back to schema default
|
|
||||||
discriminator_value = input_data.get(field_info.discriminator)
|
|
||||||
if discriminator_value is None:
|
|
||||||
field = block.input_schema.model_fields.get(
|
|
||||||
field_info.discriminator
|
|
||||||
)
|
|
||||||
if field and field.default is not PydanticUndefined:
|
|
||||||
discriminator_value = field.default
|
|
||||||
|
|
||||||
if (
|
|
||||||
discriminator_value
|
|
||||||
and discriminator_value in field_info.discriminator_mapping
|
|
||||||
):
|
|
||||||
effective_field_info = field_info.discriminate(discriminator_value)
|
|
||||||
logger.debug(
|
|
||||||
f"Discriminated provider for {field_name}: "
|
|
||||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
|
||||||
)
|
|
||||||
|
|
||||||
matching_cred = next(
|
|
||||||
(
|
|
||||||
cred
|
|
||||||
for cred in available_creds
|
|
||||||
if cred.provider in effective_field_info.provider
|
|
||||||
and cred.type in effective_field_info.supported_types
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if matching_cred:
|
|
||||||
matched_credentials[field_name] = CredentialsMetaInput(
|
|
||||||
id=matching_cred.id,
|
|
||||||
provider=matching_cred.provider, # type: ignore
|
|
||||||
type=matching_cred.type,
|
|
||||||
title=matching_cred.title,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Create a placeholder for the missing credential
|
|
||||||
provider = next(iter(effective_field_info.provider), "unknown")
|
|
||||||
cred_type = next(iter(effective_field_info.supported_types), "api_key")
|
|
||||||
missing_credentials.append(
|
|
||||||
CredentialsMetaInput(
|
|
||||||
id=field_name,
|
|
||||||
provider=provider, # type: ignore
|
|
||||||
type=cred_type, # type: ignore
|
|
||||||
title=field_name.replace("_", " ").title(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return matched_credentials, missing_credentials
|
|
||||||
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self,
|
self,
|
||||||
@@ -214,9 +165,10 @@ class RunBlockTool(BaseTool):
|
|||||||
|
|
||||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||||
|
|
||||||
|
# Check credentials
|
||||||
creds_manager = IntegrationCredentialsManager()
|
creds_manager = IntegrationCredentialsManager()
|
||||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
matched_credentials, missing_credentials = await self._check_block_credentials(
|
||||||
user_id, block, input_data
|
user_id, block
|
||||||
)
|
)
|
||||||
|
|
||||||
if missing_credentials:
|
if missing_credentials:
|
||||||
@@ -347,27 +299,7 @@ class RunBlockTool(BaseTool):
|
|||||||
|
|
||||||
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
||||||
"""Extract non-credential inputs from block schema."""
|
"""Extract non-credential inputs from block schema."""
|
||||||
inputs_list = []
|
|
||||||
schema = block.input_schema.jsonschema()
|
schema = block.input_schema.jsonschema()
|
||||||
properties = schema.get("properties", {})
|
|
||||||
required_fields = set(schema.get("required", []))
|
|
||||||
|
|
||||||
# Get credential field names to exclude
|
# Get credential field names to exclude
|
||||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||||
|
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
|
||||||
for field_name, field_schema in properties.items():
|
|
||||||
# Skip credential fields
|
|
||||||
if field_name in credentials_fields:
|
|
||||||
continue
|
|
||||||
|
|
||||||
inputs_list.append(
|
|
||||||
{
|
|
||||||
"name": field_name,
|
|
||||||
"title": field_schema.get("title", field_name),
|
|
||||||
"type": field_schema.get("type", "string"),
|
|
||||||
"description": field_schema.get("description", ""),
|
|
||||||
"required": field_name in required_fields,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return inputs_list
|
|
||||||
|
|||||||
@@ -6,16 +6,10 @@ from typing import Any
|
|||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.api.features.library import model as library_model
|
from backend.api.features.library import model as library_model
|
||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
|
from backend.data import graph as graph_db
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import (
|
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
|
||||||
Credentials,
|
|
||||||
CredentialsFieldInfo,
|
|
||||||
CredentialsMetaInput,
|
|
||||||
HostScopedCredentials,
|
|
||||||
OAuth2Credentials,
|
|
||||||
)
|
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -45,8 +39,14 @@ async def fetch_graph_from_store_slug(
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# Get the graph from store listing version
|
# Get the graph from store listing version
|
||||||
graph = await store_db.get_available_graph(
|
graph_meta = await store_db.get_available_graph(
|
||||||
store_agent.store_listing_version_id, hide_nodes=False
|
store_agent.store_listing_version_id
|
||||||
|
)
|
||||||
|
graph = await graph_db.get_graph(
|
||||||
|
graph_id=graph_meta.id,
|
||||||
|
version=graph_meta.version,
|
||||||
|
user_id=None, # Public access
|
||||||
|
include_subgraphs=True,
|
||||||
)
|
)
|
||||||
return graph, store_agent
|
return graph, store_agent
|
||||||
|
|
||||||
@@ -123,7 +123,7 @@ def build_missing_credentials_from_graph(
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
field_key: _serialize_missing_credential(field_key, field_info)
|
field_key: _serialize_missing_credential(field_key, field_info)
|
||||||
for field_key, (field_info, _, _) in aggregated_fields.items()
|
for field_key, (field_info, _node_fields) in aggregated_fields.items()
|
||||||
if field_key not in matched_keys
|
if field_key not in matched_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -225,6 +225,127 @@ async def get_or_create_library_agent(
|
|||||||
return library_agents[0]
|
return library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_credentials(user_id: str) -> list:
|
||||||
|
"""
|
||||||
|
Get all available credentials for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of user's credentials
|
||||||
|
"""
|
||||||
|
creds_manager = IntegrationCredentialsManager()
|
||||||
|
return await creds_manager.store.get_all_creds(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def find_matching_credential(
|
||||||
|
available_creds: list,
|
||||||
|
field_info: CredentialsFieldInfo,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Find a credential that matches the required provider, type, and scopes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
available_creds: List of user's available credentials
|
||||||
|
field_info: CredentialsFieldInfo with provider, type, and scope requirements
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Matching credential or None
|
||||||
|
"""
|
||||||
|
for cred in available_creds:
|
||||||
|
if cred.provider not in field_info.provider:
|
||||||
|
continue
|
||||||
|
if cred.type not in field_info.supported_types:
|
||||||
|
continue
|
||||||
|
if not _credential_has_required_scopes(cred, field_info):
|
||||||
|
continue
|
||||||
|
return cred
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_credential_meta_from_match(
|
||||||
|
matching_cred,
|
||||||
|
) -> CredentialsMetaInput:
|
||||||
|
"""
|
||||||
|
Create a CredentialsMetaInput from a matched credential.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
matching_cred: The matched credential object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CredentialsMetaInput instance
|
||||||
|
"""
|
||||||
|
return CredentialsMetaInput(
|
||||||
|
id=matching_cred.id,
|
||||||
|
provider=matching_cred.provider, # type: ignore
|
||||||
|
type=matching_cred.type,
|
||||||
|
title=matching_cred.title,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def match_credentials_to_requirements(
|
||||||
|
user_id: str,
|
||||||
|
requirements: dict[str, CredentialsFieldInfo],
|
||||||
|
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||||
|
"""
|
||||||
|
Match user's credentials against a dictionary of credential requirements.
|
||||||
|
|
||||||
|
This is the core matching logic shared by both graph and block credential matching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
requirements: Dict mapping field names to CredentialsFieldInfo
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[matched_credentials dict, missing_credentials list]
|
||||||
|
"""
|
||||||
|
matched: dict[str, CredentialsMetaInput] = {}
|
||||||
|
missing: list[CredentialsMetaInput] = []
|
||||||
|
|
||||||
|
if not requirements:
|
||||||
|
return matched, missing
|
||||||
|
|
||||||
|
available_creds = await get_user_credentials(user_id)
|
||||||
|
|
||||||
|
for field_name, field_info in requirements.items():
|
||||||
|
matching_cred = find_matching_credential(available_creds, field_info)
|
||||||
|
|
||||||
|
if matching_cred:
|
||||||
|
try:
|
||||||
|
matched[field_name] = create_credential_meta_from_match(matching_cred)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create CredentialsMetaInput for field '{field_name}': "
|
||||||
|
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
||||||
|
f"credential_id={matching_cred.id}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
provider = next(iter(field_info.provider), "unknown")
|
||||||
|
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||||
|
missing.append(
|
||||||
|
CredentialsMetaInput(
|
||||||
|
id=field_name,
|
||||||
|
provider=provider, # type: ignore
|
||||||
|
type=cred_type, # type: ignore
|
||||||
|
title=f"{field_name} (validation failed: {e})",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
provider = next(iter(field_info.provider), "unknown")
|
||||||
|
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||||
|
missing.append(
|
||||||
|
CredentialsMetaInput(
|
||||||
|
id=field_name,
|
||||||
|
provider=provider, # type: ignore
|
||||||
|
type=cred_type, # type: ignore
|
||||||
|
title=field_name.replace("_", " ").title(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matched, missing
|
||||||
|
|
||||||
|
|
||||||
async def match_user_credentials_to_graph(
|
async def match_user_credentials_to_graph(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
graph: GraphModel,
|
graph: GraphModel,
|
||||||
@@ -242,9 +363,6 @@ async def match_user_credentials_to_graph(
|
|||||||
Returns:
|
Returns:
|
||||||
tuple[matched_credentials dict, missing_credential_descriptions list]
|
tuple[matched_credentials dict, missing_credential_descriptions list]
|
||||||
"""
|
"""
|
||||||
graph_credentials_inputs: dict[str, CredentialsMetaInput] = {}
|
|
||||||
missing_creds: list[str] = []
|
|
||||||
|
|
||||||
# Get aggregated credentials requirements from the graph
|
# Get aggregated credentials requirements from the graph
|
||||||
aggregated_creds = graph.aggregate_credentials_inputs()
|
aggregated_creds = graph.aggregate_credentials_inputs()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -252,88 +370,46 @@ async def match_user_credentials_to_graph(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not aggregated_creds:
|
if not aggregated_creds:
|
||||||
return graph_credentials_inputs, missing_creds
|
return {}, []
|
||||||
|
|
||||||
# Get all available credentials for the user
|
# Convert aggregated format to simple requirements dict
|
||||||
creds_manager = IntegrationCredentialsManager()
|
requirements = {
|
||||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
field_name: field_info
|
||||||
|
for field_name, (field_info, _node_fields) in aggregated_creds.items()
|
||||||
|
}
|
||||||
|
|
||||||
# For each required credential field, find a matching user credential
|
# Use shared matching logic
|
||||||
# field_info.provider is a frozenset because aggregate_credentials_inputs()
|
matched, missing_list = await match_credentials_to_requirements(
|
||||||
# combines requirements from multiple nodes. A credential matches if its
|
user_id, requirements
|
||||||
# provider is in the set of acceptable providers.
|
|
||||||
for credential_field_name, (
|
|
||||||
credential_requirements,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
) in aggregated_creds.items():
|
|
||||||
# Find first matching credential by provider, type, scopes, and host/URL
|
|
||||||
matching_cred = next(
|
|
||||||
(
|
|
||||||
cred
|
|
||||||
for cred in available_creds
|
|
||||||
if cred.provider in credential_requirements.provider
|
|
||||||
and cred.type in credential_requirements.supported_types
|
|
||||||
and (
|
|
||||||
cred.type != "oauth2"
|
|
||||||
or _credential_has_required_scopes(cred, credential_requirements)
|
|
||||||
)
|
|
||||||
and (
|
|
||||||
cred.type != "host_scoped"
|
|
||||||
or _credential_is_for_host(cred, credential_requirements)
|
|
||||||
)
|
|
||||||
and (
|
|
||||||
cred.provider != ProviderName.MCP
|
|
||||||
or _credential_is_for_mcp_server(cred, credential_requirements)
|
|
||||||
)
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if matching_cred:
|
|
||||||
try:
|
|
||||||
graph_credentials_inputs[credential_field_name] = CredentialsMetaInput(
|
|
||||||
id=matching_cred.id,
|
|
||||||
provider=matching_cred.provider, # type: ignore
|
|
||||||
type=matching_cred.type,
|
|
||||||
title=matching_cred.title,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to create CredentialsMetaInput for field '{credential_field_name}': "
|
|
||||||
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
|
||||||
f"credential_id={matching_cred.id}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
missing_creds.append(
|
|
||||||
f"{credential_field_name} (validation failed: {e})"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Build a helpful error message including scope requirements
|
|
||||||
error_parts = [
|
|
||||||
f"provider in {list(credential_requirements.provider)}",
|
|
||||||
f"type in {list(credential_requirements.supported_types)}",
|
|
||||||
]
|
|
||||||
if credential_requirements.required_scopes:
|
|
||||||
error_parts.append(
|
|
||||||
f"scopes including {list(credential_requirements.required_scopes)}"
|
|
||||||
)
|
|
||||||
missing_creds.append(
|
|
||||||
f"{credential_field_name} (requires {', '.join(error_parts)})"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Credential matching complete: {len(graph_credentials_inputs)}/{len(aggregated_creds)} matched"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return graph_credentials_inputs, missing_creds
|
# Convert missing list to string descriptions for backward compatibility
|
||||||
|
missing_descriptions = [
|
||||||
|
f"{cred.id} (requires provider={cred.provider}, type={cred.type})"
|
||||||
|
for cred in missing_list
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Credential matching complete: {len(matched)}/{len(aggregated_creds)} matched"
|
||||||
|
)
|
||||||
|
|
||||||
|
return matched, missing_descriptions
|
||||||
|
|
||||||
|
|
||||||
def _credential_has_required_scopes(
|
def _credential_has_required_scopes(
|
||||||
credential: OAuth2Credentials,
|
credential: Credentials,
|
||||||
requirements: CredentialsFieldInfo,
|
requirements: CredentialsFieldInfo,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if an OAuth2 credential has all the scopes required by the input."""
|
"""
|
||||||
|
Check if a credential has all the scopes required by the block.
|
||||||
|
|
||||||
|
For OAuth2 credentials, verifies that the credential's scopes are a superset
|
||||||
|
of the required scopes. For other credential types, returns True (no scope check).
|
||||||
|
"""
|
||||||
|
# Only OAuth2 credentials have scopes to check
|
||||||
|
if credential.type != "oauth2":
|
||||||
|
return True
|
||||||
|
|
||||||
# If no scopes are required, any credential matches
|
# If no scopes are required, any credential matches
|
||||||
if not requirements.required_scopes:
|
if not requirements.required_scopes:
|
||||||
return True
|
return True
|
||||||
@@ -342,38 +418,6 @@ def _credential_has_required_scopes(
|
|||||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||||
|
|
||||||
|
|
||||||
def _credential_is_for_host(
|
|
||||||
credential: HostScopedCredentials,
|
|
||||||
requirements: CredentialsFieldInfo,
|
|
||||||
) -> bool:
|
|
||||||
"""Check if a host-scoped credential matches the host required by the input."""
|
|
||||||
# We need to know the host to match host-scoped credentials to.
|
|
||||||
# Graph.aggregate_credentials_inputs() adds the node's set URL value (if any)
|
|
||||||
# to discriminator_values. No discriminator_values -> no host to match against.
|
|
||||||
if not requirements.discriminator_values:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check that credential host matches required host.
|
|
||||||
# Host-scoped credential inputs are grouped by host, so any item from the set works.
|
|
||||||
return credential.matches_url(list(requirements.discriminator_values)[0])
|
|
||||||
|
|
||||||
|
|
||||||
def _credential_is_for_mcp_server(
|
|
||||||
credential: Credentials,
|
|
||||||
requirements: CredentialsFieldInfo,
|
|
||||||
) -> bool:
|
|
||||||
"""Check if an MCP OAuth credential matches the required server URL."""
|
|
||||||
if not requirements.discriminator_values:
|
|
||||||
return True
|
|
||||||
|
|
||||||
server_url = (
|
|
||||||
credential.metadata.get("mcp_server_url")
|
|
||||||
if isinstance(credential, OAuth2Credentials)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
return server_url in requirements.discriminator_values if server_url else False
|
|
||||||
|
|
||||||
|
|
||||||
async def check_user_has_required_credentials(
|
async def check_user_has_required_credentials(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
required_credentials: list[CredentialsMetaInput],
|
required_credentials: list[CredentialsMetaInput],
|
||||||
|
|||||||
@@ -102,19 +102,9 @@ class CredentialsMetaResponse(BaseModel):
|
|||||||
scopes: list[str] | None
|
scopes: list[str] | None
|
||||||
username: str | None
|
username: str | None
|
||||||
host: str | None = Field(
|
host: str | None = Field(
|
||||||
default=None,
|
default=None, description="Host pattern for host-scoped credentials"
|
||||||
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_host(cred: Credentials) -> str | None:
|
|
||||||
"""Extract host from credential: HostScoped host or MCP server URL."""
|
|
||||||
if isinstance(cred, HostScopedCredentials):
|
|
||||||
return cred.host
|
|
||||||
if isinstance(cred, OAuth2Credentials) and cred.provider == ProviderName.MCP:
|
|
||||||
return (cred.metadata or {}).get("mcp_server_url")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
||||||
async def callback(
|
async def callback(
|
||||||
@@ -189,7 +179,9 @@ async def callback(
|
|||||||
title=credentials.title,
|
title=credentials.title,
|
||||||
scopes=credentials.scopes,
|
scopes=credentials.scopes,
|
||||||
username=credentials.username,
|
username=credentials.username,
|
||||||
host=(CredentialsMetaResponse.get_host(credentials)),
|
host=(
|
||||||
|
credentials.host if isinstance(credentials, HostScopedCredentials) else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -207,7 +199,7 @@ async def list_credentials(
|
|||||||
title=cred.title,
|
title=cred.title,
|
||||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||||
host=CredentialsMetaResponse.get_host(cred),
|
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||||
)
|
)
|
||||||
for cred in credentials
|
for cred in credentials
|
||||||
]
|
]
|
||||||
@@ -230,7 +222,7 @@ async def list_credentials_by_provider(
|
|||||||
title=cred.title,
|
title=cred.title,
|
||||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||||
host=CredentialsMetaResponse.get_host(cred),
|
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||||
)
|
)
|
||||||
for cred in credentials
|
for cred in credentials
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -19,10 +19,7 @@ from backend.data.graph import GraphSettings
|
|||||||
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||||
on_graph_activate,
|
|
||||||
on_graph_deactivate,
|
|
||||||
)
|
|
||||||
from backend.util.clients import get_scheduler_client
|
from backend.util.clients import get_scheduler_client
|
||||||
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
|
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
@@ -374,7 +371,7 @@ async def get_library_agent_by_graph_id(
|
|||||||
|
|
||||||
|
|
||||||
async def add_generated_agent_image(
|
async def add_generated_agent_image(
|
||||||
graph: graph_db.GraphBaseMeta,
|
graph: graph_db.BaseGraph,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
) -> Optional[prisma.models.LibraryAgent]:
|
) -> Optional[prisma.models.LibraryAgent]:
|
||||||
@@ -540,92 +537,6 @@ async def update_agent_version_in_library(
|
|||||||
return library_model.LibraryAgent.from_db(lib)
|
return library_model.LibraryAgent.from_db(lib)
|
||||||
|
|
||||||
|
|
||||||
async def create_graph_in_library(
|
|
||||||
graph: graph_db.Graph,
|
|
||||||
user_id: str,
|
|
||||||
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
|
||||||
"""Create a new graph and add it to the user's library."""
|
|
||||||
graph.version = 1
|
|
||||||
graph_model = graph_db.make_graph_model(graph, user_id)
|
|
||||||
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
|
||||||
|
|
||||||
created_graph = await graph_db.create_graph(graph_model, user_id)
|
|
||||||
|
|
||||||
library_agents = await create_library_agent(
|
|
||||||
graph=created_graph,
|
|
||||||
user_id=user_id,
|
|
||||||
sensitive_action_safe_mode=True,
|
|
||||||
create_library_agents_for_sub_graphs=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if created_graph.is_active:
|
|
||||||
created_graph = await on_graph_activate(created_graph, user_id=user_id)
|
|
||||||
|
|
||||||
return created_graph, library_agents[0]
|
|
||||||
|
|
||||||
|
|
||||||
async def update_graph_in_library(
|
|
||||||
graph: graph_db.Graph,
|
|
||||||
user_id: str,
|
|
||||||
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
|
||||||
"""Create a new version of an existing graph and update the library entry."""
|
|
||||||
existing_versions = await graph_db.get_graph_all_versions(graph.id, user_id)
|
|
||||||
current_active_version = (
|
|
||||||
next((v for v in existing_versions if v.is_active), None)
|
|
||||||
if existing_versions
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
graph.version = (
|
|
||||||
max(v.version for v in existing_versions) + 1 if existing_versions else 1
|
|
||||||
)
|
|
||||||
|
|
||||||
graph_model = graph_db.make_graph_model(graph, user_id)
|
|
||||||
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
|
||||||
|
|
||||||
created_graph = await graph_db.create_graph(graph_model, user_id)
|
|
||||||
|
|
||||||
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
|
|
||||||
if not library_agent:
|
|
||||||
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
|
|
||||||
|
|
||||||
library_agent = await update_library_agent_version_and_settings(
|
|
||||||
user_id, created_graph
|
|
||||||
)
|
|
||||||
|
|
||||||
if created_graph.is_active:
|
|
||||||
created_graph = await on_graph_activate(created_graph, user_id=user_id)
|
|
||||||
await graph_db.set_graph_active_version(
|
|
||||||
graph_id=created_graph.id,
|
|
||||||
version=created_graph.version,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
if current_active_version:
|
|
||||||
await on_graph_deactivate(current_active_version, user_id=user_id)
|
|
||||||
|
|
||||||
return created_graph, library_agent
|
|
||||||
|
|
||||||
|
|
||||||
async def update_library_agent_version_and_settings(
|
|
||||||
user_id: str, agent_graph: graph_db.GraphModel
|
|
||||||
) -> library_model.LibraryAgent:
|
|
||||||
"""Update library agent to point to new graph version and sync settings."""
|
|
||||||
library = await update_agent_version_in_library(
|
|
||||||
user_id, agent_graph.id, agent_graph.version
|
|
||||||
)
|
|
||||||
updated_settings = GraphSettings.from_graph(
|
|
||||||
graph=agent_graph,
|
|
||||||
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
|
||||||
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
|
||||||
)
|
|
||||||
if updated_settings != library.settings:
|
|
||||||
library = await update_library_agent(
|
|
||||||
library_agent_id=library.id,
|
|
||||||
user_id=user_id,
|
|
||||||
settings=updated_settings,
|
|
||||||
)
|
|
||||||
return library
|
|
||||||
|
|
||||||
|
|
||||||
async def update_library_agent(
|
async def update_library_agent(
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
|||||||
@@ -1,412 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP (Model Context Protocol) API routes.
|
|
||||||
|
|
||||||
Provides endpoints for MCP tool discovery and OAuth authentication so the
|
|
||||||
frontend can list available tools on an MCP server before placing a block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Annotated, Any
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import fastapi
|
|
||||||
from autogpt_libs.auth import get_user_id
|
|
||||||
from fastapi import Security
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from backend.api.features.integrations.router import CredentialsMetaResponse
|
|
||||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
|
||||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util.request import HTTPClientError, Requests
|
|
||||||
from backend.util.settings import Settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
router = fastapi.APIRouter(tags=["mcp"])
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Tool Discovery ====================== #
|
|
||||||
|
|
||||||
|
|
||||||
class DiscoverToolsRequest(BaseModel):
|
|
||||||
"""Request to discover tools on an MCP server."""
|
|
||||||
|
|
||||||
server_url: str = Field(description="URL of the MCP server")
|
|
||||||
auth_token: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Optional Bearer token for authenticated MCP servers",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MCPToolResponse(BaseModel):
|
|
||||||
"""A single MCP tool returned by discovery."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
input_schema: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class DiscoverToolsResponse(BaseModel):
|
|
||||||
"""Response containing the list of tools available on an MCP server."""
|
|
||||||
|
|
||||||
tools: list[MCPToolResponse]
|
|
||||||
server_name: str | None = None
|
|
||||||
protocol_version: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/discover-tools",
|
|
||||||
summary="Discover available tools on an MCP server",
|
|
||||||
response_model=DiscoverToolsResponse,
|
|
||||||
)
|
|
||||||
async def discover_tools(
|
|
||||||
request: DiscoverToolsRequest,
|
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
|
||||||
) -> DiscoverToolsResponse:
|
|
||||||
"""
|
|
||||||
Connect to an MCP server and return its available tools.
|
|
||||||
|
|
||||||
If the user has a stored MCP credential for this server URL, it will be
|
|
||||||
used automatically — no need to pass an explicit auth token.
|
|
||||||
"""
|
|
||||||
auth_token = request.auth_token
|
|
||||||
|
|
||||||
# Auto-use stored MCP credential when no explicit token is provided
|
|
||||||
if not auth_token:
|
|
||||||
try:
|
|
||||||
mcp_creds = await creds_manager.store.get_creds_by_provider(
|
|
||||||
user_id, str(ProviderName.MCP)
|
|
||||||
)
|
|
||||||
# Find the freshest credential for this server URL
|
|
||||||
best_cred: OAuth2Credentials | None = None
|
|
||||||
for cred in mcp_creds:
|
|
||||||
if (
|
|
||||||
isinstance(cred, OAuth2Credentials)
|
|
||||||
and cred.metadata.get("mcp_server_url") == request.server_url
|
|
||||||
):
|
|
||||||
if best_cred is None or (
|
|
||||||
(cred.access_token_expires_at or 0)
|
|
||||||
> (best_cred.access_token_expires_at or 0)
|
|
||||||
):
|
|
||||||
best_cred = cred
|
|
||||||
if best_cred:
|
|
||||||
logger.info(
|
|
||||||
f"Using MCP credential {best_cred.id} for {request.server_url}, "
|
|
||||||
f"expires_at={best_cred.access_token_expires_at}"
|
|
||||||
)
|
|
||||||
auth_token = best_cred.access_token.get_secret_value()
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Could not look up stored MCP credentials", exc_info=True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
client = MCPClient(request.server_url, auth_token=auth_token)
|
|
||||||
|
|
||||||
init_result = await client.initialize()
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
return DiscoverToolsResponse(
|
|
||||||
tools=[
|
|
||||||
MCPToolResponse(
|
|
||||||
name=t.name,
|
|
||||||
description=t.description,
|
|
||||||
input_schema=t.input_schema,
|
|
||||||
)
|
|
||||||
for t in tools
|
|
||||||
],
|
|
||||||
server_name=init_result.get("serverInfo", {}).get("name"),
|
|
||||||
protocol_version=init_result.get("protocolVersion"),
|
|
||||||
)
|
|
||||||
except HTTPClientError as e:
|
|
||||||
if e.status_code in (401, 403):
|
|
||||||
logger.warning(
|
|
||||||
f"MCP server returned {e.status_code} for {request.server_url}: {e}"
|
|
||||||
)
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail="This MCP server requires authentication. "
|
|
||||||
"Please provide a valid auth token.",
|
|
||||||
)
|
|
||||||
raise fastapi.HTTPException(status_code=502, detail=str(e))
|
|
||||||
except MCPClientError as e:
|
|
||||||
raise fastapi.HTTPException(status_code=502, detail=str(e))
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("MCP tool discovery failed")
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=502,
|
|
||||||
detail=f"Failed to connect to MCP server: {str(e)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ======================== OAuth Flow ======================== #
|
|
||||||
|
|
||||||
|
|
||||||
class MCPOAuthLoginRequest(BaseModel):
|
|
||||||
"""Request to start an OAuth flow for an MCP server."""
|
|
||||||
|
|
||||||
server_url: str = Field(description="URL of the MCP server that requires OAuth")
|
|
||||||
|
|
||||||
|
|
||||||
class MCPOAuthLoginResponse(BaseModel):
|
|
||||||
"""Response with the OAuth login URL for the user to authenticate."""
|
|
||||||
|
|
||||||
login_url: str
|
|
||||||
state_token: str
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/oauth/login",
|
|
||||||
summary="Initiate OAuth login for an MCP server",
|
|
||||||
)
|
|
||||||
async def mcp_oauth_login(
|
|
||||||
request: MCPOAuthLoginRequest,
|
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
|
||||||
) -> MCPOAuthLoginResponse:
|
|
||||||
"""
|
|
||||||
Discover OAuth metadata from the MCP server and return a login URL.
|
|
||||||
|
|
||||||
1. Discovers the protected-resource metadata (RFC 9728)
|
|
||||||
2. Fetches the authorization server metadata (RFC 8414)
|
|
||||||
3. Performs Dynamic Client Registration (RFC 7591) if available
|
|
||||||
4. Returns the authorization URL for the frontend to open in a popup
|
|
||||||
"""
|
|
||||||
client = MCPClient(request.server_url)
|
|
||||||
|
|
||||||
# Step 1: Discover protected-resource metadata (RFC 9728)
|
|
||||||
try:
|
|
||||||
protected_resource = await client.discover_auth()
|
|
||||||
except Exception as e:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=502,
|
|
||||||
detail=f"Failed to discover OAuth metadata: {e}",
|
|
||||||
)
|
|
||||||
|
|
||||||
metadata: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
if protected_resource and "authorization_servers" in protected_resource:
|
|
||||||
auth_server_url = protected_resource["authorization_servers"][0]
|
|
||||||
resource_url = protected_resource.get("resource", request.server_url)
|
|
||||||
|
|
||||||
# Step 2a: Discover auth-server metadata (RFC 8414)
|
|
||||||
try:
|
|
||||||
metadata = await client.discover_auth_server_metadata(auth_server_url)
|
|
||||||
except Exception as e:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=502,
|
|
||||||
detail=f"Failed to discover authorization server metadata: {e}",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Fallback: Some MCP servers (e.g. Linear) are their own auth server
|
|
||||||
# and serve OAuth metadata directly without protected-resource metadata.
|
|
||||||
# Don't assume a resource_url — omitting it lets the auth server choose
|
|
||||||
# the correct audience for the token (RFC 8707 resource is optional).
|
|
||||||
resource_url = None
|
|
||||||
try:
|
|
||||||
metadata = await client.discover_auth_server_metadata(request.server_url)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not metadata or "authorization_endpoint" not in metadata:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="This MCP server does not advertise OAuth support. "
|
|
||||||
"You may need to provide an auth token manually.",
|
|
||||||
)
|
|
||||||
|
|
||||||
authorize_url = metadata["authorization_endpoint"]
|
|
||||||
token_url = metadata["token_endpoint"]
|
|
||||||
registration_endpoint = metadata.get("registration_endpoint")
|
|
||||||
revoke_url = metadata.get("revocation_endpoint")
|
|
||||||
|
|
||||||
# Step 3: Dynamic Client Registration (RFC 7591) if available
|
|
||||||
frontend_base_url = settings.config.frontend_base_url
|
|
||||||
if not frontend_base_url:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Frontend base URL is not configured.",
|
|
||||||
)
|
|
||||||
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
|
|
||||||
|
|
||||||
client_id = ""
|
|
||||||
client_secret = ""
|
|
||||||
if registration_endpoint:
|
|
||||||
reg_result = await _register_mcp_client(
|
|
||||||
registration_endpoint, redirect_uri, request.server_url
|
|
||||||
)
|
|
||||||
if reg_result:
|
|
||||||
client_id = reg_result.get("client_id", "")
|
|
||||||
client_secret = reg_result.get("client_secret", "")
|
|
||||||
|
|
||||||
if not client_id:
|
|
||||||
client_id = "autogpt-platform"
|
|
||||||
|
|
||||||
# Step 4: Store state token with OAuth metadata for the callback
|
|
||||||
scopes = (protected_resource or {}).get("scopes_supported") or metadata.get(
|
|
||||||
"scopes_supported", []
|
|
||||||
)
|
|
||||||
state_token, code_challenge = await creds_manager.store.store_state_token(
|
|
||||||
user_id,
|
|
||||||
str(ProviderName.MCP),
|
|
||||||
scopes,
|
|
||||||
state_metadata={
|
|
||||||
"authorize_url": authorize_url,
|
|
||||||
"token_url": token_url,
|
|
||||||
"revoke_url": revoke_url,
|
|
||||||
"resource_url": resource_url,
|
|
||||||
"server_url": request.server_url,
|
|
||||||
"client_id": client_id,
|
|
||||||
"client_secret": client_secret,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 5: Build and return the login URL
|
|
||||||
handler = MCPOAuthHandler(
|
|
||||||
client_id=client_id,
|
|
||||||
client_secret=client_secret,
|
|
||||||
redirect_uri=redirect_uri,
|
|
||||||
authorize_url=authorize_url,
|
|
||||||
token_url=token_url,
|
|
||||||
resource_url=resource_url,
|
|
||||||
)
|
|
||||||
login_url = handler.get_login_url(
|
|
||||||
scopes, state_token, code_challenge=code_challenge
|
|
||||||
)
|
|
||||||
|
|
||||||
return MCPOAuthLoginResponse(login_url=login_url, state_token=state_token)
|
|
||||||
|
|
||||||
|
|
||||||
class MCPOAuthCallbackRequest(BaseModel):
|
|
||||||
"""Request to exchange an OAuth code for tokens."""
|
|
||||||
|
|
||||||
code: str = Field(description="Authorization code from OAuth callback")
|
|
||||||
state_token: str = Field(description="State token for CSRF verification")
|
|
||||||
|
|
||||||
|
|
||||||
class MCPOAuthCallbackResponse(BaseModel):
|
|
||||||
"""Response after successfully storing OAuth credentials."""
|
|
||||||
|
|
||||||
credential_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/oauth/callback",
|
|
||||||
summary="Exchange OAuth code for MCP tokens",
|
|
||||||
)
|
|
||||||
async def mcp_oauth_callback(
|
|
||||||
request: MCPOAuthCallbackRequest,
|
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
|
||||||
) -> CredentialsMetaResponse:
|
|
||||||
"""
|
|
||||||
Exchange the authorization code for tokens and store the credential.
|
|
||||||
|
|
||||||
The frontend calls this after receiving the OAuth code from the popup.
|
|
||||||
On success, subsequent ``/discover-tools`` calls for the same server URL
|
|
||||||
will automatically use the stored credential.
|
|
||||||
"""
|
|
||||||
valid_state = await creds_manager.store.verify_state_token(
|
|
||||||
user_id, request.state_token, str(ProviderName.MCP)
|
|
||||||
)
|
|
||||||
if not valid_state:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Invalid or expired state token.",
|
|
||||||
)
|
|
||||||
|
|
||||||
meta = valid_state.state_metadata
|
|
||||||
frontend_base_url = settings.config.frontend_base_url
|
|
||||||
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
|
|
||||||
|
|
||||||
handler = MCPOAuthHandler(
|
|
||||||
client_id=meta["client_id"],
|
|
||||||
client_secret=meta.get("client_secret", ""),
|
|
||||||
redirect_uri=redirect_uri,
|
|
||||||
authorize_url=meta["authorize_url"],
|
|
||||||
token_url=meta["token_url"],
|
|
||||||
revoke_url=meta.get("revoke_url"),
|
|
||||||
resource_url=meta.get("resource_url"),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
credentials = await handler.exchange_code_for_tokens(
|
|
||||||
request.code, valid_state.scopes, valid_state.code_verifier
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("MCP OAuth token exchange failed")
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"OAuth token exchange failed: {e}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Enrich credential metadata for future lookup and token refresh
|
|
||||||
if credentials.metadata is None:
|
|
||||||
credentials.metadata = {}
|
|
||||||
credentials.metadata["mcp_server_url"] = meta["server_url"]
|
|
||||||
credentials.metadata["mcp_client_id"] = meta["client_id"]
|
|
||||||
credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "")
|
|
||||||
credentials.metadata["mcp_token_url"] = meta["token_url"]
|
|
||||||
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
|
|
||||||
|
|
||||||
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
|
|
||||||
credentials.title = f"MCP: {hostname}"
|
|
||||||
|
|
||||||
# Remove old MCP credentials for the same server to prevent stale token buildup
|
|
||||||
try:
|
|
||||||
old_creds = await creds_manager.store.get_creds_by_provider(
|
|
||||||
user_id, str(ProviderName.MCP)
|
|
||||||
)
|
|
||||||
for old in old_creds:
|
|
||||||
if (
|
|
||||||
isinstance(old, OAuth2Credentials)
|
|
||||||
and old.metadata.get("mcp_server_url") == meta["server_url"]
|
|
||||||
):
|
|
||||||
await creds_manager.store.delete_creds_by_id(user_id, old.id)
|
|
||||||
logger.info(
|
|
||||||
f"Removed old MCP credential {old.id} for {meta['server_url']}"
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Could not clean up old MCP credentials", exc_info=True)
|
|
||||||
|
|
||||||
await creds_manager.create(user_id, credentials)
|
|
||||||
|
|
||||||
return CredentialsMetaResponse(
|
|
||||||
id=credentials.id,
|
|
||||||
provider=credentials.provider,
|
|
||||||
type=credentials.type,
|
|
||||||
title=credentials.title,
|
|
||||||
scopes=credentials.scopes,
|
|
||||||
username=credentials.username,
|
|
||||||
host=credentials.metadata.get("mcp_server_url"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ======================== Helpers ======================== #
|
|
||||||
|
|
||||||
|
|
||||||
async def _register_mcp_client(
|
|
||||||
registration_endpoint: str,
|
|
||||||
redirect_uri: str,
|
|
||||||
server_url: str,
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Attempt Dynamic Client Registration (RFC 7591) with an MCP auth server."""
|
|
||||||
try:
|
|
||||||
response = await Requests(raise_for_status=True).post(
|
|
||||||
registration_endpoint,
|
|
||||||
json={
|
|
||||||
"client_name": "AutoGPT Platform",
|
|
||||||
"redirect_uris": [redirect_uri],
|
|
||||||
"grant_types": ["authorization_code"],
|
|
||||||
"response_types": ["code"],
|
|
||||||
"token_endpoint_auth_method": "client_secret_post",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
data = response.json()
|
|
||||||
if isinstance(data, dict) and "client_id" in data:
|
|
||||||
return data
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Dynamic client registration failed for {server_url}: {e}")
|
|
||||||
return None
|
|
||||||
@@ -1,388 +0,0 @@
|
|||||||
"""Tests for MCP API routes."""
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, patch
|
|
||||||
|
|
||||||
import fastapi
|
|
||||||
import fastapi.testclient
|
|
||||||
from autogpt_libs.auth import get_user_id
|
|
||||||
|
|
||||||
from backend.api.features.mcp.routes import router
|
|
||||||
from backend.blocks.mcp.client import MCPClientError, MCPTool
|
|
||||||
from backend.util.request import HTTPClientError
|
|
||||||
|
|
||||||
app = fastapi.FastAPI()
|
|
||||||
app.include_router(router)
|
|
||||||
app.dependency_overrides[get_user_id] = lambda: "test-user-id"
|
|
||||||
client = fastapi.testclient.TestClient(app)
|
|
||||||
|
|
||||||
|
|
||||||
class TestDiscoverTools:
|
|
||||||
def test_discover_tools_success(self):
|
|
||||||
mock_tools = [
|
|
||||||
MCPTool(
|
|
||||||
name="get_weather",
|
|
||||||
description="Get weather for a city",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"city": {"type": "string"}},
|
|
||||||
"required": ["city"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
MCPTool(
|
|
||||||
name="add_numbers",
|
|
||||||
description="Add two numbers",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"a": {"type": "number"},
|
|
||||||
"b": {"type": "number"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
with (patch("backend.api.features.mcp.routes.MCPClient") as MockClient,):
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
"protocolVersion": "2025-03-26",
|
|
||||||
"serverInfo": {"name": "test-server"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
instance.list_tools = AsyncMock(return_value=mock_tools)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://mcp.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert len(data["tools"]) == 2
|
|
||||||
assert data["tools"][0]["name"] == "get_weather"
|
|
||||||
assert data["tools"][1]["name"] == "add_numbers"
|
|
||||||
assert data["server_name"] == "test-server"
|
|
||||||
assert data["protocol_version"] == "2025-03-26"
|
|
||||||
|
|
||||||
def test_discover_tools_with_auth_token(self):
|
|
||||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
|
||||||
)
|
|
||||||
instance.list_tools = AsyncMock(return_value=[])
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={
|
|
||||||
"server_url": "https://mcp.example.com/mcp",
|
|
||||||
"auth_token": "my-secret-token",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
MockClient.assert_called_once_with(
|
|
||||||
"https://mcp.example.com/mcp",
|
|
||||||
auth_token="my-secret-token",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_discover_tools_auto_uses_stored_credential(self):
|
|
||||||
"""When no explicit token is given, stored MCP credentials are used."""
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
|
|
||||||
stored_cred = OAuth2Credentials(
|
|
||||||
provider="mcp",
|
|
||||||
title="MCP: example.com",
|
|
||||||
access_token=SecretStr("stored-token-123"),
|
|
||||||
refresh_token=None,
|
|
||||||
access_token_expires_at=None,
|
|
||||||
refresh_token_expires_at=None,
|
|
||||||
scopes=[],
|
|
||||||
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
):
|
|
||||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred])
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
|
||||||
)
|
|
||||||
instance.list_tools = AsyncMock(return_value=[])
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://mcp.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
MockClient.assert_called_once_with(
|
|
||||||
"https://mcp.example.com/mcp",
|
|
||||||
auth_token="stored-token-123",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_discover_tools_mcp_error(self):
|
|
||||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
side_effect=MCPClientError("Connection refused")
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://bad-server.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 502
|
|
||||||
assert "Connection refused" in response.json()["detail"]
|
|
||||||
|
|
||||||
def test_discover_tools_generic_error(self):
|
|
||||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://timeout.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 502
|
|
||||||
assert "Failed to connect" in response.json()["detail"]
|
|
||||||
|
|
||||||
def test_discover_tools_auth_required(self):
|
|
||||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://auth-server.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 401
|
|
||||||
assert "requires authentication" in response.json()["detail"]
|
|
||||||
|
|
||||||
def test_discover_tools_forbidden(self):
|
|
||||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://auth-server.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 401
|
|
||||||
assert "requires authentication" in response.json()["detail"]
|
|
||||||
|
|
||||||
def test_discover_tools_missing_url(self):
|
|
||||||
response = client.post("/discover-tools", json={})
|
|
||||||
assert response.status_code == 422
|
|
||||||
|
|
||||||
|
|
||||||
class TestOAuthLogin:
|
|
||||||
def test_oauth_login_success(self):
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
|
||||||
patch(
|
|
||||||
"backend.api.features.mcp.routes._register_mcp_client"
|
|
||||||
) as mock_register,
|
|
||||||
):
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.discover_auth = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
"authorization_servers": ["https://auth.sentry.io"],
|
|
||||||
"resource": "https://mcp.sentry.dev/mcp",
|
|
||||||
"scopes_supported": ["openid"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
instance.discover_auth_server_metadata = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
"authorization_endpoint": "https://auth.sentry.io/authorize",
|
|
||||||
"token_endpoint": "https://auth.sentry.io/token",
|
|
||||||
"registration_endpoint": "https://auth.sentry.io/register",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
mock_register.return_value = {
|
|
||||||
"client_id": "registered-client-id",
|
|
||||||
"client_secret": "registered-secret",
|
|
||||||
}
|
|
||||||
mock_cm.store.store_state_token = AsyncMock(
|
|
||||||
return_value=("state-token-123", "code-challenge-abc")
|
|
||||||
)
|
|
||||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/oauth/login",
|
|
||||||
json={"server_url": "https://mcp.sentry.dev/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert "login_url" in data
|
|
||||||
assert data["state_token"] == "state-token-123"
|
|
||||||
assert "auth.sentry.io/authorize" in data["login_url"]
|
|
||||||
assert "registered-client-id" in data["login_url"]
|
|
||||||
|
|
||||||
def test_oauth_login_no_oauth_support(self):
|
|
||||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.discover_auth = AsyncMock(return_value=None)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/oauth/login",
|
|
||||||
json={"server_url": "https://simple-server.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "does not advertise OAuth" in response.json()["detail"]
|
|
||||||
|
|
||||||
def test_oauth_login_fallback_to_public_client(self):
|
|
||||||
"""When DCR is unavailable, falls back to default public client ID."""
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
|
||||||
):
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.discover_auth = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
"authorization_servers": ["https://auth.example.com"],
|
|
||||||
"resource": "https://mcp.example.com/mcp",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
instance.discover_auth_server_metadata = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
"authorization_endpoint": "https://auth.example.com/authorize",
|
|
||||||
"token_endpoint": "https://auth.example.com/token",
|
|
||||||
# No registration_endpoint
|
|
||||||
}
|
|
||||||
)
|
|
||||||
mock_cm.store.store_state_token = AsyncMock(
|
|
||||||
return_value=("state-abc", "challenge-xyz")
|
|
||||||
)
|
|
||||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/oauth/login",
|
|
||||||
json={"server_url": "https://mcp.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert "autogpt-platform" in data["login_url"]
|
|
||||||
|
|
||||||
|
|
||||||
class TestOAuthCallback:
|
|
||||||
def test_oauth_callback_success(self):
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
|
|
||||||
mock_creds = OAuth2Credentials(
|
|
||||||
provider="mcp",
|
|
||||||
title=None,
|
|
||||||
access_token=SecretStr("access-token-xyz"),
|
|
||||||
refresh_token=None,
|
|
||||||
access_token_expires_at=None,
|
|
||||||
refresh_token_expires_at=None,
|
|
||||||
scopes=[],
|
|
||||||
metadata={
|
|
||||||
"mcp_token_url": "https://auth.sentry.io/token",
|
|
||||||
"mcp_resource_url": "https://mcp.sentry.dev/mcp",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
|
||||||
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
|
|
||||||
):
|
|
||||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
|
||||||
|
|
||||||
# Mock state verification
|
|
||||||
mock_state = AsyncMock()
|
|
||||||
mock_state.state_metadata = {
|
|
||||||
"authorize_url": "https://auth.sentry.io/authorize",
|
|
||||||
"token_url": "https://auth.sentry.io/token",
|
|
||||||
"client_id": "test-client-id",
|
|
||||||
"client_secret": "test-secret",
|
|
||||||
"server_url": "https://mcp.sentry.dev/mcp",
|
|
||||||
}
|
|
||||||
mock_state.scopes = ["openid"]
|
|
||||||
mock_state.code_verifier = "verifier-123"
|
|
||||||
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
|
|
||||||
mock_cm.create = AsyncMock()
|
|
||||||
|
|
||||||
handler_instance = MockHandler.return_value
|
|
||||||
handler_instance.exchange_code_for_tokens = AsyncMock(
|
|
||||||
return_value=mock_creds
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock old credential cleanup
|
|
||||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/oauth/callback",
|
|
||||||
json={"code": "auth-code-abc", "state_token": "state-token-123"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert "id" in data
|
|
||||||
assert data["provider"] == "mcp"
|
|
||||||
assert data["type"] == "oauth2"
|
|
||||||
mock_cm.create.assert_called_once()
|
|
||||||
|
|
||||||
def test_oauth_callback_invalid_state(self):
|
|
||||||
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
|
|
||||||
mock_cm.store.verify_state_token = AsyncMock(return_value=None)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/oauth/callback",
|
|
||||||
json={"code": "auth-code", "state_token": "bad-state"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "Invalid or expired" in response.json()["detail"]
|
|
||||||
|
|
||||||
def test_oauth_callback_token_exchange_fails(self):
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
|
||||||
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
|
|
||||||
):
|
|
||||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
|
||||||
mock_state = AsyncMock()
|
|
||||||
mock_state.state_metadata = {
|
|
||||||
"authorize_url": "https://auth.example.com/authorize",
|
|
||||||
"token_url": "https://auth.example.com/token",
|
|
||||||
"client_id": "cid",
|
|
||||||
"server_url": "https://mcp.example.com/mcp",
|
|
||||||
}
|
|
||||||
mock_state.scopes = []
|
|
||||||
mock_state.code_verifier = "v"
|
|
||||||
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
|
|
||||||
|
|
||||||
handler_instance = MockHandler.return_value
|
|
||||||
handler_instance.exchange_code_for_tokens = AsyncMock(
|
|
||||||
side_effect=RuntimeError("Token exchange failed")
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/oauth/callback",
|
|
||||||
json={"code": "bad-code", "state_token": "state"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "token exchange failed" in response.json()["detail"].lower()
|
|
||||||
@@ -68,7 +68,7 @@ async def test_user(server, test_user_id: str):
|
|||||||
await PrismaUser.prisma().delete(where={"id": test_user_id})
|
await PrismaUser.prisma().delete(where={"id": test_user_id})
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(loop_scope="session")
|
@pytest_asyncio.fixture
|
||||||
async def test_oauth_app(test_user: str):
|
async def test_oauth_app(test_user: str):
|
||||||
"""Create a test OAuth application in the database."""
|
"""Create a test OAuth application in the database."""
|
||||||
app_id = str(uuid.uuid4())
|
app_id = str(uuid.uuid4())
|
||||||
@@ -123,7 +123,7 @@ def pkce_credentials() -> tuple[str, str]:
|
|||||||
return generate_pkce()
|
return generate_pkce()
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(loop_scope="session")
|
@pytest_asyncio.fixture
|
||||||
async def client(server, test_user: str) -> AsyncGenerator[httpx.AsyncClient, None]:
|
async def client(server, test_user: str) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||||
"""
|
"""
|
||||||
Create an async HTTP client that talks directly to the FastAPI app.
|
Create an async HTTP client that talks directly to the FastAPI app.
|
||||||
@@ -288,7 +288,7 @@ async def test_authorize_invalid_client_returns_error(
|
|||||||
assert query_params["error"][0] == "invalid_client"
|
assert query_params["error"][0] == "invalid_client"
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(loop_scope="session")
|
@pytest_asyncio.fixture
|
||||||
async def inactive_oauth_app(test_user: str):
|
async def inactive_oauth_app(test_user: str):
|
||||||
"""Create an inactive test OAuth application in the database."""
|
"""Create an inactive test OAuth application in the database."""
|
||||||
app_id = str(uuid.uuid4())
|
app_id = str(uuid.uuid4())
|
||||||
@@ -1005,7 +1005,7 @@ async def test_token_refresh_revoked(
|
|||||||
assert "revoked" in response.json()["detail"].lower()
|
assert "revoked" in response.json()["detail"].lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(loop_scope="session")
|
@pytest_asyncio.fixture
|
||||||
async def other_oauth_app(test_user: str):
|
async def other_oauth_app(test_user: str):
|
||||||
"""Create a second OAuth application for cross-app tests."""
|
"""Create a second OAuth application for cross-app tests."""
|
||||||
app_id = str(uuid.uuid4())
|
app_id = str(uuid.uuid4())
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Literal, overload
|
from typing import Any, Literal
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import prisma.enums
|
import prisma.enums
|
||||||
@@ -11,8 +11,8 @@ import prisma.types
|
|||||||
|
|
||||||
from backend.data.db import transaction
|
from backend.data.db import transaction
|
||||||
from backend.data.graph import (
|
from backend.data.graph import (
|
||||||
|
GraphMeta,
|
||||||
GraphModel,
|
GraphModel,
|
||||||
GraphModelWithoutNodes,
|
|
||||||
get_graph,
|
get_graph,
|
||||||
get_graph_as_admin,
|
get_graph_as_admin,
|
||||||
get_sub_graphs,
|
get_sub_graphs,
|
||||||
@@ -334,22 +334,7 @@ async def get_store_agent_details(
|
|||||||
raise DatabaseError("Failed to fetch agent details") from e
|
raise DatabaseError("Failed to fetch agent details") from e
|
||||||
|
|
||||||
|
|
||||||
@overload
|
async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
||||||
async def get_available_graph(
|
|
||||||
store_listing_version_id: str, hide_nodes: Literal[False]
|
|
||||||
) -> GraphModel: ...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
async def get_available_graph(
|
|
||||||
store_listing_version_id: str, hide_nodes: Literal[True] = True
|
|
||||||
) -> GraphModelWithoutNodes: ...
|
|
||||||
|
|
||||||
|
|
||||||
async def get_available_graph(
|
|
||||||
store_listing_version_id: str,
|
|
||||||
hide_nodes: bool = True,
|
|
||||||
) -> GraphModelWithoutNodes | GraphModel:
|
|
||||||
try:
|
try:
|
||||||
# Get avaialble, non-deleted store listing version
|
# Get avaialble, non-deleted store listing version
|
||||||
store_listing_version = (
|
store_listing_version = (
|
||||||
@@ -359,7 +344,7 @@ async def get_available_graph(
|
|||||||
"isAvailable": True,
|
"isAvailable": True,
|
||||||
"isDeleted": False,
|
"isDeleted": False,
|
||||||
},
|
},
|
||||||
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
|
include={"AgentGraph": {"include": {"Nodes": True}}},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -369,9 +354,7 @@ async def get_available_graph(
|
|||||||
detail=f"Store listing version {store_listing_version_id} not found",
|
detail=f"Store listing version {store_listing_version_id} not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
return (GraphModelWithoutNodes if hide_nodes else GraphModel).from_db(
|
return GraphModel.from_db(store_listing_version.AgentGraph).meta()
|
||||||
store_listing_version.AgentGraph
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting agent: {e}")
|
logger.error(f"Error getting agent: {e}")
|
||||||
|
|||||||
@@ -454,9 +454,6 @@ async def test_unified_hybrid_search_pagination(
|
|||||||
cleanup_embeddings: list,
|
cleanup_embeddings: list,
|
||||||
):
|
):
|
||||||
"""Test unified search pagination works correctly."""
|
"""Test unified search pagination works correctly."""
|
||||||
# Use a unique search term to avoid matching other test data
|
|
||||||
unique_term = f"xyzpagtest{uuid.uuid4().hex[:8]}"
|
|
||||||
|
|
||||||
# Create multiple items
|
# Create multiple items
|
||||||
content_ids = []
|
content_ids = []
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
@@ -468,14 +465,14 @@ async def test_unified_hybrid_search_pagination(
|
|||||||
content_type=ContentType.BLOCK,
|
content_type=ContentType.BLOCK,
|
||||||
content_id=content_id,
|
content_id=content_id,
|
||||||
embedding=mock_embedding,
|
embedding=mock_embedding,
|
||||||
searchable_text=f"{unique_term} item number {i}",
|
searchable_text=f"pagination test item number {i}",
|
||||||
metadata={"index": i},
|
metadata={"index": i},
|
||||||
user_id=None,
|
user_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get first page
|
# Get first page
|
||||||
page1_results, total1 = await unified_hybrid_search(
|
page1_results, total1 = await unified_hybrid_search(
|
||||||
query=unique_term,
|
query="pagination test",
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=1,
|
page=1,
|
||||||
page_size=2,
|
page_size=2,
|
||||||
@@ -483,7 +480,7 @@ async def test_unified_hybrid_search_pagination(
|
|||||||
|
|
||||||
# Get second page
|
# Get second page
|
||||||
page2_results, total2 = await unified_hybrid_search(
|
page2_results, total2 = await unified_hybrid_search(
|
||||||
query=unique_term,
|
query="pagination test",
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=2,
|
page=2,
|
||||||
page_size=2,
|
page_size=2,
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from backend.blocks.ideogram import (
|
|||||||
StyleType,
|
StyleType,
|
||||||
UpscaleOption,
|
UpscaleOption,
|
||||||
)
|
)
|
||||||
from backend.data.graph import GraphBaseMeta
|
from backend.data.graph import BaseGraph
|
||||||
from backend.data.model import CredentialsMetaInput, ProviderName
|
from backend.data.model import CredentialsMetaInput, ProviderName
|
||||||
from backend.integrations.credentials_store import ideogram_credentials
|
from backend.integrations.credentials_store import ideogram_credentials
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
@@ -34,14 +34,14 @@ class ImageStyle(str, Enum):
|
|||||||
DIGITAL_ART = "digital art"
|
DIGITAL_ART = "digital art"
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
||||||
if settings.config.use_agent_image_generation_v2:
|
if settings.config.use_agent_image_generation_v2:
|
||||||
return await generate_agent_image_v2(graph=agent)
|
return await generate_agent_image_v2(graph=agent)
|
||||||
else:
|
else:
|
||||||
return await generate_agent_image_v1(agent=agent)
|
return await generate_agent_image_v1(agent=agent)
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
|
||||||
"""
|
"""
|
||||||
Generate an image for an agent using Ideogram model.
|
Generate an image for an agent using Ideogram model.
|
||||||
Returns:
|
Returns:
|
||||||
@@ -54,17 +54,14 @@ async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.Bytes
|
|||||||
description = f"{name} ({graph.description})" if graph.description else name
|
description = f"{name} ({graph.description})" if graph.description else name
|
||||||
|
|
||||||
prompt = (
|
prompt = (
|
||||||
"Create a visually striking retro-futuristic vector pop art illustration "
|
f"Create a visually striking retro-futuristic vector pop art illustration prominently featuring "
|
||||||
f'prominently featuring "{name}" in bold typography. The image clearly and '
|
f'"{name}" in bold typography. The image clearly and literally depicts a {description}, '
|
||||||
f"literally depicts a {description}, along with recognizable objects directly "
|
f"along with recognizable objects directly associated with the primary function of a {name}. "
|
||||||
f"associated with the primary function of a {name}. "
|
f"Ensure the imagery is concrete, intuitive, and immediately understandable, clearly conveying the "
|
||||||
f"Ensure the imagery is concrete, intuitive, and immediately understandable, "
|
f"purpose of a {name}. Maintain vibrant, limited-palette colors, sharp vector lines, geometric "
|
||||||
f"clearly conveying the purpose of a {name}. "
|
f"shapes, flat illustration techniques, and solid colors without gradients or shading. Preserve a "
|
||||||
"Maintain vibrant, limited-palette colors, sharp vector lines, "
|
f"retro-futuristic aesthetic influenced by mid-century futurism and 1960s psychedelia, "
|
||||||
"geometric shapes, flat illustration techniques, and solid colors "
|
f"prioritizing clear visual storytelling and thematic clarity above all else."
|
||||||
"without gradients or shading. Preserve a retro-futuristic aesthetic "
|
|
||||||
"influenced by mid-century futurism and 1960s psychedelia, "
|
|
||||||
"prioritizing clear visual storytelling and thematic clarity above all else."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
custom_colors = [
|
custom_colors = [
|
||||||
@@ -102,12 +99,12 @@ async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.Bytes
|
|||||||
return io.BytesIO(response.content)
|
return io.BytesIO(response.content)
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image_v1(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image_v1(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
||||||
"""
|
"""
|
||||||
Generate an image for an agent using Flux model via Replicate API.
|
Generate an image for an agent using Flux model via Replicate API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent (GraphBaseMeta | AgentGraph): The agent to generate an image for
|
agent (Graph): The agent to generate an image for
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
io.BytesIO: The generated image as bytes
|
io.BytesIO: The generated image as bytes
|
||||||
@@ -117,13 +114,7 @@ async def generate_agent_image_v1(agent: GraphBaseMeta | AgentGraph) -> io.Bytes
|
|||||||
raise ValueError("Missing Replicate API key in settings")
|
raise ValueError("Missing Replicate API key in settings")
|
||||||
|
|
||||||
# Construct prompt from agent details
|
# Construct prompt from agent details
|
||||||
prompt = (
|
prompt = f"Create a visually engaging app store thumbnail for the AI agent that highlights what it does in a clear and captivating way:\n- **Name**: {agent.name}\n- **Description**: {agent.description}\nFocus on showcasing its core functionality with an appealing design."
|
||||||
"Create a visually engaging app store thumbnail for the AI agent "
|
|
||||||
"that highlights what it does in a clear and captivating way:\n"
|
|
||||||
f"- **Name**: {agent.name}\n"
|
|
||||||
f"- **Description**: {agent.description}\n"
|
|
||||||
f"Focus on showcasing its core functionality with an appealing design."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set up Replicate client
|
# Set up Replicate client
|
||||||
client = ReplicateClient(api_token=settings.secrets.replicate_api_key)
|
client = ReplicateClient(api_token=settings.secrets.replicate_api_key)
|
||||||
|
|||||||
@@ -278,7 +278,7 @@ async def get_agent(
|
|||||||
)
|
)
|
||||||
async def get_graph_meta_by_store_listing_version_id(
|
async def get_graph_meta_by_store_listing_version_id(
|
||||||
store_listing_version_id: str,
|
store_listing_version_id: str,
|
||||||
) -> backend.data.graph.GraphModelWithoutNodes:
|
) -> backend.data.graph.GraphMeta:
|
||||||
"""
|
"""
|
||||||
Get Agent Graph from Store Listing Version ID.
|
Get Agent Graph from Store Listing Version ID.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -101,6 +101,7 @@ from backend.util.timezone_utils import (
|
|||||||
from backend.util.virus_scanner import scan_content_safe
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
|
|
||||||
from .library import db as library_db
|
from .library import db as library_db
|
||||||
|
from .library import model as library_model
|
||||||
from .store.model import StoreAgentDetails
|
from .store.model import StoreAgentDetails
|
||||||
|
|
||||||
|
|
||||||
@@ -822,16 +823,18 @@ async def update_graph(
|
|||||||
graph: graph_db.Graph,
|
graph: graph_db.Graph,
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
) -> graph_db.GraphModel:
|
) -> graph_db.GraphModel:
|
||||||
|
# Sanity check
|
||||||
if graph.id and graph.id != graph_id:
|
if graph.id and graph.id != graph_id:
|
||||||
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
||||||
|
|
||||||
|
# Determine new version
|
||||||
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
||||||
if not existing_versions:
|
if not existing_versions:
|
||||||
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
||||||
|
latest_version_number = max(g.version for g in existing_versions)
|
||||||
|
graph.version = latest_version_number + 1
|
||||||
|
|
||||||
graph.version = max(g.version for g in existing_versions) + 1
|
|
||||||
current_active_version = next((v for v in existing_versions if v.is_active), None)
|
current_active_version = next((v for v in existing_versions if v.is_active), None)
|
||||||
|
|
||||||
graph = graph_db.make_graph_model(graph, user_id)
|
graph = graph_db.make_graph_model(graph, user_id)
|
||||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
||||||
graph.validate_graph(for_run=False)
|
graph.validate_graph(for_run=False)
|
||||||
@@ -839,23 +842,27 @@ async def update_graph(
|
|||||||
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
||||||
|
|
||||||
if new_graph_version.is_active:
|
if new_graph_version.is_active:
|
||||||
await library_db.update_library_agent_version_and_settings(
|
# Keep the library agent up to date with the new active version
|
||||||
user_id, new_graph_version
|
await _update_library_agent_version_and_settings(user_id, new_graph_version)
|
||||||
)
|
|
||||||
|
# Handle activation of the new graph first to ensure continuity
|
||||||
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
|
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
|
||||||
|
# Ensure new version is the only active version
|
||||||
await graph_db.set_graph_active_version(
|
await graph_db.set_graph_active_version(
|
||||||
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
||||||
)
|
)
|
||||||
if current_active_version:
|
if current_active_version:
|
||||||
|
# Handle deactivation of the previously active version
|
||||||
await on_graph_deactivate(current_active_version, user_id=user_id)
|
await on_graph_deactivate(current_active_version, user_id=user_id)
|
||||||
|
|
||||||
|
# Fetch new graph version *with sub-graphs* (needed for credentials input schema)
|
||||||
new_graph_version_with_subgraphs = await graph_db.get_graph(
|
new_graph_version_with_subgraphs = await graph_db.get_graph(
|
||||||
graph_id,
|
graph_id,
|
||||||
new_graph_version.version,
|
new_graph_version.version,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
include_subgraphs=True,
|
include_subgraphs=True,
|
||||||
)
|
)
|
||||||
assert new_graph_version_with_subgraphs
|
assert new_graph_version_with_subgraphs # make type checker happy
|
||||||
return new_graph_version_with_subgraphs
|
return new_graph_version_with_subgraphs
|
||||||
|
|
||||||
|
|
||||||
@@ -893,15 +900,33 @@ async def set_graph_active_version(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Keep the library agent up to date with the new active version
|
# Keep the library agent up to date with the new active version
|
||||||
await library_db.update_library_agent_version_and_settings(
|
await _update_library_agent_version_and_settings(user_id, new_active_graph)
|
||||||
user_id, new_active_graph
|
|
||||||
)
|
|
||||||
|
|
||||||
if current_active_graph and current_active_graph.version != new_active_version:
|
if current_active_graph and current_active_graph.version != new_active_version:
|
||||||
# Handle deactivation of the previously active version
|
# Handle deactivation of the previously active version
|
||||||
await on_graph_deactivate(current_active_graph, user_id=user_id)
|
await on_graph_deactivate(current_active_graph, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_library_agent_version_and_settings(
|
||||||
|
user_id: str, agent_graph: graph_db.GraphModel
|
||||||
|
) -> library_model.LibraryAgent:
|
||||||
|
library = await library_db.update_agent_version_in_library(
|
||||||
|
user_id, agent_graph.id, agent_graph.version
|
||||||
|
)
|
||||||
|
updated_settings = GraphSettings.from_graph(
|
||||||
|
graph=agent_graph,
|
||||||
|
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
||||||
|
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
||||||
|
)
|
||||||
|
if updated_settings != library.settings:
|
||||||
|
library = await library_db.update_library_agent(
|
||||||
|
library_agent_id=library.id,
|
||||||
|
user_id=user_id,
|
||||||
|
settings=updated_settings,
|
||||||
|
)
|
||||||
|
return library
|
||||||
|
|
||||||
|
|
||||||
@v1_router.patch(
|
@v1_router.patch(
|
||||||
path="/graphs/{graph_id}/settings",
|
path="/graphs/{graph_id}/settings",
|
||||||
summary="Update graph settings",
|
summary="Update graph settings",
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ import backend.api.features.executions.review.routes
|
|||||||
import backend.api.features.library.db
|
import backend.api.features.library.db
|
||||||
import backend.api.features.library.model
|
import backend.api.features.library.model
|
||||||
import backend.api.features.library.routes
|
import backend.api.features.library.routes
|
||||||
import backend.api.features.mcp.routes as mcp_routes
|
|
||||||
import backend.api.features.oauth
|
import backend.api.features.oauth
|
||||||
import backend.api.features.otto.routes
|
import backend.api.features.otto.routes
|
||||||
import backend.api.features.postmark.postmark
|
import backend.api.features.postmark.postmark
|
||||||
@@ -41,10 +40,6 @@ import backend.data.user
|
|||||||
import backend.integrations.webhooks.utils
|
import backend.integrations.webhooks.utils
|
||||||
import backend.util.service
|
import backend.util.service
|
||||||
import backend.util.settings
|
import backend.util.settings
|
||||||
from backend.api.features.chat.completion_consumer import (
|
|
||||||
start_completion_consumer,
|
|
||||||
stop_completion_consumer,
|
|
||||||
)
|
|
||||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||||
from backend.data.model import Credentials
|
from backend.data.model import Credentials
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
@@ -123,21 +118,9 @@ async def lifespan_context(app: fastapi.FastAPI):
|
|||||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||||
|
|
||||||
# Start chat completion consumer for Redis Streams notifications
|
|
||||||
try:
|
|
||||||
await start_completion_consumer()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not start chat completion consumer: {e}")
|
|
||||||
|
|
||||||
with launch_darkly_context():
|
with launch_darkly_context():
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Stop chat completion consumer
|
|
||||||
try:
|
|
||||||
await stop_completion_consumer()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error stopping chat completion consumer: {e}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await shutdown_cloud_storage_handler()
|
await shutdown_cloud_storage_handler()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -344,11 +327,6 @@ app.include_router(
|
|||||||
tags=["workspace"],
|
tags=["workspace"],
|
||||||
prefix="/api/workspace",
|
prefix="/api/workspace",
|
||||||
)
|
)
|
||||||
app.include_router(
|
|
||||||
mcp_routes.router,
|
|
||||||
tags=["v2", "mcp"],
|
|
||||||
prefix="/api/mcp",
|
|
||||||
)
|
|
||||||
app.include_router(
|
app.include_router(
|
||||||
backend.api.features.oauth.router,
|
backend.api.features.oauth.router,
|
||||||
tags=["oauth"],
|
tags=["oauth"],
|
||||||
|
|||||||
@@ -1,28 +0,0 @@
|
|||||||
"""ElevenLabs integration blocks - test credentials and shared utilities."""
|
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
TEST_CREDENTIALS = APIKeyCredentials(
|
|
||||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
|
||||||
provider="elevenlabs",
|
|
||||||
api_key=SecretStr("mock-elevenlabs-api-key"),
|
|
||||||
title="Mock ElevenLabs API key",
|
|
||||||
expires_at=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
TEST_CREDENTIALS_INPUT = {
|
|
||||||
"provider": TEST_CREDENTIALS.provider,
|
|
||||||
"id": TEST_CREDENTIALS.id,
|
|
||||||
"type": TEST_CREDENTIALS.type,
|
|
||||||
"title": TEST_CREDENTIALS.title,
|
|
||||||
}
|
|
||||||
|
|
||||||
ElevenLabsCredentials = APIKeyCredentials
|
|
||||||
ElevenLabsCredentialsInput = CredentialsMetaInput[
|
|
||||||
Literal[ProviderName.ELEVENLABS], Literal["api_key"]
|
|
||||||
]
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
"""Text encoding block for converting special characters to escape sequences."""
|
|
||||||
|
|
||||||
import codecs
|
|
||||||
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
|
|
||||||
|
|
||||||
class TextEncoderBlock(Block):
|
|
||||||
"""
|
|
||||||
Encodes a string by converting special characters into escape sequences.
|
|
||||||
|
|
||||||
This block is the inverse of TextDecoderBlock. It takes text containing
|
|
||||||
special characters (like newlines, tabs, etc.) and converts them into
|
|
||||||
their escape sequence representations (e.g., newline becomes \\n).
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
"""Input schema for TextEncoderBlock."""
|
|
||||||
|
|
||||||
text: str = SchemaField(
|
|
||||||
description="A string containing special characters to be encoded",
|
|
||||||
placeholder="Your text with newlines and quotes to encode",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
"""Output schema for TextEncoderBlock."""
|
|
||||||
|
|
||||||
encoded_text: str = SchemaField(
|
|
||||||
description="The encoded text with special characters converted to escape sequences"
|
|
||||||
)
|
|
||||||
error: str = SchemaField(description="Error message if encoding fails")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="5185f32e-4b65-4ecf-8fbb-873f003f09d6",
|
|
||||||
description="Encodes a string by converting special characters into escape sequences",
|
|
||||||
categories={BlockCategory.TEXT},
|
|
||||||
input_schema=TextEncoderBlock.Input,
|
|
||||||
output_schema=TextEncoderBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"text": """Hello
|
|
||||||
World!
|
|
||||||
This is a "quoted" string."""
|
|
||||||
},
|
|
||||||
test_output=[
|
|
||||||
(
|
|
||||||
"encoded_text",
|
|
||||||
"""Hello\\nWorld!\\nThis is a "quoted" string.""",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
|
||||||
"""
|
|
||||||
Encode the input text by converting special characters to escape sequences.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_data: The input containing the text to encode.
|
|
||||||
**kwargs: Additional keyword arguments (unused).
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
The encoded text with escape sequences, or an error message if encoding fails.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
encoded_text = codecs.encode(input_data.text, "unicode_escape").decode(
|
|
||||||
"utf-8"
|
|
||||||
)
|
|
||||||
yield "encoded_text", encoded_text
|
|
||||||
except Exception as e:
|
|
||||||
yield "error", f"Encoding error: {str(e)}"
|
|
||||||
@@ -478,7 +478,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
webset = await aexa.websets.get(id=input_data.external_id)
|
webset = aexa.websets.get(id=input_data.external_id)
|
||||||
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
||||||
|
|
||||||
yield "webset", webset_result
|
yield "webset", webset_result
|
||||||
@@ -494,7 +494,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
|||||||
count=input_data.search_count,
|
count=input_data.search_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
webset = await aexa.websets.create(
|
webset = aexa.websets.create(
|
||||||
params=CreateWebsetParameters(
|
params=CreateWebsetParameters(
|
||||||
search=search_params,
|
search=search_params,
|
||||||
external_id=input_data.external_id,
|
external_id=input_data.external_id,
|
||||||
@@ -554,7 +554,7 @@ class ExaUpdateWebsetBlock(Block):
|
|||||||
if input_data.metadata is not None:
|
if input_data.metadata is not None:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_webset = await aexa.websets.update(id=input_data.webset_id, params=payload)
|
sdk_webset = aexa.websets.update(id=input_data.webset_id, params=payload)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
sdk_webset.status.value
|
sdk_webset.status.value
|
||||||
@@ -617,7 +617,7 @@ class ExaListWebsetsBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = await aexa.websets.list(
|
response = aexa.websets.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
@@ -678,7 +678,7 @@ class ExaGetWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_webset = await aexa.websets.get(id=input_data.webset_id)
|
sdk_webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
sdk_webset.status.value
|
sdk_webset.status.value
|
||||||
@@ -748,7 +748,7 @@ class ExaDeleteWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_webset = await aexa.websets.delete(id=input_data.webset_id)
|
deleted_webset = aexa.websets.delete(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
deleted_webset.status.value
|
deleted_webset.status.value
|
||||||
@@ -798,7 +798,7 @@ class ExaCancelWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_webset = await aexa.websets.cancel(id=input_data.webset_id)
|
canceled_webset = aexa.websets.cancel(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
canceled_webset.status.value
|
canceled_webset.status.value
|
||||||
@@ -968,7 +968,7 @@ class ExaPreviewWebsetBlock(Block):
|
|||||||
entity["description"] = input_data.entity_description
|
entity["description"] = input_data.entity_description
|
||||||
payload["entity"] = entity
|
payload["entity"] = entity
|
||||||
|
|
||||||
sdk_preview = await aexa.websets.preview(params=payload)
|
sdk_preview = aexa.websets.preview(params=payload)
|
||||||
|
|
||||||
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
||||||
|
|
||||||
@@ -1051,7 +1051,7 @@ class ExaWebsetStatusBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status = (
|
status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
@@ -1185,7 +1185,7 @@ class ExaWebsetSummaryBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
# Extract basic info
|
# Extract basic info
|
||||||
webset_id = webset.id
|
webset_id = webset.id
|
||||||
@@ -1211,7 +1211,7 @@ class ExaWebsetSummaryBlock(Block):
|
|||||||
total_items = 0
|
total_items = 0
|
||||||
|
|
||||||
if input_data.include_sample_items and input_data.sample_size > 0:
|
if input_data.include_sample_items and input_data.sample_size > 0:
|
||||||
items_response = await aexa.websets.items.list(
|
items_response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||||
)
|
)
|
||||||
sample_items_data = [
|
sample_items_data = [
|
||||||
@@ -1362,7 +1362,7 @@ class ExaWebsetReadyCheckBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get webset details
|
# Get webset details
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status = (
|
status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_enrichment = await aexa.websets.enrichments.create(
|
sdk_enrichment = aexa.websets.enrichments.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -223,7 +223,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
items_enriched = 0
|
items_enriched = 0
|
||||||
|
|
||||||
while time.time() - poll_start < input_data.polling_timeout:
|
while time.time() - poll_start < input_data.polling_timeout:
|
||||||
current_enrich = await aexa.websets.enrichments.get(
|
current_enrich = aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=enrichment_id
|
webset_id=input_data.webset_id, id=enrichment_id
|
||||||
)
|
)
|
||||||
current_status = (
|
current_status = (
|
||||||
@@ -234,7 +234,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
|
|
||||||
if current_status in ["completed", "failed", "cancelled"]:
|
if current_status in ["completed", "failed", "cancelled"]:
|
||||||
# Estimate items from webset searches
|
# Estimate items from webset searches
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
if webset.searches:
|
if webset.searches:
|
||||||
for search in webset.searches:
|
for search in webset.searches:
|
||||||
if search.progress:
|
if search.progress:
|
||||||
@@ -329,7 +329,7 @@ class ExaGetEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_enrichment = await aexa.websets.enrichments.get(
|
sdk_enrichment = aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -474,7 +474,7 @@ class ExaDeleteEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_enrichment = await aexa.websets.enrichments.delete(
|
deleted_enrichment = aexa.websets.enrichments.delete(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -525,13 +525,13 @@ class ExaCancelEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_enrichment = await aexa.websets.enrichments.cancel(
|
canceled_enrichment = aexa.websets.enrichments.cancel(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to estimate how many items were enriched before cancellation
|
# Try to estimate how many items were enriched before cancellation
|
||||||
items_enriched = 0
|
items_enriched = 0
|
||||||
items_response = await aexa.websets.items.list(
|
items_response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=100
|
webset_id=input_data.webset_id, limit=100
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -222,7 +222,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
def _create_test_mock():
|
def _create_test_mock():
|
||||||
"""Create test mocks for the AsyncExa SDK."""
|
"""Create test mocks for the AsyncExa SDK."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
# Create mock SDK import object
|
# Create mock SDK import object
|
||||||
mock_import = MagicMock()
|
mock_import = MagicMock()
|
||||||
@@ -247,7 +247,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(
|
||||||
imports=MagicMock(create=AsyncMock(return_value=mock_import))
|
imports=MagicMock(create=lambda *args, **kwargs: mock_import)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -294,7 +294,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
if input_data.metadata:
|
if input_data.metadata:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_import = await aexa.websets.imports.create(
|
sdk_import = aexa.websets.imports.create(
|
||||||
params=payload, csv_data=input_data.csv_data
|
params=payload, csv_data=input_data.csv_data
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -360,7 +360,7 @@ class ExaGetImportBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_import = await aexa.websets.imports.get(import_id=input_data.import_id)
|
sdk_import = aexa.websets.imports.get(import_id=input_data.import_id)
|
||||||
|
|
||||||
import_obj = ImportModel.from_sdk(sdk_import)
|
import_obj = ImportModel.from_sdk(sdk_import)
|
||||||
|
|
||||||
@@ -426,7 +426,7 @@ class ExaListImportsBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = await aexa.websets.imports.list(
|
response = aexa.websets.imports.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
@@ -474,9 +474,7 @@ class ExaDeleteImportBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_import = await aexa.websets.imports.delete(
|
deleted_import = aexa.websets.imports.delete(import_id=input_data.import_id)
|
||||||
import_id=input_data.import_id
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "import_id", deleted_import.id
|
yield "import_id", deleted_import.id
|
||||||
yield "success", "true"
|
yield "success", "true"
|
||||||
@@ -575,14 +573,14 @@ class ExaExportWebsetBlock(Block):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create async iterator for list_all
|
# Create mock iterator
|
||||||
async def async_item_iterator(*args, **kwargs):
|
mock_items = [mock_item1, mock_item2]
|
||||||
for item in [mock_item1, mock_item2]:
|
|
||||||
yield item
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(items=MagicMock(list_all=async_item_iterator))
|
websets=MagicMock(
|
||||||
|
items=MagicMock(list_all=lambda *args, **kwargs: iter(mock_items))
|
||||||
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -604,7 +602,7 @@ class ExaExportWebsetBlock(Block):
|
|||||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||||
)
|
)
|
||||||
|
|
||||||
async for sdk_item in item_iterator:
|
for sdk_item in item_iterator:
|
||||||
if len(all_items) >= input_data.max_items:
|
if len(all_items) >= input_data.max_items:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class ExaGetWebsetItemBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_item = await aexa.websets.items.get(
|
sdk_item = aexa.websets.items.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.item_id
|
webset_id=input_data.webset_id, id=input_data.item_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -269,7 +269,7 @@ class ExaListWebsetItemsBlock(Block):
|
|||||||
response = None
|
response = None
|
||||||
|
|
||||||
while time.time() - start_time < input_data.wait_timeout:
|
while time.time() - start_time < input_data.wait_timeout:
|
||||||
response = await aexa.websets.items.list(
|
response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
@@ -282,13 +282,13 @@ class ExaListWebsetItemsBlock(Block):
|
|||||||
interval = min(interval * 1.2, 10)
|
interval = min(interval * 1.2, 10)
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
response = await aexa.websets.items.list(
|
response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = await aexa.websets.items.list(
|
response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
@@ -340,7 +340,7 @@ class ExaDeleteWebsetItemBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_item = await aexa.websets.items.delete(
|
deleted_item = aexa.websets.items.delete(
|
||||||
webset_id=input_data.webset_id, id=input_data.item_id
|
webset_id=input_data.webset_id, id=input_data.item_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -408,7 +408,7 @@ class ExaBulkWebsetItemsBlock(Block):
|
|||||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||||
)
|
)
|
||||||
|
|
||||||
async for sdk_item in item_iterator:
|
for sdk_item in item_iterator:
|
||||||
if len(all_items) >= input_data.max_items:
|
if len(all_items) >= input_data.max_items:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -475,7 +475,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
entity_type = "unknown"
|
entity_type = "unknown"
|
||||||
if webset.searches:
|
if webset.searches:
|
||||||
@@ -495,7 +495,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
|||||||
# Get sample items if requested
|
# Get sample items if requested
|
||||||
sample_items: List[WebsetItemModel] = []
|
sample_items: List[WebsetItemModel] = []
|
||||||
if input_data.sample_size > 0:
|
if input_data.sample_size > 0:
|
||||||
items_response = await aexa.websets.items.list(
|
items_response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||||
)
|
)
|
||||||
# Convert to our stable models
|
# Convert to our stable models
|
||||||
@@ -569,7 +569,7 @@ class ExaGetNewItemsBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get items starting from cursor
|
# Get items starting from cursor
|
||||||
response = await aexa.websets.items.list(
|
response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.since_cursor,
|
cursor=input_data.since_cursor,
|
||||||
limit=input_data.max_items,
|
limit=input_data.max_items,
|
||||||
|
|||||||
@@ -233,7 +233,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
def _create_test_mock():
|
def _create_test_mock():
|
||||||
"""Create test mocks for the AsyncExa SDK."""
|
"""Create test mocks for the AsyncExa SDK."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
# Create mock SDK monitor object
|
# Create mock SDK monitor object
|
||||||
mock_monitor = MagicMock()
|
mock_monitor = MagicMock()
|
||||||
@@ -263,7 +263,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(
|
||||||
monitors=MagicMock(create=AsyncMock(return_value=mock_monitor))
|
monitors=MagicMock(create=lambda *args, **kwargs: mock_monitor)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -320,7 +320,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
if input_data.metadata:
|
if input_data.metadata:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_monitor = await aexa.websets.monitors.create(params=payload)
|
sdk_monitor = aexa.websets.monitors.create(params=payload)
|
||||||
|
|
||||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||||
|
|
||||||
@@ -384,7 +384,7 @@ class ExaGetMonitorBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_monitor = await aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
sdk_monitor = aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
||||||
|
|
||||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||||
|
|
||||||
@@ -476,7 +476,7 @@ class ExaUpdateMonitorBlock(Block):
|
|||||||
if input_data.metadata is not None:
|
if input_data.metadata is not None:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_monitor = await aexa.websets.monitors.update(
|
sdk_monitor = aexa.websets.monitors.update(
|
||||||
monitor_id=input_data.monitor_id, params=payload
|
monitor_id=input_data.monitor_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -522,9 +522,7 @@ class ExaDeleteMonitorBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_monitor = await aexa.websets.monitors.delete(
|
deleted_monitor = aexa.websets.monitors.delete(monitor_id=input_data.monitor_id)
|
||||||
monitor_id=input_data.monitor_id
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "monitor_id", deleted_monitor.id
|
yield "monitor_id", deleted_monitor.id
|
||||||
yield "success", "true"
|
yield "success", "true"
|
||||||
@@ -581,7 +579,7 @@ class ExaListMonitorsBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = await aexa.websets.monitors.list(
|
response = aexa.websets.monitors.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
WebsetTargetStatus.IDLE,
|
WebsetTargetStatus.IDLE,
|
||||||
WebsetTargetStatus.ANY_COMPLETE,
|
WebsetTargetStatus.ANY_COMPLETE,
|
||||||
]:
|
]:
|
||||||
final_webset = await aexa.websets.wait_until_idle(
|
final_webset = aexa.websets.wait_until_idle(
|
||||||
id=input_data.webset_id,
|
id=input_data.webset_id,
|
||||||
timeout=input_data.timeout,
|
timeout=input_data.timeout,
|
||||||
poll_interval=input_data.check_interval,
|
poll_interval=input_data.check_interval,
|
||||||
@@ -164,7 +164,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
interval = input_data.check_interval
|
interval = input_data.check_interval
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current webset status
|
# Get current webset status
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
current_status = (
|
current_status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
if hasattr(webset.status, "value")
|
if hasattr(webset.status, "value")
|
||||||
@@ -209,7 +209,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
|
|
||||||
# Timeout reached
|
# Timeout reached
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
final_status = (
|
final_status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
if hasattr(webset.status, "value")
|
if hasattr(webset.status, "value")
|
||||||
@@ -345,7 +345,7 @@ class ExaWaitForSearchBlock(Block):
|
|||||||
try:
|
try:
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current search status using SDK
|
# Get current search status using SDK
|
||||||
search = await aexa.websets.searches.get(
|
search = aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -401,7 +401,7 @@ class ExaWaitForSearchBlock(Block):
|
|||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
# Get last known status
|
# Get last known status
|
||||||
search = await aexa.websets.searches.get(
|
search = aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
final_status = (
|
final_status = (
|
||||||
@@ -503,7 +503,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
try:
|
try:
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current enrichment status using SDK
|
# Get current enrichment status using SDK
|
||||||
enrichment = await aexa.websets.enrichments.get(
|
enrichment = aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -548,7 +548,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
# Get last known status
|
# Get last known status
|
||||||
enrichment = await aexa.websets.enrichments.get(
|
enrichment = aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
final_status = (
|
final_status = (
|
||||||
@@ -575,7 +575,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
) -> tuple[list[SampleEnrichmentModel], int]:
|
) -> tuple[list[SampleEnrichmentModel], int]:
|
||||||
"""Get sample enriched data and count."""
|
"""Get sample enriched data and count."""
|
||||||
# Get a few items to see enrichment results using SDK
|
# Get a few items to see enrichment results using SDK
|
||||||
response = await aexa.websets.items.list(webset_id=webset_id, limit=5)
|
response = aexa.websets.items.list(webset_id=webset_id, limit=5)
|
||||||
|
|
||||||
sample_data: list[SampleEnrichmentModel] = []
|
sample_data: list[SampleEnrichmentModel] = []
|
||||||
enriched_count = 0
|
enriched_count = 0
|
||||||
|
|||||||
@@ -317,7 +317,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
|||||||
|
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_search = await aexa.websets.searches.create(
|
sdk_search = aexa.websets.searches.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -350,7 +350,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
|||||||
poll_start = time.time()
|
poll_start = time.time()
|
||||||
|
|
||||||
while time.time() - poll_start < input_data.polling_timeout:
|
while time.time() - poll_start < input_data.polling_timeout:
|
||||||
current_search = await aexa.websets.searches.get(
|
current_search = aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=search_id
|
webset_id=input_data.webset_id, id=search_id
|
||||||
)
|
)
|
||||||
current_status = (
|
current_status = (
|
||||||
@@ -442,7 +442,7 @@ class ExaGetWebsetSearchBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_search = await aexa.websets.searches.get(
|
sdk_search = aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -523,7 +523,7 @@ class ExaCancelWebsetSearchBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_search = await aexa.websets.searches.cancel(
|
canceled_search = aexa.websets.searches.cancel(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -604,7 +604,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get webset to check existing searches
|
# Get webset to check existing searches
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
# Look for existing search with same query
|
# Look for existing search with same query
|
||||||
existing_search = None
|
existing_search = None
|
||||||
@@ -636,7 +636,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
|||||||
if input_data.entity_type != SearchEntityType.AUTO:
|
if input_data.entity_type != SearchEntityType.AUTO:
|
||||||
payload["entity"] = {"type": input_data.entity_type.value}
|
payload["entity"] = {"type": input_data.entity_type.value}
|
||||||
|
|
||||||
sdk_search = await aexa.websets.searches.create(
|
sdk_search = aexa.websets.searches.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -162,16 +162,8 @@ class LinearClient:
|
|||||||
"searchTerm": team_name,
|
"searchTerm": team_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
result = await self.query(query, variables)
|
team_id = await self.query(query, variables)
|
||||||
nodes = result["teams"]["nodes"]
|
return team_id["teams"]["nodes"][0]["id"]
|
||||||
|
|
||||||
if not nodes:
|
|
||||||
raise LinearAPIException(
|
|
||||||
f"Team '{team_name}' not found. Check the team name or key and try again.",
|
|
||||||
status_code=404,
|
|
||||||
)
|
|
||||||
|
|
||||||
return nodes[0]["id"]
|
|
||||||
except LinearAPIException as e:
|
except LinearAPIException as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@@ -248,44 +240,17 @@ class LinearClient:
|
|||||||
except LinearAPIException as e:
|
except LinearAPIException as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def try_search_issues(
|
async def try_search_issues(self, term: str) -> list[Issue]:
|
||||||
self,
|
|
||||||
term: str,
|
|
||||||
max_results: int = 10,
|
|
||||||
team_id: str | None = None,
|
|
||||||
) -> list[Issue]:
|
|
||||||
try:
|
try:
|
||||||
query = """
|
query = """
|
||||||
query SearchIssues(
|
query SearchIssues($term: String!, $includeComments: Boolean!) {
|
||||||
$term: String!,
|
searchIssues(term: $term, includeComments: $includeComments) {
|
||||||
$first: Int,
|
|
||||||
$teamId: String
|
|
||||||
) {
|
|
||||||
searchIssues(
|
|
||||||
term: $term,
|
|
||||||
first: $first,
|
|
||||||
teamId: $teamId
|
|
||||||
) {
|
|
||||||
nodes {
|
nodes {
|
||||||
id
|
id
|
||||||
identifier
|
identifier
|
||||||
title
|
title
|
||||||
description
|
description
|
||||||
priority
|
priority
|
||||||
createdAt
|
|
||||||
state {
|
|
||||||
id
|
|
||||||
name
|
|
||||||
type
|
|
||||||
}
|
|
||||||
project {
|
|
||||||
id
|
|
||||||
name
|
|
||||||
}
|
|
||||||
assignee {
|
|
||||||
id
|
|
||||||
name
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -293,8 +258,7 @@ class LinearClient:
|
|||||||
|
|
||||||
variables: dict[str, Any] = {
|
variables: dict[str, Any] = {
|
||||||
"term": term,
|
"term": term,
|
||||||
"first": max_results,
|
"includeComments": True,
|
||||||
"teamId": team_id,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
issues = await self.query(query, variables)
|
issues = await self.query(query, variables)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from ._config import (
|
|||||||
LinearScope,
|
LinearScope,
|
||||||
linear,
|
linear,
|
||||||
)
|
)
|
||||||
from .models import CreateIssueResponse, Issue, State
|
from .models import CreateIssueResponse, Issue
|
||||||
|
|
||||||
|
|
||||||
class LinearCreateIssueBlock(Block):
|
class LinearCreateIssueBlock(Block):
|
||||||
@@ -135,20 +135,9 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
description="Linear credentials with read permissions",
|
description="Linear credentials with read permissions",
|
||||||
required_scopes={LinearScope.READ},
|
required_scopes={LinearScope.READ},
|
||||||
)
|
)
|
||||||
max_results: int = SchemaField(
|
|
||||||
description="Maximum number of results to return",
|
|
||||||
default=10,
|
|
||||||
ge=1,
|
|
||||||
le=100,
|
|
||||||
)
|
|
||||||
team_name: str | None = SchemaField(
|
|
||||||
description="Optional team name to filter results (e.g., 'Internal', 'Open Source')",
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
issues: list[Issue] = SchemaField(description="List of issues")
|
issues: list[Issue] = SchemaField(description="List of issues")
|
||||||
error: str = SchemaField(description="Error message if the search failed")
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -156,11 +145,8 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
description="Searches for issues on Linear",
|
description="Searches for issues on Linear",
|
||||||
input_schema=self.Input,
|
input_schema=self.Input,
|
||||||
output_schema=self.Output,
|
output_schema=self.Output,
|
||||||
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
|
|
||||||
test_input={
|
test_input={
|
||||||
"term": "Test issue",
|
"term": "Test issue",
|
||||||
"max_results": 10,
|
|
||||||
"team_name": None,
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
|
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS_OAUTH,
|
test_credentials=TEST_CREDENTIALS_OAUTH,
|
||||||
@@ -170,14 +156,10 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
[
|
[
|
||||||
Issue(
|
Issue(
|
||||||
id="abc123",
|
id="abc123",
|
||||||
identifier="TST-123",
|
identifier="abc123",
|
||||||
title="Test issue",
|
title="Test issue",
|
||||||
description="Test description",
|
description="Test description",
|
||||||
priority=1,
|
priority=1,
|
||||||
state=State(
|
|
||||||
id="state1", name="In Progress", type="started"
|
|
||||||
),
|
|
||||||
createdAt="2026-01-15T10:00:00.000Z",
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@@ -186,12 +168,10 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
"search_issues": lambda *args, **kwargs: [
|
"search_issues": lambda *args, **kwargs: [
|
||||||
Issue(
|
Issue(
|
||||||
id="abc123",
|
id="abc123",
|
||||||
identifier="TST-123",
|
identifier="abc123",
|
||||||
title="Test issue",
|
title="Test issue",
|
||||||
description="Test description",
|
description="Test description",
|
||||||
priority=1,
|
priority=1,
|
||||||
state=State(id="state1", name="In Progress", type="started"),
|
|
||||||
createdAt="2026-01-15T10:00:00.000Z",
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -201,22 +181,10 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
async def search_issues(
|
async def search_issues(
|
||||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||||
term: str,
|
term: str,
|
||||||
max_results: int = 10,
|
|
||||||
team_name: str | None = None,
|
|
||||||
) -> list[Issue]:
|
) -> list[Issue]:
|
||||||
client = LinearClient(credentials=credentials)
|
client = LinearClient(credentials=credentials)
|
||||||
|
response: list[Issue] = await client.try_search_issues(term=term)
|
||||||
# Resolve team name to ID if provided
|
return response
|
||||||
# Raises LinearAPIException with descriptive message if team not found
|
|
||||||
team_id: str | None = None
|
|
||||||
if team_name:
|
|
||||||
team_id = await client.try_get_team_by_name(team_name=team_name)
|
|
||||||
|
|
||||||
return await client.try_search_issues(
|
|
||||||
term=term,
|
|
||||||
max_results=max_results,
|
|
||||||
team_id=team_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
@@ -228,10 +196,7 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
"""Execute the issue search"""
|
"""Execute the issue search"""
|
||||||
try:
|
try:
|
||||||
issues = await self.search_issues(
|
issues = await self.search_issues(
|
||||||
credentials=credentials,
|
credentials=credentials, term=input_data.term
|
||||||
term=input_data.term,
|
|
||||||
max_results=input_data.max_results,
|
|
||||||
team_name=input_data.team_name,
|
|
||||||
)
|
)
|
||||||
yield "issues", issues
|
yield "issues", issues
|
||||||
except LinearAPIException as e:
|
except LinearAPIException as e:
|
||||||
|
|||||||
@@ -36,21 +36,12 @@ class Project(BaseModel):
|
|||||||
content: str | None = None
|
content: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class State(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
type: str | None = (
|
|
||||||
None # Workflow state type (e.g., "triage", "backlog", "started", "completed", "canceled")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Issue(BaseModel):
|
class Issue(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
identifier: str
|
identifier: str
|
||||||
title: str
|
title: str
|
||||||
description: str | None
|
description: str | None
|
||||||
priority: int
|
priority: int
|
||||||
state: State | None = None
|
|
||||||
project: Project | None = None
|
project: Project | None = None
|
||||||
createdAt: str | None = None
|
createdAt: str | None = None
|
||||||
comments: list[Comment] | None = None
|
comments: list[Comment] | None = None
|
||||||
|
|||||||
@@ -115,7 +115,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
|||||||
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
||||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||||
CLAUDE_4_6_OPUS = "claude-opus-4-6"
|
|
||||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||||
# AI/ML API models
|
# AI/ML API models
|
||||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||||
@@ -271,9 +270,6 @@ MODEL_METADATA = {
|
|||||||
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
||||||
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
|
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
|
||||||
), # claude-4-sonnet-20250514
|
), # claude-4-sonnet-20250514
|
||||||
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
|
|
||||||
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
|
|
||||||
), # claude-opus-4-6
|
|
||||||
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
||||||
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
||||||
), # claude-opus-4-5-20251101
|
), # claude-opus-4-5-20251101
|
||||||
@@ -596,10 +592,10 @@ def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
|||||||
|
|
||||||
def get_parallel_tool_calls_param(
|
def get_parallel_tool_calls_param(
|
||||||
llm_model: LlmModel, parallel_tool_calls: bool | None
|
llm_model: LlmModel, parallel_tool_calls: bool | None
|
||||||
) -> bool | openai.Omit:
|
):
|
||||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||||
return openai.omit
|
return openai.NOT_GIVEN
|
||||||
return parallel_tool_calls
|
return parallel_tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,241 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP (Model Context Protocol) Tool Block.
|
|
||||||
|
|
||||||
A single dynamic block that can connect to any MCP server, discover available tools,
|
|
||||||
and execute them. Works like AgentExecutorBlock — the user selects a tool from a
|
|
||||||
dropdown and the input/output schema adapts dynamically.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockInput,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
BlockType,
|
|
||||||
)
|
|
||||||
from backend.data.model import (
|
|
||||||
CredentialsField,
|
|
||||||
CredentialsMetaInput,
|
|
||||||
OAuth2Credentials,
|
|
||||||
SchemaField,
|
|
||||||
)
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util.json import validate_with_jsonschema
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
TEST_CREDENTIALS = OAuth2Credentials(
|
|
||||||
id="test-mcp-cred",
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("mock-mcp-token"),
|
|
||||||
refresh_token=SecretStr("mock-refresh"),
|
|
||||||
scopes=[],
|
|
||||||
title="Mock MCP credential",
|
|
||||||
)
|
|
||||||
TEST_CREDENTIALS_INPUT = {
|
|
||||||
"provider": TEST_CREDENTIALS.provider,
|
|
||||||
"id": TEST_CREDENTIALS.id,
|
|
||||||
"type": TEST_CREDENTIALS.type,
|
|
||||||
"title": TEST_CREDENTIALS.title,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
MCPCredentials = CredentialsMetaInput[Literal[ProviderName.MCP], Literal["oauth2"]]
|
|
||||||
|
|
||||||
|
|
||||||
class MCPToolBlock(Block):
|
|
||||||
"""
|
|
||||||
A block that connects to an MCP server, lets the user pick a tool,
|
|
||||||
and executes it with dynamic input/output schema.
|
|
||||||
|
|
||||||
The flow:
|
|
||||||
1. User provides an MCP server URL (and optional credentials)
|
|
||||||
2. Frontend calls the backend to get tool list from that URL
|
|
||||||
3. User selects a tool from a dropdown (available_tools)
|
|
||||||
4. The block's input schema updates to reflect the selected tool's parameters
|
|
||||||
5. On execution, the block calls the MCP server to run the tool
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
server_url: str = SchemaField(
|
|
||||||
description="URL of the MCP server (Streamable HTTP endpoint)",
|
|
||||||
placeholder="https://mcp.example.com/mcp",
|
|
||||||
)
|
|
||||||
credentials: MCPCredentials = CredentialsField(
|
|
||||||
discriminator="server_url",
|
|
||||||
description="MCP server OAuth credentials",
|
|
||||||
default={},
|
|
||||||
)
|
|
||||||
selected_tool: str = SchemaField(
|
|
||||||
description="The MCP tool to execute",
|
|
||||||
placeholder="Select a tool",
|
|
||||||
default="",
|
|
||||||
)
|
|
||||||
tool_input_schema: dict[str, Any] = SchemaField(
|
|
||||||
description="JSON Schema for the selected tool's input parameters. "
|
|
||||||
"Populated automatically when a tool is selected.",
|
|
||||||
default={},
|
|
||||||
hidden=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_arguments: dict[str, Any] = SchemaField(
|
|
||||||
description="Arguments to pass to the selected MCP tool. "
|
|
||||||
"The fields here are defined by the tool's input schema.",
|
|
||||||
default={},
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
|
|
||||||
"""Return the tool's input schema so the builder UI renders dynamic fields."""
|
|
||||||
return data.get("tool_input_schema", {})
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
|
||||||
"""Return the current tool_arguments as defaults for the dynamic fields."""
|
|
||||||
return data.get("tool_arguments", {})
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
|
||||||
"""Check which required tool arguments are missing."""
|
|
||||||
required_fields = cls.get_input_schema(data).get("required", [])
|
|
||||||
tool_arguments = data.get("tool_arguments", {})
|
|
||||||
return set(required_fields) - set(tool_arguments)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
|
||||||
"""Validate tool_arguments against the tool's input schema."""
|
|
||||||
tool_schema = cls.get_input_schema(data)
|
|
||||||
if not tool_schema:
|
|
||||||
return None
|
|
||||||
tool_arguments = data.get("tool_arguments", {})
|
|
||||||
return validate_with_jsonschema(tool_schema, tool_arguments)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
result: Any = SchemaField(description="The result returned by the MCP tool")
|
|
||||||
error: str = SchemaField(description="Error message if the tool call failed")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4",
|
|
||||||
description="Connect to any MCP server and execute its tools. "
|
|
||||||
"Provide a server URL, select a tool, and pass arguments dynamically.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=MCPToolBlock.Input,
|
|
||||||
output_schema=MCPToolBlock.Output,
|
|
||||||
block_type=BlockType.STANDARD,
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_input={
|
|
||||||
"server_url": "https://mcp.example.com/mcp",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
"selected_tool": "get_weather",
|
|
||||||
"tool_input_schema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"city": {"type": "string"}},
|
|
||||||
"required": ["city"],
|
|
||||||
},
|
|
||||||
"tool_arguments": {"city": "London"},
|
|
||||||
},
|
|
||||||
test_output=[
|
|
||||||
(
|
|
||||||
"result",
|
|
||||||
{"weather": "sunny", "temperature": 20},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"_call_mcp_tool": lambda *a, **kw: {
|
|
||||||
"weather": "sunny",
|
|
||||||
"temperature": 20,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _call_mcp_tool(
|
|
||||||
self,
|
|
||||||
server_url: str,
|
|
||||||
tool_name: str,
|
|
||||||
arguments: dict[str, Any],
|
|
||||||
auth_token: str | None = None,
|
|
||||||
) -> Any:
|
|
||||||
"""Call a tool on the MCP server. Extracted for easy mocking in tests."""
|
|
||||||
client = MCPClient(server_url, auth_token=auth_token)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool(tool_name, arguments)
|
|
||||||
|
|
||||||
if result.is_error:
|
|
||||||
error_text = ""
|
|
||||||
for item in result.content:
|
|
||||||
if item.get("type") == "text":
|
|
||||||
error_text += item.get("text", "")
|
|
||||||
raise MCPClientError(
|
|
||||||
f"MCP tool '{tool_name}' returned an error: "
|
|
||||||
f"{error_text or 'Unknown error'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract text content from the result
|
|
||||||
output_parts = []
|
|
||||||
for item in result.content:
|
|
||||||
if item.get("type") == "text":
|
|
||||||
text = item.get("text", "")
|
|
||||||
# Try to parse as JSON for structured output
|
|
||||||
try:
|
|
||||||
output_parts.append(json.loads(text))
|
|
||||||
except (json.JSONDecodeError, ValueError):
|
|
||||||
output_parts.append(text)
|
|
||||||
elif item.get("type") == "image":
|
|
||||||
output_parts.append(
|
|
||||||
{
|
|
||||||
"type": "image",
|
|
||||||
"data": item.get("data"),
|
|
||||||
"mimeType": item.get("mimeType"),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif item.get("type") == "resource":
|
|
||||||
output_parts.append(item.get("resource", {}))
|
|
||||||
|
|
||||||
# If single result, unwrap
|
|
||||||
if len(output_parts) == 1:
|
|
||||||
return output_parts[0]
|
|
||||||
return output_parts if output_parts else None
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
user_id: str,
|
|
||||||
credentials: OAuth2Credentials | None = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
if not input_data.server_url:
|
|
||||||
yield "error", "MCP server URL is required"
|
|
||||||
return
|
|
||||||
|
|
||||||
if not input_data.selected_tool:
|
|
||||||
yield "error", "No tool selected. Please select a tool from the dropdown."
|
|
||||||
return
|
|
||||||
|
|
||||||
auth_token = (
|
|
||||||
credentials.access_token.get_secret_value() if credentials else None
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await self._call_mcp_tool(
|
|
||||||
server_url=input_data.server_url,
|
|
||||||
tool_name=input_data.selected_tool,
|
|
||||||
arguments=input_data.tool_arguments,
|
|
||||||
auth_token=auth_token,
|
|
||||||
)
|
|
||||||
yield "result", result
|
|
||||||
except MCPClientError as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"MCP tool call failed: {e}")
|
|
||||||
yield "error", f"MCP tool call failed: {str(e)}"
|
|
||||||
@@ -1,318 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP (Model Context Protocol) HTTP client.
|
|
||||||
|
|
||||||
Implements the MCP Streamable HTTP transport for listing tools and calling tools
|
|
||||||
on remote MCP servers. Uses JSON-RPC 2.0 over HTTP POST.
|
|
||||||
|
|
||||||
Handles both JSON and SSE (text/event-stream) response formats per the MCP spec.
|
|
||||||
|
|
||||||
Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.util.request import Requests
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MCPTool:
|
|
||||||
"""Represents an MCP tool discovered from a server."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
input_schema: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MCPCallResult:
|
|
||||||
"""Result from calling an MCP tool."""
|
|
||||||
|
|
||||||
content: list[dict[str, Any]] = field(default_factory=list)
|
|
||||||
is_error: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class MCPClientError(Exception):
|
|
||||||
"""Raised when an MCP protocol error occurs."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MCPClient:
|
|
||||||
"""
|
|
||||||
Async HTTP client for the MCP Streamable HTTP transport.
|
|
||||||
|
|
||||||
Communicates with MCP servers using JSON-RPC 2.0 over HTTP POST.
|
|
||||||
Supports optional Bearer token authentication.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
server_url: str,
|
|
||||||
auth_token: str | None = None,
|
|
||||||
):
|
|
||||||
self.server_url = server_url.rstrip("/")
|
|
||||||
self.auth_token = auth_token
|
|
||||||
self._request_id = 0
|
|
||||||
self._session_id: str | None = None
|
|
||||||
|
|
||||||
def _next_id(self) -> int:
|
|
||||||
self._request_id += 1
|
|
||||||
return self._request_id
|
|
||||||
|
|
||||||
def _build_headers(self) -> dict[str, str]:
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Accept": "application/json, text/event-stream",
|
|
||||||
}
|
|
||||||
if self.auth_token:
|
|
||||||
headers["Authorization"] = f"Bearer {self.auth_token}"
|
|
||||||
if self._session_id:
|
|
||||||
headers["Mcp-Session-Id"] = self._session_id
|
|
||||||
return headers
|
|
||||||
|
|
||||||
def _build_jsonrpc_request(
|
|
||||||
self, method: str, params: dict[str, Any] | None = None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
req: dict[str, Any] = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"method": method,
|
|
||||||
"id": self._next_id(),
|
|
||||||
}
|
|
||||||
if params is not None:
|
|
||||||
req["params"] = params
|
|
||||||
return req
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _parse_sse_response(text: str) -> dict[str, Any]:
|
|
||||||
"""Parse an SSE (text/event-stream) response body into JSON-RPC data.
|
|
||||||
|
|
||||||
MCP servers may return responses as SSE with format:
|
|
||||||
event: message
|
|
||||||
data: {"jsonrpc":"2.0","result":{...},"id":1}
|
|
||||||
|
|
||||||
We extract the last `data:` line that contains a JSON-RPC response
|
|
||||||
(i.e. has an "id" field), which is the reply to our request.
|
|
||||||
"""
|
|
||||||
last_data: dict[str, Any] | None = None
|
|
||||||
for line in text.splitlines():
|
|
||||||
stripped = line.strip()
|
|
||||||
if stripped.startswith("data:"):
|
|
||||||
payload = stripped[len("data:") :].strip()
|
|
||||||
if not payload:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
parsed = json.loads(payload)
|
|
||||||
# Only keep JSON-RPC responses (have "id"), skip notifications
|
|
||||||
if isinstance(parsed, dict) and "id" in parsed:
|
|
||||||
last_data = parsed
|
|
||||||
except (json.JSONDecodeError, ValueError):
|
|
||||||
continue
|
|
||||||
if last_data is None:
|
|
||||||
raise MCPClientError("No JSON-RPC response found in SSE stream")
|
|
||||||
return last_data
|
|
||||||
|
|
||||||
async def _send_request(
|
|
||||||
self, method: str, params: dict[str, Any] | None = None
|
|
||||||
) -> Any:
|
|
||||||
"""Send a JSON-RPC request to the MCP server and return the result.
|
|
||||||
|
|
||||||
Handles both ``application/json`` and ``text/event-stream`` responses
|
|
||||||
as required by the MCP Streamable HTTP transport specification.
|
|
||||||
"""
|
|
||||||
payload = self._build_jsonrpc_request(method, params)
|
|
||||||
headers = self._build_headers()
|
|
||||||
|
|
||||||
requests = Requests(
|
|
||||||
raise_for_status=True,
|
|
||||||
extra_headers=headers,
|
|
||||||
)
|
|
||||||
response = await requests.post(self.server_url, json=payload)
|
|
||||||
|
|
||||||
# Capture session ID from response (MCP Streamable HTTP transport)
|
|
||||||
session_id = response.headers.get("Mcp-Session-Id")
|
|
||||||
if session_id:
|
|
||||||
self._session_id = session_id
|
|
||||||
|
|
||||||
content_type = response.headers.get("content-type", "")
|
|
||||||
if "text/event-stream" in content_type:
|
|
||||||
body = self._parse_sse_response(response.text())
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
body = response.json()
|
|
||||||
except Exception as e:
|
|
||||||
raise MCPClientError(
|
|
||||||
f"MCP server returned non-JSON response: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
# Handle JSON-RPC error
|
|
||||||
if "error" in body:
|
|
||||||
error = body["error"]
|
|
||||||
if isinstance(error, dict):
|
|
||||||
raise MCPClientError(
|
|
||||||
f"MCP server error [{error.get('code', '?')}]: "
|
|
||||||
f"{error.get('message', 'Unknown error')}"
|
|
||||||
)
|
|
||||||
raise MCPClientError(f"MCP server error: {error}")
|
|
||||||
|
|
||||||
return body.get("result")
|
|
||||||
|
|
||||||
async def _send_notification(self, method: str) -> None:
|
|
||||||
"""Send a JSON-RPC notification (no id, no response expected)."""
|
|
||||||
headers = self._build_headers()
|
|
||||||
notification = {"jsonrpc": "2.0", "method": method}
|
|
||||||
requests = Requests(
|
|
||||||
raise_for_status=False,
|
|
||||||
extra_headers=headers,
|
|
||||||
)
|
|
||||||
await requests.post(self.server_url, json=notification)
|
|
||||||
|
|
||||||
async def discover_auth(self) -> dict[str, Any] | None:
|
|
||||||
"""Probe the MCP server's OAuth metadata (RFC 9728 / MCP spec).
|
|
||||||
|
|
||||||
Returns ``None`` if the server doesn't require auth, otherwise returns
|
|
||||||
a dict with:
|
|
||||||
- ``authorization_servers``: list of authorization server URLs
|
|
||||||
- ``resource``: the resource indicator URL (usually the MCP endpoint)
|
|
||||||
- ``scopes_supported``: optional list of supported scopes
|
|
||||||
|
|
||||||
The caller can then fetch the authorization server metadata to get
|
|
||||||
``authorization_endpoint``, ``token_endpoint``, etc.
|
|
||||||
"""
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
parsed = urlparse(self.server_url)
|
|
||||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
|
||||||
|
|
||||||
# Build candidates for protected-resource metadata (per RFC 9728)
|
|
||||||
path = parsed.path.rstrip("/")
|
|
||||||
candidates = []
|
|
||||||
if path and path != "/":
|
|
||||||
candidates.append(f"{base}/.well-known/oauth-protected-resource{path}")
|
|
||||||
candidates.append(f"{base}/.well-known/oauth-protected-resource")
|
|
||||||
|
|
||||||
requests = Requests(
|
|
||||||
raise_for_status=False,
|
|
||||||
)
|
|
||||||
for url in candidates:
|
|
||||||
try:
|
|
||||||
resp = await requests.get(url)
|
|
||||||
if resp.status == 200:
|
|
||||||
data = resp.json()
|
|
||||||
if isinstance(data, dict) and "authorization_servers" in data:
|
|
||||||
return data
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def discover_auth_server_metadata(
|
|
||||||
self, auth_server_url: str
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Fetch the OAuth Authorization Server Metadata (RFC 8414).
|
|
||||||
|
|
||||||
Given an authorization server URL, returns a dict with:
|
|
||||||
- ``authorization_endpoint``
|
|
||||||
- ``token_endpoint``
|
|
||||||
- ``registration_endpoint`` (for dynamic client registration)
|
|
||||||
- ``scopes_supported``
|
|
||||||
- ``code_challenge_methods_supported``
|
|
||||||
- etc.
|
|
||||||
"""
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
parsed = urlparse(auth_server_url)
|
|
||||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
|
||||||
path = parsed.path.rstrip("/")
|
|
||||||
|
|
||||||
# Try standard metadata endpoints (RFC 8414 and OpenID Connect)
|
|
||||||
candidates = []
|
|
||||||
if path and path != "/":
|
|
||||||
candidates.append(f"{base}/.well-known/oauth-authorization-server{path}")
|
|
||||||
candidates.append(f"{base}/.well-known/oauth-authorization-server")
|
|
||||||
candidates.append(f"{base}/.well-known/openid-configuration")
|
|
||||||
|
|
||||||
requests = Requests(
|
|
||||||
raise_for_status=False,
|
|
||||||
)
|
|
||||||
for url in candidates:
|
|
||||||
try:
|
|
||||||
resp = await requests.get(url)
|
|
||||||
if resp.status == 200:
|
|
||||||
data = resp.json()
|
|
||||||
if isinstance(data, dict) and "authorization_endpoint" in data:
|
|
||||||
return data
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def initialize(self) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Send the MCP initialize request.
|
|
||||||
|
|
||||||
This is required by the MCP protocol before any other requests.
|
|
||||||
Returns the server's capabilities.
|
|
||||||
"""
|
|
||||||
result = await self._send_request(
|
|
||||||
"initialize",
|
|
||||||
{
|
|
||||||
"protocolVersion": "2025-03-26",
|
|
||||||
"capabilities": {},
|
|
||||||
"clientInfo": {"name": "AutoGPT-Platform", "version": "1.0.0"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Send initialized notification (no response expected)
|
|
||||||
await self._send_notification("notifications/initialized")
|
|
||||||
|
|
||||||
return result or {}
|
|
||||||
|
|
||||||
async def list_tools(self) -> list[MCPTool]:
|
|
||||||
"""
|
|
||||||
Discover available tools from the MCP server.
|
|
||||||
|
|
||||||
Returns a list of MCPTool objects with name, description, and input schema.
|
|
||||||
"""
|
|
||||||
result = await self._send_request("tools/list")
|
|
||||||
if not result or "tools" not in result:
|
|
||||||
return []
|
|
||||||
|
|
||||||
tools = []
|
|
||||||
for tool_data in result["tools"]:
|
|
||||||
tools.append(
|
|
||||||
MCPTool(
|
|
||||||
name=tool_data.get("name", ""),
|
|
||||||
description=tool_data.get("description", ""),
|
|
||||||
input_schema=tool_data.get("inputSchema", {}),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return tools
|
|
||||||
|
|
||||||
async def call_tool(
|
|
||||||
self, tool_name: str, arguments: dict[str, Any]
|
|
||||||
) -> MCPCallResult:
|
|
||||||
"""
|
|
||||||
Call a tool on the MCP server.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name: The name of the tool to call.
|
|
||||||
arguments: The arguments to pass to the tool.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
MCPCallResult with the tool's response content.
|
|
||||||
"""
|
|
||||||
result = await self._send_request(
|
|
||||||
"tools/call",
|
|
||||||
{"name": tool_name, "arguments": arguments},
|
|
||||||
)
|
|
||||||
if not result:
|
|
||||||
return MCPCallResult(is_error=True)
|
|
||||||
|
|
||||||
return MCPCallResult(
|
|
||||||
content=result.get("content", []),
|
|
||||||
is_error=result.get("isError", False),
|
|
||||||
)
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
"""
|
|
||||||
Conftest for MCP block tests.
|
|
||||||
|
|
||||||
Override the session-scoped server and graph_cleanup fixtures from
|
|
||||||
backend/conftest.py so that MCP integration tests don't spin up the
|
|
||||||
full SpinTestServer infrastructure.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config: pytest.Config) -> None:
|
|
||||||
config.addinivalue_line("markers", "e2e: end-to-end tests requiring network")
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_collection_modifyitems(
|
|
||||||
config: pytest.Config, items: list[pytest.Item]
|
|
||||||
) -> None:
|
|
||||||
"""Skip e2e tests unless --run-e2e is passed."""
|
|
||||||
if not config.getoption("--run-e2e", default=False):
|
|
||||||
skip_e2e = pytest.mark.skip(reason="need --run-e2e option to run")
|
|
||||||
for item in items:
|
|
||||||
if "e2e" in item.keywords:
|
|
||||||
item.add_marker(skip_e2e)
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
|
||||||
parser.addoption(
|
|
||||||
"--run-e2e", action="store_true", default=False, help="run e2e tests"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def server():
|
|
||||||
"""No-op override — MCP tests don't need the full platform server."""
|
|
||||||
yield None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def graph_cleanup(server):
|
|
||||||
"""No-op override — MCP tests don't create graphs."""
|
|
||||||
yield
|
|
||||||
@@ -1,198 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP OAuth handler for MCP servers that use OAuth 2.1 authorization.
|
|
||||||
|
|
||||||
Unlike other OAuth handlers (GitHub, Google, etc.) where endpoints are fixed,
|
|
||||||
MCP servers have dynamic endpoints discovered via RFC 9728 / RFC 8414 metadata.
|
|
||||||
This handler accepts those endpoints at construction time.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
import urllib.parse
|
|
||||||
from typing import ClassVar, Optional
|
|
||||||
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util.request import Requests
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class MCPOAuthHandler(BaseOAuthHandler):
|
|
||||||
"""
|
|
||||||
OAuth handler for MCP servers with dynamically-discovered endpoints.
|
|
||||||
|
|
||||||
Construction requires the authorization and token endpoint URLs,
|
|
||||||
which are obtained via MCP OAuth metadata discovery
|
|
||||||
(``MCPClient.discover_auth`` + ``discover_auth_server_metadata``).
|
|
||||||
"""
|
|
||||||
|
|
||||||
PROVIDER_NAME: ClassVar[ProviderName | str] = ProviderName.MCP
|
|
||||||
DEFAULT_SCOPES: ClassVar[list[str]] = []
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
client_id: str,
|
|
||||||
client_secret: str,
|
|
||||||
redirect_uri: str,
|
|
||||||
*,
|
|
||||||
authorize_url: str,
|
|
||||||
token_url: str,
|
|
||||||
revoke_url: str | None = None,
|
|
||||||
resource_url: str | None = None,
|
|
||||||
):
|
|
||||||
self.client_id = client_id
|
|
||||||
self.client_secret = client_secret
|
|
||||||
self.redirect_uri = redirect_uri
|
|
||||||
self.authorize_url = authorize_url
|
|
||||||
self.token_url = token_url
|
|
||||||
self.revoke_url = revoke_url
|
|
||||||
self.resource_url = resource_url
|
|
||||||
|
|
||||||
def get_login_url(
|
|
||||||
self,
|
|
||||||
scopes: list[str],
|
|
||||||
state: str,
|
|
||||||
code_challenge: Optional[str],
|
|
||||||
) -> str:
|
|
||||||
scopes = self.handle_default_scopes(scopes)
|
|
||||||
|
|
||||||
params: dict[str, str] = {
|
|
||||||
"response_type": "code",
|
|
||||||
"client_id": self.client_id,
|
|
||||||
"redirect_uri": self.redirect_uri,
|
|
||||||
"state": state,
|
|
||||||
}
|
|
||||||
if scopes:
|
|
||||||
params["scope"] = " ".join(scopes)
|
|
||||||
# PKCE (S256) — included when the caller provides a code_challenge
|
|
||||||
if code_challenge:
|
|
||||||
params["code_challenge"] = code_challenge
|
|
||||||
params["code_challenge_method"] = "S256"
|
|
||||||
# MCP spec requires resource indicator (RFC 8707)
|
|
||||||
if self.resource_url:
|
|
||||||
params["resource"] = self.resource_url
|
|
||||||
|
|
||||||
return f"{self.authorize_url}?{urllib.parse.urlencode(params)}"
|
|
||||||
|
|
||||||
async def exchange_code_for_tokens(
|
|
||||||
self,
|
|
||||||
code: str,
|
|
||||||
scopes: list[str],
|
|
||||||
code_verifier: Optional[str],
|
|
||||||
) -> OAuth2Credentials:
|
|
||||||
data: dict[str, str] = {
|
|
||||||
"grant_type": "authorization_code",
|
|
||||||
"code": code,
|
|
||||||
"redirect_uri": self.redirect_uri,
|
|
||||||
"client_id": self.client_id,
|
|
||||||
}
|
|
||||||
if self.client_secret:
|
|
||||||
data["client_secret"] = self.client_secret
|
|
||||||
if code_verifier:
|
|
||||||
data["code_verifier"] = code_verifier
|
|
||||||
if self.resource_url:
|
|
||||||
data["resource"] = self.resource_url
|
|
||||||
|
|
||||||
response = await Requests(raise_for_status=True).post(
|
|
||||||
self.token_url,
|
|
||||||
data=data,
|
|
||||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
||||||
)
|
|
||||||
tokens = response.json()
|
|
||||||
|
|
||||||
if "error" in tokens:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Token exchange failed: {tokens.get('error_description', tokens['error'])}"
|
|
||||||
)
|
|
||||||
|
|
||||||
now = int(time.time())
|
|
||||||
expires_in = tokens.get("expires_in")
|
|
||||||
|
|
||||||
return OAuth2Credentials(
|
|
||||||
provider=str(self.PROVIDER_NAME),
|
|
||||||
title=None,
|
|
||||||
access_token=SecretStr(tokens["access_token"]),
|
|
||||||
refresh_token=(
|
|
||||||
SecretStr(tokens["refresh_token"])
|
|
||||||
if tokens.get("refresh_token")
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
access_token_expires_at=now + expires_in if expires_in else None,
|
|
||||||
refresh_token_expires_at=None,
|
|
||||||
scopes=scopes,
|
|
||||||
metadata={
|
|
||||||
"mcp_token_url": self.token_url,
|
|
||||||
"mcp_resource_url": self.resource_url,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _refresh_tokens(
|
|
||||||
self, credentials: OAuth2Credentials
|
|
||||||
) -> OAuth2Credentials:
|
|
||||||
if not credentials.refresh_token:
|
|
||||||
raise ValueError("No refresh token available for MCP OAuth credentials")
|
|
||||||
|
|
||||||
data: dict[str, str] = {
|
|
||||||
"grant_type": "refresh_token",
|
|
||||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
|
||||||
"client_id": self.client_id,
|
|
||||||
}
|
|
||||||
if self.client_secret:
|
|
||||||
data["client_secret"] = self.client_secret
|
|
||||||
if self.resource_url:
|
|
||||||
data["resource"] = self.resource_url
|
|
||||||
|
|
||||||
response = await Requests(raise_for_status=True).post(
|
|
||||||
self.token_url,
|
|
||||||
data=data,
|
|
||||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
||||||
)
|
|
||||||
tokens = response.json()
|
|
||||||
|
|
||||||
if "error" in tokens:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Token refresh failed: {tokens.get('error_description', tokens['error'])}"
|
|
||||||
)
|
|
||||||
|
|
||||||
now = int(time.time())
|
|
||||||
expires_in = tokens.get("expires_in")
|
|
||||||
|
|
||||||
return OAuth2Credentials(
|
|
||||||
id=credentials.id,
|
|
||||||
provider=str(self.PROVIDER_NAME),
|
|
||||||
title=credentials.title,
|
|
||||||
access_token=SecretStr(tokens["access_token"]),
|
|
||||||
refresh_token=(
|
|
||||||
SecretStr(tokens["refresh_token"])
|
|
||||||
if tokens.get("refresh_token")
|
|
||||||
else credentials.refresh_token
|
|
||||||
),
|
|
||||||
access_token_expires_at=now + expires_in if expires_in else None,
|
|
||||||
refresh_token_expires_at=credentials.refresh_token_expires_at,
|
|
||||||
scopes=credentials.scopes,
|
|
||||||
metadata=credentials.metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
|
||||||
if not self.revoke_url:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = {
|
|
||||||
"token": credentials.access_token.get_secret_value(),
|
|
||||||
"token_type_hint": "access_token",
|
|
||||||
"client_id": self.client_id,
|
|
||||||
}
|
|
||||||
await Requests().post(
|
|
||||||
self.revoke_url,
|
|
||||||
data=data,
|
|
||||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to revoke MCP OAuth tokens", exc_info=True)
|
|
||||||
return False
|
|
||||||
@@ -1,104 +0,0 @@
|
|||||||
"""
|
|
||||||
End-to-end tests against a real public MCP server.
|
|
||||||
|
|
||||||
These tests hit the OpenAI docs MCP server (https://developers.openai.com/mcp)
|
|
||||||
which is publicly accessible without authentication and returns SSE responses.
|
|
||||||
|
|
||||||
Mark: These are tagged with ``@pytest.mark.e2e`` so they can be run/skipped
|
|
||||||
independently of the rest of the test suite (they require network access).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.blocks.mcp.client import MCPClient
|
|
||||||
|
|
||||||
# Public MCP server that requires no authentication
|
|
||||||
OPENAI_DOCS_MCP_URL = "https://developers.openai.com/mcp"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.e2e
|
|
||||||
class TestRealMCPServer:
|
|
||||||
"""Tests against the live OpenAI docs MCP server."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_initialize(self):
|
|
||||||
"""Verify we can complete the MCP handshake with a real server."""
|
|
||||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
|
||||||
result = await client.initialize()
|
|
||||||
|
|
||||||
assert result["protocolVersion"] == "2025-03-26"
|
|
||||||
assert "serverInfo" in result
|
|
||||||
assert result["serverInfo"]["name"] == "openai-docs-mcp"
|
|
||||||
assert "tools" in result.get("capabilities", {})
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tools(self):
|
|
||||||
"""Verify we can discover tools from a real MCP server."""
|
|
||||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
|
||||||
await client.initialize()
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
assert len(tools) >= 3 # server has at least 5 tools as of writing
|
|
||||||
|
|
||||||
tool_names = {t.name for t in tools}
|
|
||||||
# These tools are documented and should be stable
|
|
||||||
assert "search_openai_docs" in tool_names
|
|
||||||
assert "list_openai_docs" in tool_names
|
|
||||||
assert "fetch_openai_doc" in tool_names
|
|
||||||
|
|
||||||
# Verify schema structure
|
|
||||||
search_tool = next(t for t in tools if t.name == "search_openai_docs")
|
|
||||||
assert "query" in search_tool.input_schema.get("properties", {})
|
|
||||||
assert "query" in search_tool.input_schema.get("required", [])
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_tool_list_api_endpoints(self):
|
|
||||||
"""Call the list_api_endpoints tool and verify we get real data."""
|
|
||||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool("list_api_endpoints", {})
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
assert len(result.content) >= 1
|
|
||||||
assert result.content[0]["type"] == "text"
|
|
||||||
|
|
||||||
data = json.loads(result.content[0]["text"])
|
|
||||||
assert "paths" in data or "urls" in data
|
|
||||||
# The OpenAI API should have many endpoints
|
|
||||||
total = data.get("total", len(data.get("paths", [])))
|
|
||||||
assert total > 50
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_tool_search(self):
|
|
||||||
"""Search for docs and verify we get results."""
|
|
||||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool(
|
|
||||||
"search_openai_docs", {"query": "chat completions", "limit": 3}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
assert len(result.content) >= 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_sse_response_handling(self):
|
|
||||||
"""Verify the client correctly handles SSE responses from a real server.
|
|
||||||
|
|
||||||
This is the key test — our local test server returns JSON,
|
|
||||||
but real MCP servers typically return SSE. This proves the
|
|
||||||
SSE parsing works end-to-end.
|
|
||||||
"""
|
|
||||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
|
||||||
# initialize() internally calls _send_request which must parse SSE
|
|
||||||
result = await client.initialize()
|
|
||||||
|
|
||||||
# If we got here without error, SSE parsing works
|
|
||||||
assert isinstance(result, dict)
|
|
||||||
assert "protocolVersion" in result
|
|
||||||
|
|
||||||
# Also verify list_tools works (another SSE response)
|
|
||||||
tools = await client.list_tools()
|
|
||||||
assert len(tools) > 0
|
|
||||||
assert all(hasattr(t, "name") for t in tools)
|
|
||||||
@@ -1,389 +0,0 @@
|
|||||||
"""
|
|
||||||
Integration tests for MCP client and MCPToolBlock against a real HTTP server.
|
|
||||||
|
|
||||||
These tests spin up a local MCP test server and run the full client/block flow
|
|
||||||
against it — no mocking, real HTTP requests.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import threading
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from aiohttp import web
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.blocks.mcp.block import MCPToolBlock
|
|
||||||
from backend.blocks.mcp.client import MCPClient
|
|
||||||
from backend.blocks.mcp.test_server import create_test_mcp_app
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
|
|
||||||
MOCK_USER_ID = "test-user-integration"
|
|
||||||
|
|
||||||
|
|
||||||
class _MCPTestServer:
|
|
||||||
"""
|
|
||||||
Run an MCP test server in a background thread with its own event loop.
|
|
||||||
This avoids event loop conflicts with pytest-asyncio.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, auth_token: str | None = None):
|
|
||||||
self.auth_token = auth_token
|
|
||||||
self.url: str = ""
|
|
||||||
self._runner: web.AppRunner | None = None
|
|
||||||
self._loop: asyncio.AbstractEventLoop | None = None
|
|
||||||
self._thread: threading.Thread | None = None
|
|
||||||
self._started = threading.Event()
|
|
||||||
|
|
||||||
def _run(self):
|
|
||||||
self._loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(self._loop)
|
|
||||||
self._loop.run_until_complete(self._start())
|
|
||||||
self._started.set()
|
|
||||||
self._loop.run_forever()
|
|
||||||
|
|
||||||
async def _start(self):
|
|
||||||
app = create_test_mcp_app(auth_token=self.auth_token)
|
|
||||||
self._runner = web.AppRunner(app)
|
|
||||||
await self._runner.setup()
|
|
||||||
site = web.TCPSite(self._runner, "127.0.0.1", 0)
|
|
||||||
await site.start()
|
|
||||||
port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr]
|
|
||||||
self.url = f"http://127.0.0.1:{port}/mcp"
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
|
||||||
self._thread.start()
|
|
||||||
if not self._started.wait(timeout=5):
|
|
||||||
raise RuntimeError("MCP test server failed to start within 5 seconds")
|
|
||||||
return self
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
if self._loop and self._runner:
|
|
||||||
asyncio.run_coroutine_threadsafe(self._runner.cleanup(), self._loop).result(
|
|
||||||
timeout=5
|
|
||||||
)
|
|
||||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
|
||||||
if self._thread:
|
|
||||||
self._thread.join(timeout=5)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def mcp_server():
|
|
||||||
"""Start a local MCP test server in a background thread."""
|
|
||||||
server = _MCPTestServer()
|
|
||||||
server.start()
|
|
||||||
yield server.url
|
|
||||||
server.stop()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def mcp_server_with_auth():
|
|
||||||
"""Start a local MCP test server with auth in a background thread."""
|
|
||||||
server = _MCPTestServer(auth_token="test-secret-token")
|
|
||||||
server.start()
|
|
||||||
yield server.url, "test-secret-token"
|
|
||||||
server.stop()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def _allow_localhost():
|
|
||||||
"""
|
|
||||||
Allow 127.0.0.1 through SSRF protection for integration tests.
|
|
||||||
|
|
||||||
The Requests class blocks private IPs by default. We patch the Requests
|
|
||||||
constructor to always include 127.0.0.1 as a trusted origin so the local
|
|
||||||
test server is reachable.
|
|
||||||
"""
|
|
||||||
from backend.util.request import Requests
|
|
||||||
|
|
||||||
original_init = Requests.__init__
|
|
||||||
|
|
||||||
def patched_init(self, *args, **kwargs):
|
|
||||||
trusted = list(kwargs.get("trusted_origins") or [])
|
|
||||||
trusted.append("http://127.0.0.1")
|
|
||||||
kwargs["trusted_origins"] = trusted
|
|
||||||
original_init(self, *args, **kwargs)
|
|
||||||
|
|
||||||
with patch.object(Requests, "__init__", patched_init):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
def _make_client(url: str, auth_token: str | None = None) -> MCPClient:
|
|
||||||
"""Create an MCPClient for integration tests."""
|
|
||||||
return MCPClient(url, auth_token=auth_token)
|
|
||||||
|
|
||||||
|
|
||||||
# ── MCPClient integration tests ──────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPClientIntegration:
|
|
||||||
"""Test MCPClient against a real local MCP server."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_initialize(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
result = await client.initialize()
|
|
||||||
|
|
||||||
assert result["protocolVersion"] == "2025-03-26"
|
|
||||||
assert result["serverInfo"]["name"] == "test-mcp-server"
|
|
||||||
assert "tools" in result["capabilities"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tools(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
assert len(tools) == 3
|
|
||||||
|
|
||||||
tool_names = {t.name for t in tools}
|
|
||||||
assert tool_names == {"get_weather", "add_numbers", "echo"}
|
|
||||||
|
|
||||||
# Check get_weather schema
|
|
||||||
weather = next(t for t in tools if t.name == "get_weather")
|
|
||||||
assert weather.description == "Get current weather for a city"
|
|
||||||
assert "city" in weather.input_schema["properties"]
|
|
||||||
assert weather.input_schema["required"] == ["city"]
|
|
||||||
|
|
||||||
# Check add_numbers schema
|
|
||||||
add = next(t for t in tools if t.name == "add_numbers")
|
|
||||||
assert "a" in add.input_schema["properties"]
|
|
||||||
assert "b" in add.input_schema["properties"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_tool_get_weather(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool("get_weather", {"city": "London"})
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
assert len(result.content) == 1
|
|
||||||
assert result.content[0]["type"] == "text"
|
|
||||||
|
|
||||||
data = json.loads(result.content[0]["text"])
|
|
||||||
assert data["city"] == "London"
|
|
||||||
assert data["temperature"] == 22
|
|
||||||
assert data["condition"] == "sunny"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_tool_add_numbers(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool("add_numbers", {"a": 3, "b": 7})
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
data = json.loads(result.content[0]["text"])
|
|
||||||
assert data["result"] == 10
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_tool_echo(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool("echo", {"message": "Hello MCP!"})
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
assert result.content[0]["text"] == "Hello MCP!"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_unknown_tool(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool("nonexistent_tool", {})
|
|
||||||
|
|
||||||
assert result.is_error
|
|
||||||
assert "Unknown tool" in result.content[0]["text"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_auth_success(self, mcp_server_with_auth):
|
|
||||||
url, token = mcp_server_with_auth
|
|
||||||
client = _make_client(url, auth_token=token)
|
|
||||||
result = await client.initialize()
|
|
||||||
|
|
||||||
assert result["protocolVersion"] == "2025-03-26"
|
|
||||||
|
|
||||||
tools = await client.list_tools()
|
|
||||||
assert len(tools) == 3
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_auth_failure(self, mcp_server_with_auth):
|
|
||||||
url, _ = mcp_server_with_auth
|
|
||||||
client = _make_client(url, auth_token="wrong-token")
|
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
await client.initialize()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_auth_missing(self, mcp_server_with_auth):
|
|
||||||
url, _ = mcp_server_with_auth
|
|
||||||
client = _make_client(url)
|
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
await client.initialize()
|
|
||||||
|
|
||||||
|
|
||||||
# ── MCPToolBlock integration tests ───────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPToolBlockIntegration:
|
|
||||||
"""Test MCPToolBlock end-to-end against a real local MCP server."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_full_flow_get_weather(self, mcp_server):
|
|
||||||
"""Full flow: discover tools, select one, execute it."""
|
|
||||||
# Step 1: Discover tools (simulating what the frontend/API would do)
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
tools = await client.list_tools()
|
|
||||||
assert len(tools) == 3
|
|
||||||
|
|
||||||
# Step 2: User selects "get_weather" and we get its schema
|
|
||||||
weather_tool = next(t for t in tools if t.name == "get_weather")
|
|
||||||
|
|
||||||
# Step 3: Execute the block — no credentials (public server)
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=mcp_server,
|
|
||||||
selected_tool="get_weather",
|
|
||||||
tool_input_schema=weather_tool.input_schema,
|
|
||||||
tool_arguments={"city": "Paris"},
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
result = outputs[0][1]
|
|
||||||
assert result["city"] == "Paris"
|
|
||||||
assert result["temperature"] == 22
|
|
||||||
assert result["condition"] == "sunny"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_full_flow_add_numbers(self, mcp_server):
|
|
||||||
"""Full flow for add_numbers tool."""
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
tools = await client.list_tools()
|
|
||||||
add_tool = next(t for t in tools if t.name == "add_numbers")
|
|
||||||
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=mcp_server,
|
|
||||||
selected_tool="add_numbers",
|
|
||||||
tool_input_schema=add_tool.input_schema,
|
|
||||||
tool_arguments={"a": 42, "b": 58},
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
assert outputs[0][1]["result"] == 100
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_full_flow_echo_plain_text(self, mcp_server):
|
|
||||||
"""Verify plain text (non-JSON) responses work."""
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=mcp_server,
|
|
||||||
selected_tool="echo",
|
|
||||||
tool_input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"message": {"type": "string"}},
|
|
||||||
"required": ["message"],
|
|
||||||
},
|
|
||||||
tool_arguments={"message": "Hello from AutoGPT!"},
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
assert outputs[0][1] == "Hello from AutoGPT!"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_full_flow_unknown_tool_yields_error(self, mcp_server):
|
|
||||||
"""Calling an unknown tool should yield an error output."""
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=mcp_server,
|
|
||||||
selected_tool="nonexistent_tool",
|
|
||||||
tool_arguments={},
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "error"
|
|
||||||
assert "returned an error" in outputs[0][1]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_full_flow_with_auth(self, mcp_server_with_auth):
|
|
||||||
"""Full flow with authentication via credentials kwarg."""
|
|
||||||
url, token = mcp_server_with_auth
|
|
||||||
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=url,
|
|
||||||
selected_tool="echo",
|
|
||||||
tool_input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"message": {"type": "string"}},
|
|
||||||
"required": ["message"],
|
|
||||||
},
|
|
||||||
tool_arguments={"message": "Authenticated!"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pass credentials via the standard kwarg (as the executor would)
|
|
||||||
test_creds = OAuth2Credentials(
|
|
||||||
id="test-cred",
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr(token),
|
|
||||||
refresh_token=SecretStr(""),
|
|
||||||
scopes=[],
|
|
||||||
title="Test MCP credential",
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(
|
|
||||||
input_data, user_id=MOCK_USER_ID, credentials=test_creds
|
|
||||||
):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
assert outputs[0][1] == "Authenticated!"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_no_credentials_runs_without_auth(self, mcp_server):
|
|
||||||
"""Block runs without auth when no credentials are provided."""
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=mcp_server,
|
|
||||||
selected_tool="echo",
|
|
||||||
tool_input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"message": {"type": "string"}},
|
|
||||||
"required": ["message"],
|
|
||||||
},
|
|
||||||
tool_arguments={"message": "No auth needed"},
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(
|
|
||||||
input_data, user_id=MOCK_USER_ID, credentials=None
|
|
||||||
):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
assert outputs[0][1] == "No auth needed"
|
|
||||||
@@ -1,619 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests for MCP client and MCPToolBlock.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from unittest.mock import AsyncMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.blocks.mcp.block import MCPToolBlock
|
|
||||||
from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError
|
|
||||||
from backend.util.test import execute_block_test
|
|
||||||
|
|
||||||
# ── SSE parsing unit tests ───────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestSSEParsing:
|
|
||||||
"""Tests for SSE (text/event-stream) response parsing."""
|
|
||||||
|
|
||||||
def test_parse_sse_simple(self):
|
|
||||||
sse = (
|
|
||||||
"event: message\n"
|
|
||||||
'data: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n'
|
|
||||||
"\n"
|
|
||||||
)
|
|
||||||
body = MCPClient._parse_sse_response(sse)
|
|
||||||
assert body["result"] == {"tools": []}
|
|
||||||
assert body["id"] == 1
|
|
||||||
|
|
||||||
def test_parse_sse_with_notifications(self):
|
|
||||||
"""SSE streams can contain notifications (no id) before the response."""
|
|
||||||
sse = (
|
|
||||||
"event: message\n"
|
|
||||||
'data: {"jsonrpc":"2.0","method":"some/notification"}\n'
|
|
||||||
"\n"
|
|
||||||
"event: message\n"
|
|
||||||
'data: {"jsonrpc":"2.0","result":{"ok":true},"id":2}\n'
|
|
||||||
"\n"
|
|
||||||
)
|
|
||||||
body = MCPClient._parse_sse_response(sse)
|
|
||||||
assert body["result"] == {"ok": True}
|
|
||||||
assert body["id"] == 2
|
|
||||||
|
|
||||||
def test_parse_sse_error_response(self):
|
|
||||||
sse = (
|
|
||||||
"event: message\n"
|
|
||||||
'data: {"jsonrpc":"2.0","error":{"code":-32600,"message":"Bad Request"},"id":1}\n'
|
|
||||||
)
|
|
||||||
body = MCPClient._parse_sse_response(sse)
|
|
||||||
assert "error" in body
|
|
||||||
assert body["error"]["code"] == -32600
|
|
||||||
|
|
||||||
def test_parse_sse_no_data_raises(self):
|
|
||||||
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
|
|
||||||
MCPClient._parse_sse_response("event: message\n\n")
|
|
||||||
|
|
||||||
def test_parse_sse_empty_raises(self):
|
|
||||||
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
|
|
||||||
MCPClient._parse_sse_response("")
|
|
||||||
|
|
||||||
def test_parse_sse_ignores_non_data_lines(self):
|
|
||||||
sse = (
|
|
||||||
": comment line\n"
|
|
||||||
"event: message\n"
|
|
||||||
"id: 123\n"
|
|
||||||
'data: {"jsonrpc":"2.0","result":"ok","id":1}\n'
|
|
||||||
"\n"
|
|
||||||
)
|
|
||||||
body = MCPClient._parse_sse_response(sse)
|
|
||||||
assert body["result"] == "ok"
|
|
||||||
|
|
||||||
def test_parse_sse_uses_last_response(self):
|
|
||||||
"""If multiple responses exist, use the last one."""
|
|
||||||
sse = (
|
|
||||||
'data: {"jsonrpc":"2.0","result":"first","id":1}\n'
|
|
||||||
"\n"
|
|
||||||
'data: {"jsonrpc":"2.0","result":"second","id":2}\n'
|
|
||||||
"\n"
|
|
||||||
)
|
|
||||||
body = MCPClient._parse_sse_response(sse)
|
|
||||||
assert body["result"] == "second"
|
|
||||||
|
|
||||||
|
|
||||||
# ── MCPClient unit tests ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPClient:
|
|
||||||
"""Tests for the MCP HTTP client."""
|
|
||||||
|
|
||||||
def test_build_headers_without_auth(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
headers = client._build_headers()
|
|
||||||
assert "Authorization" not in headers
|
|
||||||
assert headers["Content-Type"] == "application/json"
|
|
||||||
|
|
||||||
def test_build_headers_with_auth(self):
|
|
||||||
client = MCPClient("https://mcp.example.com", auth_token="my-token")
|
|
||||||
headers = client._build_headers()
|
|
||||||
assert headers["Authorization"] == "Bearer my-token"
|
|
||||||
|
|
||||||
def test_build_jsonrpc_request(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
req = client._build_jsonrpc_request("tools/list")
|
|
||||||
assert req["jsonrpc"] == "2.0"
|
|
||||||
assert req["method"] == "tools/list"
|
|
||||||
assert "id" in req
|
|
||||||
assert "params" not in req
|
|
||||||
|
|
||||||
def test_build_jsonrpc_request_with_params(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
req = client._build_jsonrpc_request(
|
|
||||||
"tools/call", {"name": "test", "arguments": {"x": 1}}
|
|
||||||
)
|
|
||||||
assert req["params"] == {"name": "test", "arguments": {"x": 1}}
|
|
||||||
|
|
||||||
def test_request_id_increments(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
req1 = client._build_jsonrpc_request("tools/list")
|
|
||||||
req2 = client._build_jsonrpc_request("tools/list")
|
|
||||||
assert req2["id"] > req1["id"]
|
|
||||||
|
|
||||||
def test_server_url_trailing_slash_stripped(self):
|
|
||||||
client = MCPClient("https://mcp.example.com/mcp/")
|
|
||||||
assert client.server_url == "https://mcp.example.com/mcp"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_send_request_success(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
mock_response = AsyncMock()
|
|
||||||
mock_response.json.return_value = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"result": {"tools": []},
|
|
||||||
"id": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value={"tools": []}):
|
|
||||||
result = await client._send_request("tools/list")
|
|
||||||
assert result == {"tools": []}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_send_request_error(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
async def mock_send(*args, **kwargs):
|
|
||||||
raise MCPClientError("MCP server error [-32600]: Invalid Request")
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", side_effect=mock_send):
|
|
||||||
with pytest.raises(MCPClientError, match="Invalid Request"):
|
|
||||||
await client._send_request("tools/list")
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tools(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
mock_result = {
|
|
||||||
"tools": [
|
|
||||||
{
|
|
||||||
"name": "get_weather",
|
|
||||||
"description": "Get current weather for a city",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"city": {"type": "string"}},
|
|
||||||
"required": ["city"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "search",
|
|
||||||
"description": "Search the web",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"query": {"type": "string"}},
|
|
||||||
"required": ["query"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value=mock_result):
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
assert len(tools) == 2
|
|
||||||
assert tools[0].name == "get_weather"
|
|
||||||
assert tools[0].description == "Get current weather for a city"
|
|
||||||
assert tools[0].input_schema["properties"]["city"]["type"] == "string"
|
|
||||||
assert tools[1].name == "search"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tools_empty(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value={"tools": []}):
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
assert tools == []
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tools_none_result(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value=None):
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
assert tools == []
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_tool_success(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
mock_result = {
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": json.dumps({"temp": 20, "city": "London"})}
|
|
||||||
],
|
|
||||||
"isError": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value=mock_result):
|
|
||||||
result = await client.call_tool("get_weather", {"city": "London"})
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
assert len(result.content) == 1
|
|
||||||
assert result.content[0]["type"] == "text"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_tool_error(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
mock_result = {
|
|
||||||
"content": [{"type": "text", "text": "City not found"}],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value=mock_result):
|
|
||||||
result = await client.call_tool("get_weather", {"city": "???"})
|
|
||||||
|
|
||||||
assert result.is_error
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_tool_none_result(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value=None):
|
|
||||||
result = await client.call_tool("get_weather", {"city": "London"})
|
|
||||||
|
|
||||||
assert result.is_error
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_initialize(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
mock_result = {
|
|
||||||
"protocolVersion": "2025-03-26",
|
|
||||||
"capabilities": {"tools": {}},
|
|
||||||
"serverInfo": {"name": "test-server", "version": "1.0.0"},
|
|
||||||
}
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(client, "_send_request", return_value=mock_result) as mock_req,
|
|
||||||
patch.object(client, "_send_notification") as mock_notif,
|
|
||||||
):
|
|
||||||
result = await client.initialize()
|
|
||||||
|
|
||||||
mock_req.assert_called_once()
|
|
||||||
mock_notif.assert_called_once_with("notifications/initialized")
|
|
||||||
assert result["protocolVersion"] == "2025-03-26"
|
|
||||||
|
|
||||||
|
|
||||||
# ── MCPToolBlock unit tests ──────────────────────────────────────────
|
|
||||||
|
|
||||||
MOCK_USER_ID = "test-user-123"
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPToolBlock:
|
|
||||||
"""Tests for the MCPToolBlock."""
|
|
||||||
|
|
||||||
def test_block_instantiation(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
assert block.id == "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
|
|
||||||
assert block.name == "MCPToolBlock"
|
|
||||||
|
|
||||||
def test_input_schema_has_required_fields(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
schema = block.input_schema.jsonschema()
|
|
||||||
props = schema.get("properties", {})
|
|
||||||
assert "server_url" in props
|
|
||||||
assert "selected_tool" in props
|
|
||||||
assert "tool_arguments" in props
|
|
||||||
assert "credentials" in props
|
|
||||||
|
|
||||||
def test_output_schema(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
schema = block.output_schema.jsonschema()
|
|
||||||
props = schema.get("properties", {})
|
|
||||||
assert "result" in props
|
|
||||||
assert "error" in props
|
|
||||||
|
|
||||||
def test_get_input_schema_with_tool_schema(self):
|
|
||||||
tool_schema = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"query": {"type": "string"}},
|
|
||||||
"required": ["query"],
|
|
||||||
}
|
|
||||||
data = {"tool_input_schema": tool_schema}
|
|
||||||
result = MCPToolBlock.Input.get_input_schema(data)
|
|
||||||
assert result == tool_schema
|
|
||||||
|
|
||||||
def test_get_input_schema_without_tool_schema(self):
|
|
||||||
result = MCPToolBlock.Input.get_input_schema({})
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
def test_get_input_defaults(self):
|
|
||||||
data = {"tool_arguments": {"city": "London"}}
|
|
||||||
result = MCPToolBlock.Input.get_input_defaults(data)
|
|
||||||
assert result == {"city": "London"}
|
|
||||||
|
|
||||||
def test_get_missing_input(self):
|
|
||||||
data = {
|
|
||||||
"tool_input_schema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"city": {"type": "string"},
|
|
||||||
"units": {"type": "string"},
|
|
||||||
},
|
|
||||||
"required": ["city", "units"],
|
|
||||||
},
|
|
||||||
"tool_arguments": {"city": "London"},
|
|
||||||
}
|
|
||||||
missing = MCPToolBlock.Input.get_missing_input(data)
|
|
||||||
assert missing == {"units"}
|
|
||||||
|
|
||||||
def test_get_missing_input_all_present(self):
|
|
||||||
data = {
|
|
||||||
"tool_input_schema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"city": {"type": "string"}},
|
|
||||||
"required": ["city"],
|
|
||||||
},
|
|
||||||
"tool_arguments": {"city": "London"},
|
|
||||||
}
|
|
||||||
missing = MCPToolBlock.Input.get_missing_input(data)
|
|
||||||
assert missing == set()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_with_mock(self):
|
|
||||||
"""Test the block using the built-in test infrastructure."""
|
|
||||||
block = MCPToolBlock()
|
|
||||||
await execute_block_test(block)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_missing_server_url(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="",
|
|
||||||
selected_tool="test",
|
|
||||||
)
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
assert outputs == [("error", "MCP server URL is required")]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_missing_tool(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="https://mcp.example.com/mcp",
|
|
||||||
selected_tool="",
|
|
||||||
)
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
assert outputs == [
|
|
||||||
("error", "No tool selected. Please select a tool from the dropdown.")
|
|
||||||
]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_success(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="https://mcp.example.com/mcp",
|
|
||||||
selected_tool="get_weather",
|
|
||||||
tool_input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"city": {"type": "string"}},
|
|
||||||
},
|
|
||||||
tool_arguments={"city": "London"},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_call(*args, **kwargs):
|
|
||||||
return {"temp": 20, "city": "London"}
|
|
||||||
|
|
||||||
block._call_mcp_tool = mock_call # type: ignore
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
assert outputs[0][1] == {"temp": 20, "city": "London"}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_mcp_error(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="https://mcp.example.com/mcp",
|
|
||||||
selected_tool="bad_tool",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_call(*args, **kwargs):
|
|
||||||
raise MCPClientError("Tool not found")
|
|
||||||
|
|
||||||
block._call_mcp_tool = mock_call # type: ignore
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert outputs[0][0] == "error"
|
|
||||||
assert "Tool not found" in outputs[0][1]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_mcp_tool_parses_json_text(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
|
|
||||||
mock_result = MCPCallResult(
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": '{"temp": 20}'},
|
|
||||||
],
|
|
||||||
is_error=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_init(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def mock_call(self, name, args):
|
|
||||||
return mock_result
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MCPClient, "initialize", mock_init),
|
|
||||||
patch.object(MCPClient, "call_tool", mock_call),
|
|
||||||
):
|
|
||||||
result = await block._call_mcp_tool(
|
|
||||||
"https://mcp.example.com", "test_tool", {}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == {"temp": 20}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_mcp_tool_plain_text(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
|
|
||||||
mock_result = MCPCallResult(
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": "Hello, world!"},
|
|
||||||
],
|
|
||||||
is_error=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_init(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def mock_call(self, name, args):
|
|
||||||
return mock_result
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MCPClient, "initialize", mock_init),
|
|
||||||
patch.object(MCPClient, "call_tool", mock_call),
|
|
||||||
):
|
|
||||||
result = await block._call_mcp_tool(
|
|
||||||
"https://mcp.example.com", "test_tool", {}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == "Hello, world!"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_mcp_tool_multiple_content(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
|
|
||||||
mock_result = MCPCallResult(
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": "Part 1"},
|
|
||||||
{"type": "text", "text": '{"part": 2}'},
|
|
||||||
],
|
|
||||||
is_error=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_init(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def mock_call(self, name, args):
|
|
||||||
return mock_result
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MCPClient, "initialize", mock_init),
|
|
||||||
patch.object(MCPClient, "call_tool", mock_call),
|
|
||||||
):
|
|
||||||
result = await block._call_mcp_tool(
|
|
||||||
"https://mcp.example.com", "test_tool", {}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == ["Part 1", {"part": 2}]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_mcp_tool_error_result(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
|
|
||||||
mock_result = MCPCallResult(
|
|
||||||
content=[{"type": "text", "text": "Something went wrong"}],
|
|
||||||
is_error=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_init(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def mock_call(self, name, args):
|
|
||||||
return mock_result
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MCPClient, "initialize", mock_init),
|
|
||||||
patch.object(MCPClient, "call_tool", mock_call),
|
|
||||||
):
|
|
||||||
with pytest.raises(MCPClientError, match="returned an error"):
|
|
||||||
await block._call_mcp_tool("https://mcp.example.com", "test_tool", {})
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_mcp_tool_image_content(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
|
|
||||||
mock_result = MCPCallResult(
|
|
||||||
content=[
|
|
||||||
{
|
|
||||||
"type": "image",
|
|
||||||
"data": "base64data==",
|
|
||||||
"mimeType": "image/png",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
is_error=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_init(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def mock_call(self, name, args):
|
|
||||||
return mock_result
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MCPClient, "initialize", mock_init),
|
|
||||||
patch.object(MCPClient, "call_tool", mock_call),
|
|
||||||
):
|
|
||||||
result = await block._call_mcp_tool(
|
|
||||||
"https://mcp.example.com", "test_tool", {}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == {
|
|
||||||
"type": "image",
|
|
||||||
"data": "base64data==",
|
|
||||||
"mimeType": "image/png",
|
|
||||||
}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_with_credentials(self):
|
|
||||||
"""Verify the block uses OAuth2Credentials and passes auth token."""
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="https://mcp.example.com/mcp",
|
|
||||||
selected_tool="test_tool",
|
|
||||||
)
|
|
||||||
|
|
||||||
captured_tokens: list[str | None] = []
|
|
||||||
|
|
||||||
async def mock_call(server_url, tool_name, arguments, auth_token=None):
|
|
||||||
captured_tokens.append(auth_token)
|
|
||||||
return "ok"
|
|
||||||
|
|
||||||
block._call_mcp_tool = mock_call # type: ignore
|
|
||||||
|
|
||||||
test_creds = OAuth2Credentials(
|
|
||||||
id="cred-123",
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("resolved-token"),
|
|
||||||
refresh_token=SecretStr(""),
|
|
||||||
scopes=[],
|
|
||||||
title="Test MCP credential",
|
|
||||||
)
|
|
||||||
|
|
||||||
async for _ in block.run(
|
|
||||||
input_data, user_id=MOCK_USER_ID, credentials=test_creds
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|
||||||
assert captured_tokens == ["resolved-token"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_without_credentials(self):
|
|
||||||
"""Verify the block works without credentials (public server)."""
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="https://mcp.example.com/mcp",
|
|
||||||
selected_tool="test_tool",
|
|
||||||
)
|
|
||||||
|
|
||||||
captured_tokens: list[str | None] = []
|
|
||||||
|
|
||||||
async def mock_call(server_url, tool_name, arguments, auth_token=None):
|
|
||||||
captured_tokens.append(auth_token)
|
|
||||||
return "ok"
|
|
||||||
|
|
||||||
block._call_mcp_tool = mock_call # type: ignore
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert captured_tokens == [None]
|
|
||||||
assert outputs == [("result", "ok")]
|
|
||||||
@@ -1,242 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests for MCP OAuth handler.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.blocks.mcp.client import MCPClient
|
|
||||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_response(json_data: dict, status: int = 200) -> MagicMock:
|
|
||||||
"""Create a mock Response with synchronous json() (matching Requests.Response)."""
|
|
||||||
resp = MagicMock()
|
|
||||||
resp.status = status
|
|
||||||
resp.ok = 200 <= status < 300
|
|
||||||
resp.json.return_value = json_data
|
|
||||||
return resp
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPOAuthHandler:
|
|
||||||
"""Tests for the MCPOAuthHandler."""
|
|
||||||
|
|
||||||
def _make_handler(self, **overrides) -> MCPOAuthHandler:
|
|
||||||
defaults = {
|
|
||||||
"client_id": "test-client-id",
|
|
||||||
"client_secret": "test-client-secret",
|
|
||||||
"redirect_uri": "https://app.example.com/callback",
|
|
||||||
"authorize_url": "https://auth.example.com/authorize",
|
|
||||||
"token_url": "https://auth.example.com/token",
|
|
||||||
}
|
|
||||||
defaults.update(overrides)
|
|
||||||
return MCPOAuthHandler(**defaults)
|
|
||||||
|
|
||||||
def test_get_login_url_basic(self):
|
|
||||||
handler = self._make_handler()
|
|
||||||
url = handler.get_login_url(
|
|
||||||
scopes=["read", "write"],
|
|
||||||
state="random-state-token",
|
|
||||||
code_challenge="S256-challenge-value",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "https://auth.example.com/authorize?" in url
|
|
||||||
assert "response_type=code" in url
|
|
||||||
assert "client_id=test-client-id" in url
|
|
||||||
assert "state=random-state-token" in url
|
|
||||||
assert "code_challenge=S256-challenge-value" in url
|
|
||||||
assert "code_challenge_method=S256" in url
|
|
||||||
assert "scope=read+write" in url
|
|
||||||
|
|
||||||
def test_get_login_url_with_resource(self):
|
|
||||||
handler = self._make_handler(resource_url="https://mcp.example.com/mcp")
|
|
||||||
url = handler.get_login_url(
|
|
||||||
scopes=[], state="state", code_challenge="challenge"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "resource=https" in url
|
|
||||||
|
|
||||||
def test_get_login_url_without_pkce(self):
|
|
||||||
handler = self._make_handler()
|
|
||||||
url = handler.get_login_url(scopes=["read"], state="state", code_challenge=None)
|
|
||||||
|
|
||||||
assert "code_challenge" not in url
|
|
||||||
assert "code_challenge_method" not in url
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_exchange_code_for_tokens(self):
|
|
||||||
handler = self._make_handler()
|
|
||||||
|
|
||||||
resp = _mock_response(
|
|
||||||
{
|
|
||||||
"access_token": "new-access-token",
|
|
||||||
"refresh_token": "new-refresh-token",
|
|
||||||
"expires_in": 3600,
|
|
||||||
"token_type": "Bearer",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.post = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
creds = await handler.exchange_code_for_tokens(
|
|
||||||
code="auth-code",
|
|
||||||
scopes=["read"],
|
|
||||||
code_verifier="pkce-verifier",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(creds, OAuth2Credentials)
|
|
||||||
assert creds.access_token.get_secret_value() == "new-access-token"
|
|
||||||
assert creds.refresh_token is not None
|
|
||||||
assert creds.refresh_token.get_secret_value() == "new-refresh-token"
|
|
||||||
assert creds.scopes == ["read"]
|
|
||||||
assert creds.access_token_expires_at is not None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_refresh_tokens(self):
|
|
||||||
handler = self._make_handler()
|
|
||||||
|
|
||||||
existing_creds = OAuth2Credentials(
|
|
||||||
id="existing-id",
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("old-token"),
|
|
||||||
refresh_token=SecretStr("old-refresh"),
|
|
||||||
scopes=["read"],
|
|
||||||
title="test",
|
|
||||||
)
|
|
||||||
|
|
||||||
resp = _mock_response(
|
|
||||||
{
|
|
||||||
"access_token": "refreshed-token",
|
|
||||||
"refresh_token": "new-refresh",
|
|
||||||
"expires_in": 3600,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.post = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
refreshed = await handler._refresh_tokens(existing_creds)
|
|
||||||
|
|
||||||
assert refreshed.id == "existing-id"
|
|
||||||
assert refreshed.access_token.get_secret_value() == "refreshed-token"
|
|
||||||
assert refreshed.refresh_token is not None
|
|
||||||
assert refreshed.refresh_token.get_secret_value() == "new-refresh"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_refresh_tokens_no_refresh_token(self):
|
|
||||||
handler = self._make_handler()
|
|
||||||
|
|
||||||
creds = OAuth2Credentials(
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("token"),
|
|
||||||
scopes=["read"],
|
|
||||||
title="test",
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="No refresh token"):
|
|
||||||
await handler._refresh_tokens(creds)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_revoke_tokens_no_url(self):
|
|
||||||
handler = self._make_handler(revoke_url=None)
|
|
||||||
|
|
||||||
creds = OAuth2Credentials(
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("token"),
|
|
||||||
scopes=[],
|
|
||||||
title="test",
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await handler.revoke_tokens(creds)
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_revoke_tokens_with_url(self):
|
|
||||||
handler = self._make_handler(revoke_url="https://auth.example.com/revoke")
|
|
||||||
|
|
||||||
creds = OAuth2Credentials(
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("token"),
|
|
||||||
scopes=[],
|
|
||||||
title="test",
|
|
||||||
)
|
|
||||||
|
|
||||||
resp = _mock_response({}, status=200)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.post = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
result = await handler.revoke_tokens(creds)
|
|
||||||
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPClientDiscovery:
|
|
||||||
"""Tests for MCPClient OAuth metadata discovery."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_discover_auth_found(self):
|
|
||||||
client = MCPClient("https://mcp.example.com/mcp")
|
|
||||||
|
|
||||||
metadata = {
|
|
||||||
"authorization_servers": ["https://auth.example.com"],
|
|
||||||
"resource": "https://mcp.example.com/mcp",
|
|
||||||
}
|
|
||||||
|
|
||||||
resp = _mock_response(metadata, status=200)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.get = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
result = await client.discover_auth()
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
assert result["authorization_servers"] == ["https://auth.example.com"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_discover_auth_not_found(self):
|
|
||||||
client = MCPClient("https://mcp.example.com/mcp")
|
|
||||||
|
|
||||||
resp = _mock_response({}, status=404)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.get = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
result = await client.discover_auth()
|
|
||||||
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_discover_auth_server_metadata(self):
|
|
||||||
client = MCPClient("https://mcp.example.com/mcp")
|
|
||||||
|
|
||||||
server_metadata = {
|
|
||||||
"issuer": "https://auth.example.com",
|
|
||||||
"authorization_endpoint": "https://auth.example.com/authorize",
|
|
||||||
"token_endpoint": "https://auth.example.com/token",
|
|
||||||
"registration_endpoint": "https://auth.example.com/register",
|
|
||||||
"code_challenge_methods_supported": ["S256"],
|
|
||||||
}
|
|
||||||
|
|
||||||
resp = _mock_response(server_metadata, status=200)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.get = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
result = await client.discover_auth_server_metadata(
|
|
||||||
"https://auth.example.com"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
assert result["authorization_endpoint"] == "https://auth.example.com/authorize"
|
|
||||||
assert result["token_endpoint"] == "https://auth.example.com/token"
|
|
||||||
@@ -1,162 +0,0 @@
|
|||||||
"""
|
|
||||||
Minimal MCP server for integration testing.
|
|
||||||
|
|
||||||
Implements the MCP Streamable HTTP transport (JSON-RPC 2.0 over HTTP POST)
|
|
||||||
with a few sample tools. Runs on localhost with a random available port.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from aiohttp import web
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Sample tools this test server exposes
|
|
||||||
TEST_TOOLS = [
|
|
||||||
{
|
|
||||||
"name": "get_weather",
|
|
||||||
"description": "Get current weather for a city",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"city": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "City name",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["city"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "add_numbers",
|
|
||||||
"description": "Add two numbers together",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"a": {"type": "number", "description": "First number"},
|
|
||||||
"b": {"type": "number", "description": "Second number"},
|
|
||||||
},
|
|
||||||
"required": ["a", "b"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "echo",
|
|
||||||
"description": "Echo back the input message",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"message": {"type": "string", "description": "Message to echo"},
|
|
||||||
},
|
|
||||||
"required": ["message"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_initialize(params: dict) -> dict:
|
|
||||||
return {
|
|
||||||
"protocolVersion": "2025-03-26",
|
|
||||||
"capabilities": {"tools": {"listChanged": False}},
|
|
||||||
"serverInfo": {"name": "test-mcp-server", "version": "1.0.0"},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_tools_list(params: dict) -> dict:
|
|
||||||
return {"tools": TEST_TOOLS}
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_tools_call(params: dict) -> dict:
|
|
||||||
tool_name = params.get("name", "")
|
|
||||||
arguments = params.get("arguments", {})
|
|
||||||
|
|
||||||
if tool_name == "get_weather":
|
|
||||||
city = arguments.get("city", "Unknown")
|
|
||||||
return {
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": json.dumps(
|
|
||||||
{"city": city, "temperature": 22, "condition": "sunny"}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
elif tool_name == "add_numbers":
|
|
||||||
a = arguments.get("a", 0)
|
|
||||||
b = arguments.get("b", 0)
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": json.dumps({"result": a + b})}],
|
|
||||||
}
|
|
||||||
|
|
||||||
elif tool_name == "echo":
|
|
||||||
message = arguments.get("message", "")
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": message}],
|
|
||||||
}
|
|
||||||
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
HANDLERS = {
|
|
||||||
"initialize": _handle_initialize,
|
|
||||||
"tools/list": _handle_tools_list,
|
|
||||||
"tools/call": _handle_tools_call,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_mcp_request(request: web.Request) -> web.Response:
|
|
||||||
"""Handle incoming MCP JSON-RPC 2.0 requests."""
|
|
||||||
# Check auth if configured
|
|
||||||
expected_token = request.app.get("auth_token")
|
|
||||||
if expected_token:
|
|
||||||
auth_header = request.headers.get("Authorization", "")
|
|
||||||
if auth_header != f"Bearer {expected_token}":
|
|
||||||
return web.json_response(
|
|
||||||
{
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"error": {"code": -32001, "message": "Unauthorized"},
|
|
||||||
"id": None,
|
|
||||||
},
|
|
||||||
status=401,
|
|
||||||
)
|
|
||||||
|
|
||||||
body = await request.json()
|
|
||||||
|
|
||||||
# Handle notifications (no id field) — just acknowledge
|
|
||||||
if "id" not in body:
|
|
||||||
return web.Response(status=202)
|
|
||||||
|
|
||||||
method = body.get("method", "")
|
|
||||||
params = body.get("params", {})
|
|
||||||
request_id = body.get("id")
|
|
||||||
|
|
||||||
handler = HANDLERS.get(method)
|
|
||||||
if not handler:
|
|
||||||
return web.json_response(
|
|
||||||
{
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"error": {
|
|
||||||
"code": -32601,
|
|
||||||
"message": f"Method not found: {method}",
|
|
||||||
},
|
|
||||||
"id": request_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
result = handler(params)
|
|
||||||
return web.json_response({"jsonrpc": "2.0", "result": result, "id": request_id})
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_mcp_app(auth_token: str | None = None) -> web.Application:
|
|
||||||
"""Create an aiohttp app that acts as an MCP server."""
|
|
||||||
app = web.Application()
|
|
||||||
app.router.add_post("/mcp", handle_mcp_request)
|
|
||||||
if auth_token:
|
|
||||||
app["auth_token"] = auth_token
|
|
||||||
return app
|
|
||||||
246
autogpt_platform/backend/backend/blocks/media.py
Normal file
246
autogpt_platform/backend/backend/blocks/media.py
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||||
|
from moviepy.video.fx.Loop import Loop
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class MediaDurationBlock(Block):
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
media_in: MediaFileType = SchemaField(
|
||||||
|
description="Media input (URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
is_video: bool = SchemaField(
|
||||||
|
description="Whether the media is a video (True) or audio (False).",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
duration: float = SchemaField(
|
||||||
|
description="Duration of the media file (in seconds)."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
|
||||||
|
description="Block to get the duration of a media file.",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=MediaDurationBlock.Input,
|
||||||
|
output_schema=MediaDurationBlock.Output,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
# 1) Store the input media locally
|
||||||
|
local_media_path = await store_media_file(
|
||||||
|
file=input_data.media_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
media_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, local_media_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2) Load the clip
|
||||||
|
if input_data.is_video:
|
||||||
|
clip = VideoFileClip(media_abspath)
|
||||||
|
else:
|
||||||
|
clip = AudioFileClip(media_abspath)
|
||||||
|
|
||||||
|
yield "duration", clip.duration
|
||||||
|
|
||||||
|
|
||||||
|
class LoopVideoBlock(Block):
|
||||||
|
"""
|
||||||
|
Block for looping (repeating) a video clip until a given duration or number of loops.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="The input video (can be a URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
# Provide EITHER a `duration` or `n_loops` or both. We'll demonstrate `duration`.
|
||||||
|
duration: Optional[float] = SchemaField(
|
||||||
|
description="Target duration (in seconds) to loop the video to. If omitted, defaults to no looping.",
|
||||||
|
default=None,
|
||||||
|
ge=0.0,
|
||||||
|
)
|
||||||
|
n_loops: Optional[int] = SchemaField(
|
||||||
|
description="Number of times to repeat the video. If omitted, defaults to 1 (no repeat).",
|
||||||
|
default=None,
|
||||||
|
ge=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: str = SchemaField(
|
||||||
|
description="Looped video returned either as a relative path or a data URI."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="8bf9eef6-5451-4213-b265-25306446e94b",
|
||||||
|
description="Block to loop a video to a given duration or number of repeats.",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=LoopVideoBlock.Input,
|
||||||
|
output_schema=LoopVideoBlock.Output,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
assert execution_context.node_exec_id is not None
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
node_exec_id = execution_context.node_exec_id
|
||||||
|
|
||||||
|
# 1) Store the input video locally
|
||||||
|
local_video_path = await store_media_file(
|
||||||
|
file=input_data.video_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||||
|
|
||||||
|
# 2) Load the clip
|
||||||
|
clip = VideoFileClip(input_abspath)
|
||||||
|
|
||||||
|
# 3) Apply the loop effect
|
||||||
|
looped_clip = clip
|
||||||
|
if input_data.duration:
|
||||||
|
# Loop until we reach the specified duration
|
||||||
|
looped_clip = looped_clip.with_effects([Loop(duration=input_data.duration)])
|
||||||
|
elif input_data.n_loops:
|
||||||
|
looped_clip = looped_clip.with_effects([Loop(n=input_data.n_loops)])
|
||||||
|
else:
|
||||||
|
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
|
||||||
|
|
||||||
|
assert isinstance(looped_clip, VideoFileClip)
|
||||||
|
|
||||||
|
# 4) Save the looped output
|
||||||
|
output_filename = MediaFileType(
|
||||||
|
f"{node_exec_id}_looped_{os.path.basename(local_video_path)}"
|
||||||
|
)
|
||||||
|
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
||||||
|
|
||||||
|
looped_clip = looped_clip.with_audio(clip.audio)
|
||||||
|
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||||
|
|
||||||
|
# Return output - for_block_output returns workspace:// if available, else data URI
|
||||||
|
video_out = await store_media_file(
|
||||||
|
file=output_filename,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
|
|
||||||
|
|
||||||
|
class AddAudioToVideoBlock(Block):
|
||||||
|
"""
|
||||||
|
Block that adds (attaches) an audio track to an existing video.
|
||||||
|
Optionally scale the volume of the new track.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="Video input (URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
audio_in: MediaFileType = SchemaField(
|
||||||
|
description="Audio input (URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
volume: float = SchemaField(
|
||||||
|
description="Volume scale for the newly attached audio track (1.0 = original).",
|
||||||
|
default=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Final video (with attached audio), as a path or data URI."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="3503748d-62b6-4425-91d6-725b064af509",
|
||||||
|
description="Block to attach an audio file to a video file using moviepy.",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=AddAudioToVideoBlock.Input,
|
||||||
|
output_schema=AddAudioToVideoBlock.Output,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
assert execution_context.node_exec_id is not None
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
node_exec_id = execution_context.node_exec_id
|
||||||
|
|
||||||
|
# 1) Store the inputs locally
|
||||||
|
local_video_path = await store_media_file(
|
||||||
|
file=input_data.video_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
local_audio_path = await store_media_file(
|
||||||
|
file=input_data.audio_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
||||||
|
video_abspath = os.path.join(abs_temp_dir, local_video_path)
|
||||||
|
audio_abspath = os.path.join(abs_temp_dir, local_audio_path)
|
||||||
|
|
||||||
|
# 2) Load video + audio with moviepy
|
||||||
|
video_clip = VideoFileClip(video_abspath)
|
||||||
|
audio_clip = AudioFileClip(audio_abspath)
|
||||||
|
# Optionally scale volume
|
||||||
|
if input_data.volume != 1.0:
|
||||||
|
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
|
||||||
|
|
||||||
|
# 3) Attach the new audio track
|
||||||
|
final_clip = video_clip.with_audio(audio_clip)
|
||||||
|
|
||||||
|
# 4) Write to output file
|
||||||
|
output_filename = MediaFileType(
|
||||||
|
f"{node_exec_id}_audio_attached_{os.path.basename(local_video_path)}"
|
||||||
|
)
|
||||||
|
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
||||||
|
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||||
|
|
||||||
|
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
||||||
|
video_out = await store_media_file(
|
||||||
|
file=output_filename,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
@@ -182,7 +182,10 @@ class StagehandObserveBlock(Block):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
|
||||||
logger.debug(f"OBSERVE: Using model provider {model_credentials.provider}")
|
logger.info(f"OBSERVE: Stagehand credentials: {stagehand_credentials}")
|
||||||
|
logger.info(
|
||||||
|
f"OBSERVE: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
||||||
|
)
|
||||||
|
|
||||||
with disable_signal_handling():
|
with disable_signal_handling():
|
||||||
stagehand = Stagehand(
|
stagehand = Stagehand(
|
||||||
@@ -279,7 +282,10 @@ class StagehandActBlock(Block):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
|
||||||
logger.debug(f"ACT: Using model provider {model_credentials.provider}")
|
logger.info(f"ACT: Stagehand credentials: {stagehand_credentials}")
|
||||||
|
logger.info(
|
||||||
|
f"ACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
||||||
|
)
|
||||||
|
|
||||||
with disable_signal_handling():
|
with disable_signal_handling():
|
||||||
stagehand = Stagehand(
|
stagehand = Stagehand(
|
||||||
@@ -364,7 +370,10 @@ class StagehandExtractBlock(Block):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
|
||||||
logger.debug(f"EXTRACT: Using model provider {model_credentials.provider}")
|
logger.info(f"EXTRACT: Stagehand credentials: {stagehand_credentials}")
|
||||||
|
logger.info(
|
||||||
|
f"EXTRACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
||||||
|
)
|
||||||
|
|
||||||
with disable_signal_handling():
|
with disable_signal_handling():
|
||||||
stagehand = Stagehand(
|
stagehand = Stagehand(
|
||||||
|
|||||||
@@ -1,77 +0,0 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
from backend.blocks.encoder_block import TextEncoderBlock
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_encoder_basic():
|
|
||||||
"""Test basic encoding of newlines and special characters."""
|
|
||||||
block = TextEncoderBlock()
|
|
||||||
result = []
|
|
||||||
async for output in block.run(TextEncoderBlock.Input(text="Hello\nWorld")):
|
|
||||||
result.append(output)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0][0] == "encoded_text"
|
|
||||||
assert result[0][1] == "Hello\\nWorld"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_encoder_multiple_escapes():
|
|
||||||
"""Test encoding of multiple escape sequences."""
|
|
||||||
block = TextEncoderBlock()
|
|
||||||
result = []
|
|
||||||
async for output in block.run(
|
|
||||||
TextEncoderBlock.Input(text="Line1\nLine2\tTabbed\rCarriage")
|
|
||||||
):
|
|
||||||
result.append(output)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0][0] == "encoded_text"
|
|
||||||
assert "\\n" in result[0][1]
|
|
||||||
assert "\\t" in result[0][1]
|
|
||||||
assert "\\r" in result[0][1]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_encoder_unicode():
|
|
||||||
"""Test that unicode characters are handled correctly."""
|
|
||||||
block = TextEncoderBlock()
|
|
||||||
result = []
|
|
||||||
async for output in block.run(TextEncoderBlock.Input(text="Hello 世界\n")):
|
|
||||||
result.append(output)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0][0] == "encoded_text"
|
|
||||||
# Unicode characters should be escaped as \uXXXX sequences
|
|
||||||
assert "\\n" in result[0][1]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_encoder_empty_string():
|
|
||||||
"""Test encoding of an empty string."""
|
|
||||||
block = TextEncoderBlock()
|
|
||||||
result = []
|
|
||||||
async for output in block.run(TextEncoderBlock.Input(text="")):
|
|
||||||
result.append(output)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0][0] == "encoded_text"
|
|
||||||
assert result[0][1] == ""
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_encoder_error_handling():
|
|
||||||
"""Test that encoding errors are handled gracefully."""
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
block = TextEncoderBlock()
|
|
||||||
result = []
|
|
||||||
|
|
||||||
with patch("codecs.encode", side_effect=Exception("Mocked encoding error")):
|
|
||||||
async for output in block.run(TextEncoderBlock.Input(text="test")):
|
|
||||||
result.append(output)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0][0] == "error"
|
|
||||||
assert "Mocked encoding error" in result[0][1]
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
"""Video editing blocks for AutoGPT Platform.
|
|
||||||
|
|
||||||
This module provides blocks for:
|
|
||||||
- Downloading videos from URLs (YouTube, Vimeo, news sites, direct links)
|
|
||||||
- Clipping/trimming video segments
|
|
||||||
- Concatenating multiple videos
|
|
||||||
- Adding text overlays
|
|
||||||
- Adding AI-generated narration
|
|
||||||
- Getting media duration
|
|
||||||
- Looping videos
|
|
||||||
- Adding audio to videos
|
|
||||||
|
|
||||||
Dependencies:
|
|
||||||
- yt-dlp: For video downloading
|
|
||||||
- moviepy: For video editing operations
|
|
||||||
- elevenlabs: For AI narration (optional)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from backend.blocks.video.add_audio import AddAudioToVideoBlock
|
|
||||||
from backend.blocks.video.clip import VideoClipBlock
|
|
||||||
from backend.blocks.video.concat import VideoConcatBlock
|
|
||||||
from backend.blocks.video.download import VideoDownloadBlock
|
|
||||||
from backend.blocks.video.duration import MediaDurationBlock
|
|
||||||
from backend.blocks.video.loop import LoopVideoBlock
|
|
||||||
from backend.blocks.video.narration import VideoNarrationBlock
|
|
||||||
from backend.blocks.video.text_overlay import VideoTextOverlayBlock
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AddAudioToVideoBlock",
|
|
||||||
"LoopVideoBlock",
|
|
||||||
"MediaDurationBlock",
|
|
||||||
"VideoClipBlock",
|
|
||||||
"VideoConcatBlock",
|
|
||||||
"VideoDownloadBlock",
|
|
||||||
"VideoNarrationBlock",
|
|
||||||
"VideoTextOverlayBlock",
|
|
||||||
]
|
|
||||||
@@ -1,131 +0,0 @@
|
|||||||
"""Shared utilities for video blocks."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import subprocess
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Known operation tags added by video blocks
|
|
||||||
_VIDEO_OPS = (
|
|
||||||
r"(?:clip|overlay|narrated|looped|concat|audio_attached|with_audio|narration)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Matches: {node_exec_id}_{operation}_ where node_exec_id contains a UUID
|
|
||||||
_BLOCK_PREFIX_RE = re.compile(
|
|
||||||
r"^[a-zA-Z0-9_-]*"
|
|
||||||
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
|
||||||
r"[a-zA-Z0-9_-]*"
|
|
||||||
r"_" + _VIDEO_OPS + r"_"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Matches: a lone {node_exec_id}_ prefix (no operation keyword, e.g. download output)
|
|
||||||
_UUID_PREFIX_RE = re.compile(
|
|
||||||
r"^[a-zA-Z0-9_-]*"
|
|
||||||
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
|
||||||
r"[a-zA-Z0-9_-]*_"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_source_name(input_path: str, max_length: int = 50) -> str:
|
|
||||||
"""Extract the original source filename by stripping block-generated prefixes.
|
|
||||||
|
|
||||||
Iteratively removes {node_exec_id}_{operation}_ prefixes that accumulate
|
|
||||||
when chaining video blocks, recovering the original human-readable name.
|
|
||||||
|
|
||||||
Safe for plain filenames (no UUID -> no stripping).
|
|
||||||
Falls back to "video" if everything is stripped.
|
|
||||||
"""
|
|
||||||
stem = Path(input_path).stem
|
|
||||||
|
|
||||||
# Pass 1: strip {node_exec_id}_{operation}_ prefixes iteratively
|
|
||||||
while _BLOCK_PREFIX_RE.match(stem):
|
|
||||||
stem = _BLOCK_PREFIX_RE.sub("", stem, count=1)
|
|
||||||
|
|
||||||
# Pass 2: strip a lone {node_exec_id}_ prefix (e.g. from download block)
|
|
||||||
if _UUID_PREFIX_RE.match(stem):
|
|
||||||
stem = _UUID_PREFIX_RE.sub("", stem, count=1)
|
|
||||||
|
|
||||||
if not stem:
|
|
||||||
return "video"
|
|
||||||
|
|
||||||
return stem[:max_length]
|
|
||||||
|
|
||||||
|
|
||||||
def get_video_codecs(output_path: str) -> tuple[str, str]:
|
|
||||||
"""Get appropriate video and audio codecs based on output file extension.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output_path: Path to the output file (used to determine extension)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (video_codec, audio_codec)
|
|
||||||
|
|
||||||
Codec mappings:
|
|
||||||
- .mp4: H.264 + AAC (universal compatibility)
|
|
||||||
- .webm: VP8 + Vorbis (web streaming)
|
|
||||||
- .mkv: H.264 + AAC (container supports many codecs)
|
|
||||||
- .mov: H.264 + AAC (Apple QuickTime, widely compatible)
|
|
||||||
- .m4v: H.264 + AAC (Apple iTunes/devices)
|
|
||||||
- .avi: MPEG-4 + MP3 (legacy Windows)
|
|
||||||
"""
|
|
||||||
ext = os.path.splitext(output_path)[1].lower()
|
|
||||||
|
|
||||||
codec_map: dict[str, tuple[str, str]] = {
|
|
||||||
".mp4": ("libx264", "aac"),
|
|
||||||
".webm": ("libvpx", "libvorbis"),
|
|
||||||
".mkv": ("libx264", "aac"),
|
|
||||||
".mov": ("libx264", "aac"),
|
|
||||||
".m4v": ("libx264", "aac"),
|
|
||||||
".avi": ("mpeg4", "libmp3lame"),
|
|
||||||
}
|
|
||||||
|
|
||||||
return codec_map.get(ext, ("libx264", "aac"))
|
|
||||||
|
|
||||||
|
|
||||||
def strip_chapters_inplace(video_path: str) -> None:
|
|
||||||
"""Strip chapter metadata from a media file in-place using ffmpeg.
|
|
||||||
|
|
||||||
MoviePy 2.x crashes with IndexError when parsing files with embedded
|
|
||||||
chapter metadata (https://github.com/Zulko/moviepy/issues/2419).
|
|
||||||
This strips chapters without re-encoding.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
video_path: Absolute path to the media file to strip chapters from.
|
|
||||||
"""
|
|
||||||
base, ext = os.path.splitext(video_path)
|
|
||||||
tmp_path = base + ".tmp" + ext
|
|
||||||
try:
|
|
||||||
result = subprocess.run(
|
|
||||||
[
|
|
||||||
"ffmpeg",
|
|
||||||
"-y",
|
|
||||||
"-i",
|
|
||||||
video_path,
|
|
||||||
"-map_chapters",
|
|
||||||
"-1",
|
|
||||||
"-codec",
|
|
||||||
"copy",
|
|
||||||
tmp_path,
|
|
||||||
],
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
timeout=300,
|
|
||||||
)
|
|
||||||
if result.returncode != 0:
|
|
||||||
logger.warning(
|
|
||||||
"ffmpeg chapter strip failed (rc=%d): %s",
|
|
||||||
result.returncode,
|
|
||||||
result.stderr,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
os.replace(tmp_path, video_path)
|
|
||||||
except FileNotFoundError:
|
|
||||||
logger.warning("ffmpeg not found; skipping chapter strip")
|
|
||||||
finally:
|
|
||||||
if os.path.exists(tmp_path):
|
|
||||||
os.unlink(tmp_path)
|
|
||||||
@@ -1,113 +0,0 @@
|
|||||||
"""AddAudioToVideoBlock - Attach an audio track to a video file."""
|
|
||||||
|
|
||||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.blocks.video._utils import extract_source_name, strip_chapters_inplace
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class AddAudioToVideoBlock(Block):
|
|
||||||
"""Add (attach) an audio track to an existing video."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="Video input (URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
audio_in: MediaFileType = SchemaField(
|
|
||||||
description="Audio input (URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
volume: float = SchemaField(
|
|
||||||
description="Volume scale for the newly attached audio track (1.0 = original).",
|
|
||||||
default=1.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: MediaFileType = SchemaField(
|
|
||||||
description="Final video (with attached audio), as a path or data URI."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="3503748d-62b6-4425-91d6-725b064af509",
|
|
||||||
description="Block to attach an audio file to a video file using moviepy.",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=AddAudioToVideoBlock.Input,
|
|
||||||
output_schema=AddAudioToVideoBlock.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
assert execution_context.node_exec_id is not None
|
|
||||||
graph_exec_id = execution_context.graph_exec_id
|
|
||||||
node_exec_id = execution_context.node_exec_id
|
|
||||||
|
|
||||||
# 1) Store the inputs locally
|
|
||||||
local_video_path = await store_media_file(
|
|
||||||
file=input_data.video_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
local_audio_path = await store_media_file(
|
|
||||||
file=input_data.audio_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
video_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
|
||||||
audio_abspath = get_exec_file_path(graph_exec_id, local_audio_path)
|
|
||||||
|
|
||||||
# 2) Load video + audio with moviepy
|
|
||||||
strip_chapters_inplace(video_abspath)
|
|
||||||
strip_chapters_inplace(audio_abspath)
|
|
||||||
video_clip = None
|
|
||||||
audio_clip = None
|
|
||||||
final_clip = None
|
|
||||||
try:
|
|
||||||
video_clip = VideoFileClip(video_abspath)
|
|
||||||
audio_clip = AudioFileClip(audio_abspath)
|
|
||||||
# Optionally scale volume
|
|
||||||
if input_data.volume != 1.0:
|
|
||||||
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
|
|
||||||
|
|
||||||
# 3) Attach the new audio track
|
|
||||||
final_clip = video_clip.with_audio(audio_clip)
|
|
||||||
|
|
||||||
# 4) Write to output file
|
|
||||||
source = extract_source_name(local_video_path)
|
|
||||||
output_filename = MediaFileType(f"{node_exec_id}_with_audio_{source}.mp4")
|
|
||||||
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
|
||||||
final_clip.write_videofile(
|
|
||||||
output_abspath, codec="libx264", audio_codec="aac"
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
if final_clip:
|
|
||||||
final_clip.close()
|
|
||||||
if audio_clip:
|
|
||||||
audio_clip.close()
|
|
||||||
if video_clip:
|
|
||||||
video_clip.close()
|
|
||||||
|
|
||||||
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
|
||||||
video_out = await store_media_file(
|
|
||||||
file=output_filename,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
@@ -1,167 +0,0 @@
|
|||||||
"""VideoClipBlock - Extract a segment from a video file."""
|
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.blocks.video._utils import (
|
|
||||||
extract_source_name,
|
|
||||||
get_video_codecs,
|
|
||||||
strip_chapters_inplace,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.exceptions import BlockExecutionError
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class VideoClipBlock(Block):
|
|
||||||
"""Extract a time segment from a video."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="Input video (URL, data URI, or local path)"
|
|
||||||
)
|
|
||||||
start_time: float = SchemaField(description="Start time in seconds", ge=0.0)
|
|
||||||
end_time: float = SchemaField(description="End time in seconds", ge=0.0)
|
|
||||||
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
|
|
||||||
description="Output format", default="mp4", advanced=True
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: MediaFileType = SchemaField(
|
|
||||||
description="Clipped video file (path or data URI)"
|
|
||||||
)
|
|
||||||
duration: float = SchemaField(description="Clip duration in seconds")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="8f539119-e580-4d86-ad41-86fbcb22abb1",
|
|
||||||
description="Extract a time segment from a video",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
test_input={
|
|
||||||
"video_in": "/tmp/test.mp4",
|
|
||||||
"start_time": 0.0,
|
|
||||||
"end_time": 10.0,
|
|
||||||
},
|
|
||||||
test_output=[("video_out", str), ("duration", float)],
|
|
||||||
test_mock={
|
|
||||||
"_clip_video": lambda *args: 10.0,
|
|
||||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
|
||||||
"_store_output_video": lambda *args, **kwargs: "clip_test.mp4",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_input_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store input video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_output_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store output video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _clip_video(
|
|
||||||
self,
|
|
||||||
video_abspath: str,
|
|
||||||
output_abspath: str,
|
|
||||||
start_time: float,
|
|
||||||
end_time: float,
|
|
||||||
) -> float:
|
|
||||||
"""Extract a clip from a video. Extracted for testability."""
|
|
||||||
clip = None
|
|
||||||
subclip = None
|
|
||||||
try:
|
|
||||||
strip_chapters_inplace(video_abspath)
|
|
||||||
clip = VideoFileClip(video_abspath)
|
|
||||||
subclip = clip.subclipped(start_time, end_time)
|
|
||||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
|
||||||
subclip.write_videofile(
|
|
||||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
|
||||||
)
|
|
||||||
return subclip.duration
|
|
||||||
finally:
|
|
||||||
if subclip:
|
|
||||||
subclip.close()
|
|
||||||
if clip:
|
|
||||||
clip.close()
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
node_exec_id: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
# Validate time range
|
|
||||||
if input_data.end_time <= input_data.start_time:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
|
|
||||||
# Store the input video locally
|
|
||||||
local_video_path = await self._store_input_video(
|
|
||||||
execution_context, input_data.video_in
|
|
||||||
)
|
|
||||||
video_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, local_video_path
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build output path
|
|
||||||
source = extract_source_name(local_video_path)
|
|
||||||
output_filename = MediaFileType(
|
|
||||||
f"{node_exec_id}_clip_{source}.{input_data.output_format}"
|
|
||||||
)
|
|
||||||
output_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, output_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
duration = self._clip_video(
|
|
||||||
video_abspath,
|
|
||||||
output_abspath,
|
|
||||||
input_data.start_time,
|
|
||||||
input_data.end_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return as workspace path or data URI based on context
|
|
||||||
video_out = await self._store_output_video(
|
|
||||||
execution_context, output_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
yield "duration", duration
|
|
||||||
|
|
||||||
except BlockExecutionError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"Failed to clip video: {e}",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
) from e
|
|
||||||
@@ -1,227 +0,0 @@
|
|||||||
"""VideoConcatBlock - Concatenate multiple video clips into one."""
|
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from moviepy import concatenate_videoclips
|
|
||||||
from moviepy.video.fx import CrossFadeIn, CrossFadeOut, FadeIn, FadeOut
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.blocks.video._utils import (
|
|
||||||
extract_source_name,
|
|
||||||
get_video_codecs,
|
|
||||||
strip_chapters_inplace,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.exceptions import BlockExecutionError
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class VideoConcatBlock(Block):
|
|
||||||
"""Merge multiple video clips into one continuous video."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
videos: list[MediaFileType] = SchemaField(
|
|
||||||
description="List of video files to concatenate (in order)"
|
|
||||||
)
|
|
||||||
transition: Literal["none", "crossfade", "fade_black"] = SchemaField(
|
|
||||||
description="Transition between clips", default="none"
|
|
||||||
)
|
|
||||||
transition_duration: int = SchemaField(
|
|
||||||
description="Transition duration in seconds",
|
|
||||||
default=1,
|
|
||||||
ge=0,
|
|
||||||
advanced=True,
|
|
||||||
)
|
|
||||||
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
|
|
||||||
description="Output format", default="mp4", advanced=True
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: MediaFileType = SchemaField(
|
|
||||||
description="Concatenated video file (path or data URI)"
|
|
||||||
)
|
|
||||||
total_duration: float = SchemaField(description="Total duration in seconds")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="9b0f531a-1118-487f-aeec-3fa63ea8900a",
|
|
||||||
description="Merge multiple video clips into one continuous video",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
test_input={
|
|
||||||
"videos": ["/tmp/a.mp4", "/tmp/b.mp4"],
|
|
||||||
},
|
|
||||||
test_output=[
|
|
||||||
("video_out", str),
|
|
||||||
("total_duration", float),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"_concat_videos": lambda *args: 20.0,
|
|
||||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
|
||||||
"_store_output_video": lambda *args, **kwargs: "concat_test.mp4",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_input_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store input video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_output_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store output video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _concat_videos(
|
|
||||||
self,
|
|
||||||
video_abspaths: list[str],
|
|
||||||
output_abspath: str,
|
|
||||||
transition: str,
|
|
||||||
transition_duration: int,
|
|
||||||
) -> float:
|
|
||||||
"""Concatenate videos. Extracted for testability.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Total duration of the concatenated video.
|
|
||||||
"""
|
|
||||||
clips = []
|
|
||||||
faded_clips = []
|
|
||||||
final = None
|
|
||||||
try:
|
|
||||||
# Load clips
|
|
||||||
for v in video_abspaths:
|
|
||||||
strip_chapters_inplace(v)
|
|
||||||
clips.append(VideoFileClip(v))
|
|
||||||
|
|
||||||
# Validate transition_duration against shortest clip
|
|
||||||
if transition in {"crossfade", "fade_black"} and transition_duration > 0:
|
|
||||||
min_duration = min(c.duration for c in clips)
|
|
||||||
if transition_duration >= min_duration:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=(
|
|
||||||
f"transition_duration ({transition_duration}s) must be "
|
|
||||||
f"shorter than the shortest clip ({min_duration:.2f}s)"
|
|
||||||
),
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
)
|
|
||||||
|
|
||||||
if transition == "crossfade":
|
|
||||||
for i, clip in enumerate(clips):
|
|
||||||
effects = []
|
|
||||||
if i > 0:
|
|
||||||
effects.append(CrossFadeIn(transition_duration))
|
|
||||||
if i < len(clips) - 1:
|
|
||||||
effects.append(CrossFadeOut(transition_duration))
|
|
||||||
if effects:
|
|
||||||
clip = clip.with_effects(effects)
|
|
||||||
faded_clips.append(clip)
|
|
||||||
final = concatenate_videoclips(
|
|
||||||
faded_clips,
|
|
||||||
method="compose",
|
|
||||||
padding=-transition_duration,
|
|
||||||
)
|
|
||||||
elif transition == "fade_black":
|
|
||||||
for clip in clips:
|
|
||||||
faded = clip.with_effects(
|
|
||||||
[FadeIn(transition_duration), FadeOut(transition_duration)]
|
|
||||||
)
|
|
||||||
faded_clips.append(faded)
|
|
||||||
final = concatenate_videoclips(faded_clips)
|
|
||||||
else:
|
|
||||||
final = concatenate_videoclips(clips)
|
|
||||||
|
|
||||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
|
||||||
final.write_videofile(
|
|
||||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
|
||||||
)
|
|
||||||
|
|
||||||
return final.duration
|
|
||||||
finally:
|
|
||||||
if final:
|
|
||||||
final.close()
|
|
||||||
for clip in faded_clips:
|
|
||||||
clip.close()
|
|
||||||
for clip in clips:
|
|
||||||
clip.close()
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
node_exec_id: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
# Validate minimum clips
|
|
||||||
if len(input_data.videos) < 2:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message="At least 2 videos are required for concatenation",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
|
|
||||||
# Store all input videos locally
|
|
||||||
video_abspaths = []
|
|
||||||
for video in input_data.videos:
|
|
||||||
local_path = await self._store_input_video(execution_context, video)
|
|
||||||
video_abspaths.append(
|
|
||||||
get_exec_file_path(execution_context.graph_exec_id, local_path)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build output path
|
|
||||||
source = (
|
|
||||||
extract_source_name(video_abspaths[0]) if video_abspaths else "video"
|
|
||||||
)
|
|
||||||
output_filename = MediaFileType(
|
|
||||||
f"{node_exec_id}_concat_{source}.{input_data.output_format}"
|
|
||||||
)
|
|
||||||
output_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, output_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
total_duration = self._concat_videos(
|
|
||||||
video_abspaths,
|
|
||||||
output_abspath,
|
|
||||||
input_data.transition,
|
|
||||||
input_data.transition_duration,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return as workspace path or data URI based on context
|
|
||||||
video_out = await self._store_output_video(
|
|
||||||
execution_context, output_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
yield "total_duration", total_duration
|
|
||||||
|
|
||||||
except BlockExecutionError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"Failed to concatenate videos: {e}",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
) from e
|
|
||||||
@@ -1,172 +0,0 @@
|
|||||||
"""VideoDownloadBlock - Download video from URL (YouTube, Vimeo, news sites, direct links)."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import typing
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import yt_dlp
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from yt_dlp import _Params
|
|
||||||
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.exceptions import BlockExecutionError
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class VideoDownloadBlock(Block):
|
|
||||||
"""Download video from URL using yt-dlp."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
url: str = SchemaField(
|
|
||||||
description="URL of the video to download (YouTube, Vimeo, direct link, etc.)",
|
|
||||||
placeholder="https://www.youtube.com/watch?v=...",
|
|
||||||
)
|
|
||||||
quality: Literal["best", "1080p", "720p", "480p", "audio_only"] = SchemaField(
|
|
||||||
description="Video quality preference", default="720p"
|
|
||||||
)
|
|
||||||
output_format: Literal["mp4", "webm", "mkv"] = SchemaField(
|
|
||||||
description="Output video format", default="mp4", advanced=True
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_file: MediaFileType = SchemaField(
|
|
||||||
description="Downloaded video (path or data URI)"
|
|
||||||
)
|
|
||||||
duration: float = SchemaField(description="Video duration in seconds")
|
|
||||||
title: str = SchemaField(description="Video title from source")
|
|
||||||
source_url: str = SchemaField(description="Original source URL")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="c35daabb-cd60-493b-b9ad-51f1fe4b50c4",
|
|
||||||
description="Download video from URL (YouTube, Vimeo, news sites, direct links)",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
disabled=True, # Disable until we can sandbox yt-dlp and handle security implications
|
|
||||||
test_input={
|
|
||||||
"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
|
|
||||||
"quality": "480p",
|
|
||||||
},
|
|
||||||
test_output=[
|
|
||||||
("video_file", str),
|
|
||||||
("duration", float),
|
|
||||||
("title", str),
|
|
||||||
("source_url", str),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"_download_video": lambda *args: (
|
|
||||||
"video.mp4",
|
|
||||||
212.0,
|
|
||||||
"Test Video",
|
|
||||||
),
|
|
||||||
"_store_output_video": lambda *args, **kwargs: "video.mp4",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_output_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store output video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_format_string(self, quality: str) -> str:
|
|
||||||
formats = {
|
|
||||||
"best": "bestvideo+bestaudio/best",
|
|
||||||
"1080p": "bestvideo[height<=1080]+bestaudio/best[height<=1080]",
|
|
||||||
"720p": "bestvideo[height<=720]+bestaudio/best[height<=720]",
|
|
||||||
"480p": "bestvideo[height<=480]+bestaudio/best[height<=480]",
|
|
||||||
"audio_only": "bestaudio/best",
|
|
||||||
}
|
|
||||||
return formats.get(quality, formats["720p"])
|
|
||||||
|
|
||||||
def _download_video(
|
|
||||||
self,
|
|
||||||
url: str,
|
|
||||||
quality: str,
|
|
||||||
output_format: str,
|
|
||||||
output_dir: str,
|
|
||||||
node_exec_id: str,
|
|
||||||
) -> tuple[str, float, str]:
|
|
||||||
"""Download video. Extracted for testability."""
|
|
||||||
output_template = os.path.join(
|
|
||||||
output_dir, f"{node_exec_id}_%(title).50s.%(ext)s"
|
|
||||||
)
|
|
||||||
|
|
||||||
ydl_opts: "_Params" = {
|
|
||||||
"format": f"{self._get_format_string(quality)}/best",
|
|
||||||
"outtmpl": output_template,
|
|
||||||
"merge_output_format": output_format,
|
|
||||||
"quiet": True,
|
|
||||||
"no_warnings": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
|
||||||
info = ydl.extract_info(url, download=True)
|
|
||||||
video_path = ydl.prepare_filename(info)
|
|
||||||
|
|
||||||
# Handle format conversion in filename
|
|
||||||
if not video_path.endswith(f".{output_format}"):
|
|
||||||
video_path = video_path.rsplit(".", 1)[0] + f".{output_format}"
|
|
||||||
|
|
||||||
# Return just the filename, not the full path
|
|
||||||
filename = os.path.basename(video_path)
|
|
||||||
|
|
||||||
return (
|
|
||||||
filename,
|
|
||||||
info.get("duration") or 0.0,
|
|
||||||
info.get("title") or "Unknown",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
node_exec_id: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
|
|
||||||
# Get the exec file directory
|
|
||||||
output_dir = get_exec_file_path(execution_context.graph_exec_id, "")
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
filename, duration, title = self._download_video(
|
|
||||||
input_data.url,
|
|
||||||
input_data.quality,
|
|
||||||
input_data.output_format,
|
|
||||||
output_dir,
|
|
||||||
node_exec_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return as workspace path or data URI based on context
|
|
||||||
video_out = await self._store_output_video(
|
|
||||||
execution_context, MediaFileType(filename)
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_file", video_out
|
|
||||||
yield "duration", duration
|
|
||||||
yield "title", title
|
|
||||||
yield "source_url", input_data.url
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"Failed to download video: {e}",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
) from e
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
"""MediaDurationBlock - Get the duration of a media file."""
|
|
||||||
|
|
||||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.blocks.video._utils import strip_chapters_inplace
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class MediaDurationBlock(Block):
|
|
||||||
"""Get the duration of a media file (video or audio)."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
media_in: MediaFileType = SchemaField(
|
|
||||||
description="Media input (URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
is_video: bool = SchemaField(
|
|
||||||
description="Whether the media is a video (True) or audio (False).",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
duration: float = SchemaField(
|
|
||||||
description="Duration of the media file (in seconds)."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
|
|
||||||
description="Block to get the duration of a media file.",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=MediaDurationBlock.Input,
|
|
||||||
output_schema=MediaDurationBlock.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
# 1) Store the input media locally
|
|
||||||
local_media_path = await store_media_file(
|
|
||||||
file=input_data.media_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
media_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, local_media_path
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2) Strip chapters to avoid MoviePy crash, then load the clip
|
|
||||||
strip_chapters_inplace(media_abspath)
|
|
||||||
clip = None
|
|
||||||
try:
|
|
||||||
if input_data.is_video:
|
|
||||||
clip = VideoFileClip(media_abspath)
|
|
||||||
else:
|
|
||||||
clip = AudioFileClip(media_abspath)
|
|
||||||
|
|
||||||
duration = clip.duration
|
|
||||||
finally:
|
|
||||||
if clip:
|
|
||||||
clip.close()
|
|
||||||
|
|
||||||
yield "duration", duration
|
|
||||||
@@ -1,115 +0,0 @@
|
|||||||
"""LoopVideoBlock - Loop a video to a given duration or number of repeats."""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from moviepy.video.fx.Loop import Loop
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.blocks.video._utils import extract_source_name, strip_chapters_inplace
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class LoopVideoBlock(Block):
|
|
||||||
"""Loop (repeat) a video clip until a given duration or number of loops."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="The input video (can be a URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
duration: Optional[float] = SchemaField(
|
|
||||||
description="Target duration (in seconds) to loop the video to. Either duration or n_loops must be provided.",
|
|
||||||
default=None,
|
|
||||||
ge=0.0,
|
|
||||||
le=3600.0, # Max 1 hour to prevent disk exhaustion
|
|
||||||
)
|
|
||||||
n_loops: Optional[int] = SchemaField(
|
|
||||||
description="Number of times to repeat the video. Either n_loops or duration must be provided.",
|
|
||||||
default=None,
|
|
||||||
ge=1,
|
|
||||||
le=10, # Max 10 loops to prevent disk exhaustion
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: MediaFileType = SchemaField(
|
|
||||||
description="Looped video returned either as a relative path or a data URI."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="8bf9eef6-5451-4213-b265-25306446e94b",
|
|
||||||
description="Block to loop a video to a given duration or number of repeats.",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=LoopVideoBlock.Input,
|
|
||||||
output_schema=LoopVideoBlock.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
assert execution_context.node_exec_id is not None
|
|
||||||
graph_exec_id = execution_context.graph_exec_id
|
|
||||||
node_exec_id = execution_context.node_exec_id
|
|
||||||
|
|
||||||
# 1) Store the input video locally
|
|
||||||
local_video_path = await store_media_file(
|
|
||||||
file=input_data.video_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
|
||||||
|
|
||||||
# 2) Load the clip
|
|
||||||
strip_chapters_inplace(input_abspath)
|
|
||||||
clip = None
|
|
||||||
looped_clip = None
|
|
||||||
try:
|
|
||||||
clip = VideoFileClip(input_abspath)
|
|
||||||
|
|
||||||
# 3) Apply the loop effect
|
|
||||||
if input_data.duration:
|
|
||||||
# Loop until we reach the specified duration
|
|
||||||
looped_clip = clip.with_effects([Loop(duration=input_data.duration)])
|
|
||||||
elif input_data.n_loops:
|
|
||||||
looped_clip = clip.with_effects([Loop(n=input_data.n_loops)])
|
|
||||||
else:
|
|
||||||
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
|
|
||||||
|
|
||||||
assert isinstance(looped_clip, VideoFileClip)
|
|
||||||
|
|
||||||
# 4) Save the looped output
|
|
||||||
source = extract_source_name(local_video_path)
|
|
||||||
output_filename = MediaFileType(f"{node_exec_id}_looped_{source}.mp4")
|
|
||||||
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
|
||||||
|
|
||||||
looped_clip = looped_clip.with_audio(clip.audio)
|
|
||||||
looped_clip.write_videofile(
|
|
||||||
output_abspath, codec="libx264", audio_codec="aac"
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
if looped_clip:
|
|
||||||
looped_clip.close()
|
|
||||||
if clip:
|
|
||||||
clip.close()
|
|
||||||
|
|
||||||
# Return output - for_block_output returns workspace:// if available, else data URI
|
|
||||||
video_out = await store_media_file(
|
|
||||||
file=output_filename,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
@@ -1,267 +0,0 @@
|
|||||||
"""VideoNarrationBlock - Generate AI voice narration and add to video."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from elevenlabs import ElevenLabs
|
|
||||||
from moviepy import CompositeAudioClip
|
|
||||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.blocks.elevenlabs._auth import (
|
|
||||||
TEST_CREDENTIALS,
|
|
||||||
TEST_CREDENTIALS_INPUT,
|
|
||||||
ElevenLabsCredentials,
|
|
||||||
ElevenLabsCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.blocks.video._utils import (
|
|
||||||
extract_source_name,
|
|
||||||
get_video_codecs,
|
|
||||||
strip_chapters_inplace,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import CredentialsField, SchemaField
|
|
||||||
from backend.util.exceptions import BlockExecutionError
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class VideoNarrationBlock(Block):
|
|
||||||
"""Generate AI narration and add to video."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: ElevenLabsCredentialsInput = CredentialsField(
|
|
||||||
description="ElevenLabs API key for voice synthesis"
|
|
||||||
)
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="Input video (URL, data URI, or local path)"
|
|
||||||
)
|
|
||||||
script: str = SchemaField(description="Narration script text")
|
|
||||||
voice_id: str = SchemaField(
|
|
||||||
description="ElevenLabs voice ID", default="21m00Tcm4TlvDq8ikWAM" # Rachel
|
|
||||||
)
|
|
||||||
model_id: Literal[
|
|
||||||
"eleven_multilingual_v2",
|
|
||||||
"eleven_flash_v2_5",
|
|
||||||
"eleven_turbo_v2_5",
|
|
||||||
"eleven_turbo_v2",
|
|
||||||
] = SchemaField(
|
|
||||||
description="ElevenLabs TTS model",
|
|
||||||
default="eleven_multilingual_v2",
|
|
||||||
)
|
|
||||||
mix_mode: Literal["replace", "mix", "ducking"] = SchemaField(
|
|
||||||
description="How to combine with original audio. 'ducking' applies stronger attenuation than 'mix'.",
|
|
||||||
default="ducking",
|
|
||||||
)
|
|
||||||
narration_volume: float = SchemaField(
|
|
||||||
description="Narration volume (0.0 to 2.0)",
|
|
||||||
default=1.0,
|
|
||||||
ge=0.0,
|
|
||||||
le=2.0,
|
|
||||||
advanced=True,
|
|
||||||
)
|
|
||||||
original_volume: float = SchemaField(
|
|
||||||
description="Original audio volume when mixing (0.0 to 1.0)",
|
|
||||||
default=0.3,
|
|
||||||
ge=0.0,
|
|
||||||
le=1.0,
|
|
||||||
advanced=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: MediaFileType = SchemaField(
|
|
||||||
description="Video with narration (path or data URI)"
|
|
||||||
)
|
|
||||||
audio_file: MediaFileType = SchemaField(
|
|
||||||
description="Generated audio file (path or data URI)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="3d036b53-859c-4b17-9826-ca340f736e0e",
|
|
||||||
description="Generate AI narration and add to video",
|
|
||||||
categories={BlockCategory.MULTIMEDIA, BlockCategory.AI},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
test_input={
|
|
||||||
"video_in": "/tmp/test.mp4",
|
|
||||||
"script": "Hello world",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[("video_out", str), ("audio_file", str)],
|
|
||||||
test_mock={
|
|
||||||
"_generate_narration_audio": lambda *args: b"mock audio content",
|
|
||||||
"_add_narration_to_video": lambda *args: None,
|
|
||||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
|
||||||
"_store_output_video": lambda *args, **kwargs: "narrated_test.mp4",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_input_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store input video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_output_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store output video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _generate_narration_audio(
|
|
||||||
self, api_key: str, script: str, voice_id: str, model_id: str
|
|
||||||
) -> bytes:
|
|
||||||
"""Generate narration audio via ElevenLabs API."""
|
|
||||||
client = ElevenLabs(api_key=api_key)
|
|
||||||
audio_generator = client.text_to_speech.convert(
|
|
||||||
voice_id=voice_id,
|
|
||||||
text=script,
|
|
||||||
model_id=model_id,
|
|
||||||
)
|
|
||||||
# The SDK returns a generator, collect all chunks
|
|
||||||
return b"".join(audio_generator)
|
|
||||||
|
|
||||||
def _add_narration_to_video(
|
|
||||||
self,
|
|
||||||
video_abspath: str,
|
|
||||||
audio_abspath: str,
|
|
||||||
output_abspath: str,
|
|
||||||
mix_mode: str,
|
|
||||||
narration_volume: float,
|
|
||||||
original_volume: float,
|
|
||||||
) -> None:
|
|
||||||
"""Add narration audio to video. Extracted for testability."""
|
|
||||||
video = None
|
|
||||||
final = None
|
|
||||||
narration_original = None
|
|
||||||
narration_scaled = None
|
|
||||||
original = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
strip_chapters_inplace(video_abspath)
|
|
||||||
video = VideoFileClip(video_abspath)
|
|
||||||
narration_original = AudioFileClip(audio_abspath)
|
|
||||||
narration_scaled = narration_original.with_volume_scaled(narration_volume)
|
|
||||||
narration = narration_scaled
|
|
||||||
|
|
||||||
if mix_mode == "replace":
|
|
||||||
final_audio = narration
|
|
||||||
elif mix_mode == "mix":
|
|
||||||
if video.audio:
|
|
||||||
original = video.audio.with_volume_scaled(original_volume)
|
|
||||||
final_audio = CompositeAudioClip([original, narration])
|
|
||||||
else:
|
|
||||||
final_audio = narration
|
|
||||||
else: # ducking - apply stronger attenuation
|
|
||||||
if video.audio:
|
|
||||||
# Ducking uses a much lower volume for original audio
|
|
||||||
ducking_volume = original_volume * 0.3
|
|
||||||
original = video.audio.with_volume_scaled(ducking_volume)
|
|
||||||
final_audio = CompositeAudioClip([original, narration])
|
|
||||||
else:
|
|
||||||
final_audio = narration
|
|
||||||
|
|
||||||
final = video.with_audio(final_audio)
|
|
||||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
|
||||||
final.write_videofile(
|
|
||||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
|
||||||
)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if original:
|
|
||||||
original.close()
|
|
||||||
if narration_scaled:
|
|
||||||
narration_scaled.close()
|
|
||||||
if narration_original:
|
|
||||||
narration_original.close()
|
|
||||||
if final:
|
|
||||||
final.close()
|
|
||||||
if video:
|
|
||||||
video.close()
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: ElevenLabsCredentials,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
node_exec_id: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
|
|
||||||
# Store the input video locally
|
|
||||||
local_video_path = await self._store_input_video(
|
|
||||||
execution_context, input_data.video_in
|
|
||||||
)
|
|
||||||
video_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, local_video_path
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate narration audio via ElevenLabs
|
|
||||||
audio_content = self._generate_narration_audio(
|
|
||||||
credentials.api_key.get_secret_value(),
|
|
||||||
input_data.script,
|
|
||||||
input_data.voice_id,
|
|
||||||
input_data.model_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save audio to exec file path
|
|
||||||
audio_filename = MediaFileType(f"{node_exec_id}_narration.mp3")
|
|
||||||
audio_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, audio_filename
|
|
||||||
)
|
|
||||||
os.makedirs(os.path.dirname(audio_abspath), exist_ok=True)
|
|
||||||
with open(audio_abspath, "wb") as f:
|
|
||||||
f.write(audio_content)
|
|
||||||
|
|
||||||
# Add narration to video
|
|
||||||
source = extract_source_name(local_video_path)
|
|
||||||
output_filename = MediaFileType(f"{node_exec_id}_narrated_{source}.mp4")
|
|
||||||
output_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, output_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
self._add_narration_to_video(
|
|
||||||
video_abspath,
|
|
||||||
audio_abspath,
|
|
||||||
output_abspath,
|
|
||||||
input_data.mix_mode,
|
|
||||||
input_data.narration_volume,
|
|
||||||
input_data.original_volume,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return as workspace path or data URI based on context
|
|
||||||
video_out = await self._store_output_video(
|
|
||||||
execution_context, output_filename
|
|
||||||
)
|
|
||||||
audio_out = await self._store_output_video(
|
|
||||||
execution_context, audio_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
yield "audio_file", audio_out
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"Failed to add narration: {e}",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
) from e
|
|
||||||
@@ -1,231 +0,0 @@
|
|||||||
"""VideoTextOverlayBlock - Add text overlay to video."""
|
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from moviepy import CompositeVideoClip, TextClip
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.blocks.video._utils import (
|
|
||||||
extract_source_name,
|
|
||||||
get_video_codecs,
|
|
||||||
strip_chapters_inplace,
|
|
||||||
)
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.exceptions import BlockExecutionError
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class VideoTextOverlayBlock(Block):
|
|
||||||
"""Add text overlay/caption to video."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="Input video (URL, data URI, or local path)"
|
|
||||||
)
|
|
||||||
text: str = SchemaField(description="Text to overlay on video")
|
|
||||||
position: Literal[
|
|
||||||
"top",
|
|
||||||
"center",
|
|
||||||
"bottom",
|
|
||||||
"top-left",
|
|
||||||
"top-right",
|
|
||||||
"bottom-left",
|
|
||||||
"bottom-right",
|
|
||||||
] = SchemaField(description="Position of text on screen", default="bottom")
|
|
||||||
start_time: float | None = SchemaField(
|
|
||||||
description="When to show text (seconds). None = entire video",
|
|
||||||
default=None,
|
|
||||||
advanced=True,
|
|
||||||
)
|
|
||||||
end_time: float | None = SchemaField(
|
|
||||||
description="When to hide text (seconds). None = until end",
|
|
||||||
default=None,
|
|
||||||
advanced=True,
|
|
||||||
)
|
|
||||||
font_size: int = SchemaField(
|
|
||||||
description="Font size", default=48, ge=12, le=200, advanced=True
|
|
||||||
)
|
|
||||||
font_color: str = SchemaField(
|
|
||||||
description="Font color (hex or name)", default="white", advanced=True
|
|
||||||
)
|
|
||||||
bg_color: str | None = SchemaField(
|
|
||||||
description="Background color behind text (None for transparent)",
|
|
||||||
default=None,
|
|
||||||
advanced=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: MediaFileType = SchemaField(
|
|
||||||
description="Video with text overlay (path or data URI)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="8ef14de6-cc90-430a-8cfa-3a003be92454",
|
|
||||||
description="Add text overlay/caption to video",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=self.Input,
|
|
||||||
output_schema=self.Output,
|
|
||||||
disabled=True, # Disable until we can lockdown imagemagick security policy
|
|
||||||
test_input={"video_in": "/tmp/test.mp4", "text": "Hello World"},
|
|
||||||
test_output=[("video_out", str)],
|
|
||||||
test_mock={
|
|
||||||
"_add_text_overlay": lambda *args: None,
|
|
||||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
|
||||||
"_store_output_video": lambda *args, **kwargs: "overlay_test.mp4",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_input_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store input video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _store_output_video(
|
|
||||||
self, execution_context: ExecutionContext, file: MediaFileType
|
|
||||||
) -> MediaFileType:
|
|
||||||
"""Store output video. Extracted for testability."""
|
|
||||||
return await store_media_file(
|
|
||||||
file=file,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_text_overlay(
|
|
||||||
self,
|
|
||||||
video_abspath: str,
|
|
||||||
output_abspath: str,
|
|
||||||
text: str,
|
|
||||||
position: str,
|
|
||||||
start_time: float | None,
|
|
||||||
end_time: float | None,
|
|
||||||
font_size: int,
|
|
||||||
font_color: str,
|
|
||||||
bg_color: str | None,
|
|
||||||
) -> None:
|
|
||||||
"""Add text overlay to video. Extracted for testability."""
|
|
||||||
video = None
|
|
||||||
final = None
|
|
||||||
txt_clip = None
|
|
||||||
try:
|
|
||||||
strip_chapters_inplace(video_abspath)
|
|
||||||
video = VideoFileClip(video_abspath)
|
|
||||||
|
|
||||||
txt_clip = TextClip(
|
|
||||||
text=text,
|
|
||||||
font_size=font_size,
|
|
||||||
color=font_color,
|
|
||||||
bg_color=bg_color,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Position mapping
|
|
||||||
pos_map = {
|
|
||||||
"top": ("center", "top"),
|
|
||||||
"center": ("center", "center"),
|
|
||||||
"bottom": ("center", "bottom"),
|
|
||||||
"top-left": ("left", "top"),
|
|
||||||
"top-right": ("right", "top"),
|
|
||||||
"bottom-left": ("left", "bottom"),
|
|
||||||
"bottom-right": ("right", "bottom"),
|
|
||||||
}
|
|
||||||
|
|
||||||
txt_clip = txt_clip.with_position(pos_map[position])
|
|
||||||
|
|
||||||
# Set timing
|
|
||||||
start = start_time or 0
|
|
||||||
end = end_time or video.duration
|
|
||||||
duration = max(0, end - start)
|
|
||||||
txt_clip = txt_clip.with_start(start).with_end(end).with_duration(duration)
|
|
||||||
|
|
||||||
final = CompositeVideoClip([video, txt_clip])
|
|
||||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
|
||||||
final.write_videofile(
|
|
||||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
|
||||||
)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if txt_clip:
|
|
||||||
txt_clip.close()
|
|
||||||
if final:
|
|
||||||
final.close()
|
|
||||||
if video:
|
|
||||||
video.close()
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
node_exec_id: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
# Validate time range if both are provided
|
|
||||||
if (
|
|
||||||
input_data.start_time is not None
|
|
||||||
and input_data.end_time is not None
|
|
||||||
and input_data.end_time <= input_data.start_time
|
|
||||||
):
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
|
|
||||||
# Store the input video locally
|
|
||||||
local_video_path = await self._store_input_video(
|
|
||||||
execution_context, input_data.video_in
|
|
||||||
)
|
|
||||||
video_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, local_video_path
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build output path
|
|
||||||
source = extract_source_name(local_video_path)
|
|
||||||
output_filename = MediaFileType(f"{node_exec_id}_overlay_{source}.mp4")
|
|
||||||
output_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, output_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
self._add_text_overlay(
|
|
||||||
video_abspath,
|
|
||||||
output_abspath,
|
|
||||||
input_data.text,
|
|
||||||
input_data.position,
|
|
||||||
input_data.start_time,
|
|
||||||
input_data.end_time,
|
|
||||||
input_data.font_size,
|
|
||||||
input_data.font_color,
|
|
||||||
input_data.bg_color,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return as workspace path or data URI based on context
|
|
||||||
video_out = await self._store_output_video(
|
|
||||||
execution_context, output_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
|
|
||||||
except BlockExecutionError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"Failed to add text overlay: {e}",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=str(self.id),
|
|
||||||
) from e
|
|
||||||
@@ -165,13 +165,10 @@ class TranscribeYoutubeVideoBlock(Block):
|
|||||||
credentials: WebshareProxyCredentials,
|
credentials: WebshareProxyCredentials,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
video_id = self.extract_video_id(input_data.youtube_url)
|
||||||
video_id = self.extract_video_id(input_data.youtube_url)
|
yield "video_id", video_id
|
||||||
transcript = self.get_transcript(video_id, credentials)
|
|
||||||
transcript_text = self.format_transcript(transcript=transcript)
|
|
||||||
|
|
||||||
# Only yield after all operations succeed
|
transcript = self.get_transcript(video_id, credentials)
|
||||||
yield "video_id", video_id
|
transcript_text = self.format_transcript(transcript=transcript)
|
||||||
yield "transcript", transcript_text
|
|
||||||
except Exception as e:
|
yield "transcript", transcript_text
|
||||||
yield "error", str(e)
|
|
||||||
|
|||||||
@@ -246,9 +246,7 @@ class BlockSchema(BaseModel):
|
|||||||
f"is not of type {CredentialsMetaInput.__name__}"
|
f"is not of type {CredentialsMetaInput.__name__}"
|
||||||
)
|
)
|
||||||
|
|
||||||
CredentialsMetaInput.validate_credentials_field_schema(
|
credentials_fields[field_name].validate_credentials_field_schema(cls)
|
||||||
cls.get_field_schema(field_name), field_name
|
|
||||||
)
|
|
||||||
|
|
||||||
elif field_name in credentials_fields:
|
elif field_name in credentials_fields:
|
||||||
raise KeyError(
|
raise KeyError(
|
||||||
@@ -875,13 +873,14 @@ def is_block_auth_configured(
|
|||||||
|
|
||||||
|
|
||||||
async def initialize_blocks() -> None:
|
async def initialize_blocks() -> None:
|
||||||
|
# First, sync all provider costs to blocks
|
||||||
|
# Imported here to avoid circular import
|
||||||
from backend.sdk.cost_integration import sync_all_provider_costs
|
from backend.sdk.cost_integration import sync_all_provider_costs
|
||||||
from backend.util.retry import func_retry
|
|
||||||
|
|
||||||
sync_all_provider_costs()
|
sync_all_provider_costs()
|
||||||
|
|
||||||
@func_retry
|
for cls in get_blocks().values():
|
||||||
async def sync_block_to_db(block: Block) -> None:
|
block = cls()
|
||||||
existing_block = await AgentBlock.prisma().find_first(
|
existing_block = await AgentBlock.prisma().find_first(
|
||||||
where={"OR": [{"id": block.id}, {"name": block.name}]}
|
where={"OR": [{"id": block.id}, {"name": block.name}]}
|
||||||
)
|
)
|
||||||
@@ -894,7 +893,7 @@ async def initialize_blocks() -> None:
|
|||||||
outputSchema=json.dumps(block.output_schema.jsonschema()),
|
outputSchema=json.dumps(block.output_schema.jsonschema()),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return
|
continue
|
||||||
|
|
||||||
input_schema = json.dumps(block.input_schema.jsonschema())
|
input_schema = json.dumps(block.input_schema.jsonschema())
|
||||||
output_schema = json.dumps(block.output_schema.jsonschema())
|
output_schema = json.dumps(block.output_schema.jsonschema())
|
||||||
@@ -914,25 +913,6 @@ async def initialize_blocks() -> None:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
failed_blocks: list[str] = []
|
|
||||||
for cls in get_blocks().values():
|
|
||||||
block = cls()
|
|
||||||
try:
|
|
||||||
await sync_block_to_db(block)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Failed to sync block {block.name} to database: {e}. "
|
|
||||||
"Block is still available in memory.",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
failed_blocks.append(block.name)
|
|
||||||
|
|
||||||
if failed_blocks:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to sync {len(failed_blocks)} block(s) to database: "
|
|
||||||
f"{', '.join(failed_blocks)}. These blocks are still available in memory."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
||||||
def get_block(block_id: str) -> AnyBlockSchema | None:
|
def get_block(block_id: str) -> AnyBlockSchema | None:
|
||||||
|
|||||||
@@ -36,14 +36,12 @@ from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
|||||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||||
from backend.blocks.video.narration import VideoNarrationBlock
|
|
||||||
from backend.data.block import Block, BlockCost, BlockCostType
|
from backend.data.block import Block, BlockCost, BlockCostType
|
||||||
from backend.integrations.credentials_store import (
|
from backend.integrations.credentials_store import (
|
||||||
aiml_api_credentials,
|
aiml_api_credentials,
|
||||||
anthropic_credentials,
|
anthropic_credentials,
|
||||||
apollo_credentials,
|
apollo_credentials,
|
||||||
did_credentials,
|
did_credentials,
|
||||||
elevenlabs_credentials,
|
|
||||||
enrichlayer_credentials,
|
enrichlayer_credentials,
|
||||||
groq_credentials,
|
groq_credentials,
|
||||||
ideogram_credentials,
|
ideogram_credentials,
|
||||||
@@ -80,7 +78,6 @@ MODEL_COST: dict[LlmModel, int] = {
|
|||||||
LlmModel.CLAUDE_4_1_OPUS: 21,
|
LlmModel.CLAUDE_4_1_OPUS: 21,
|
||||||
LlmModel.CLAUDE_4_OPUS: 21,
|
LlmModel.CLAUDE_4_OPUS: 21,
|
||||||
LlmModel.CLAUDE_4_SONNET: 5,
|
LlmModel.CLAUDE_4_SONNET: 5,
|
||||||
LlmModel.CLAUDE_4_6_OPUS: 14,
|
|
||||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||||
LlmModel.CLAUDE_4_5_OPUS: 14,
|
LlmModel.CLAUDE_4_5_OPUS: 14,
|
||||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||||
@@ -642,16 +639,4 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
VideoNarrationBlock: [
|
|
||||||
BlockCost(
|
|
||||||
cost_amount=5, # ElevenLabs TTS cost
|
|
||||||
cost_filter={
|
|
||||||
"credentials": {
|
|
||||||
"id": elevenlabs_credentials.id,
|
|
||||||
"provider": elevenlabs_credentials.provider,
|
|
||||||
"type": elevenlabs_credentials.type,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -134,16 +134,6 @@ async def test_block_credit_reset(server: SpinTestServer):
|
|||||||
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
|
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
|
||||||
user_credit.time_now = lambda: month1
|
user_credit.time_now = lambda: month1
|
||||||
|
|
||||||
# IMPORTANT: Set updatedAt to December of previous year to ensure it's
|
|
||||||
# in a different month than month1 (January). This fixes a timing bug
|
|
||||||
# where if the test runs in early February, 35 days ago would be January,
|
|
||||||
# matching the mocked month1 and preventing the refill from triggering.
|
|
||||||
dec_previous_year = month1.replace(year=month1.year - 1, month=12, day=15)
|
|
||||||
await UserBalance.prisma().update(
|
|
||||||
where={"userId": DEFAULT_USER_ID},
|
|
||||||
data={"updatedAt": dec_previous_year},
|
|
||||||
)
|
|
||||||
|
|
||||||
# First call in month 1 should trigger refill
|
# First call in month 1 should trigger refill
|
||||||
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||||
assert balance == REFILL_VALUE # Should get 1000 credits
|
assert balance == REFILL_VALUE # Should get 1000 credits
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Self, cast
|
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, cast
|
||||||
|
|
||||||
from prisma.enums import SubmissionStatus
|
from prisma.enums import SubmissionStatus
|
||||||
from prisma.models import (
|
from prisma.models import (
|
||||||
@@ -20,7 +20,7 @@ from prisma.types import (
|
|||||||
AgentNodeLinkCreateInput,
|
AgentNodeLinkCreateInput,
|
||||||
StoreListingVersionWhereInput,
|
StoreListingVersionWhereInput,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, BeforeValidator, Field
|
from pydantic import BaseModel, BeforeValidator, Field, create_model
|
||||||
from pydantic.fields import computed_field
|
from pydantic.fields import computed_field
|
||||||
|
|
||||||
from backend.blocks.agent import AgentExecutorBlock
|
from backend.blocks.agent import AgentExecutorBlock
|
||||||
@@ -30,6 +30,7 @@ from backend.data.db import prisma as db
|
|||||||
from backend.data.dynamic_fields import is_tool_pin, sanitize_pin_name
|
from backend.data.dynamic_fields import is_tool_pin, sanitize_pin_name
|
||||||
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
|
CredentialsField,
|
||||||
CredentialsFieldInfo,
|
CredentialsFieldInfo,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
is_credentials_field_name,
|
is_credentials_field_name,
|
||||||
@@ -39,12 +40,12 @@ from backend.util import type as type_utils
|
|||||||
from backend.util.exceptions import GraphNotAccessibleError, GraphNotInLibraryError
|
from backend.util.exceptions import GraphNotAccessibleError, GraphNotInLibraryError
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
from backend.util.request import parse_url
|
|
||||||
|
|
||||||
from .block import (
|
from .block import (
|
||||||
AnyBlockSchema,
|
AnyBlockSchema,
|
||||||
Block,
|
Block,
|
||||||
BlockInput,
|
BlockInput,
|
||||||
|
BlockSchema,
|
||||||
BlockType,
|
BlockType,
|
||||||
EmptySchema,
|
EmptySchema,
|
||||||
get_block,
|
get_block,
|
||||||
@@ -112,12 +113,10 @@ class Link(BaseDbModel):
|
|||||||
|
|
||||||
class Node(BaseDbModel):
|
class Node(BaseDbModel):
|
||||||
block_id: str
|
block_id: str
|
||||||
input_default: BlockInput = Field( # dict[input_name, default_value]
|
input_default: BlockInput = {} # dict[input_name, default_value]
|
||||||
default_factory=dict
|
metadata: dict[str, Any] = {}
|
||||||
)
|
input_links: list[Link] = []
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
output_links: list[Link] = []
|
||||||
input_links: list[Link] = Field(default_factory=list)
|
|
||||||
output_links: list[Link] = Field(default_factory=list)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def credentials_optional(self) -> bool:
|
def credentials_optional(self) -> bool:
|
||||||
@@ -222,33 +221,18 @@ class NodeModel(Node):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class GraphBaseMeta(BaseDbModel):
|
class BaseGraph(BaseDbModel):
|
||||||
"""
|
|
||||||
Shared base for `GraphMeta` and `BaseGraph`, with core graph metadata fields.
|
|
||||||
"""
|
|
||||||
|
|
||||||
version: int = 1
|
version: int = 1
|
||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
instructions: str | None = None
|
instructions: str | None = None
|
||||||
recommended_schedule_cron: str | None = None
|
recommended_schedule_cron: str | None = None
|
||||||
|
nodes: list[Node] = []
|
||||||
|
links: list[Link] = []
|
||||||
forked_from_id: str | None = None
|
forked_from_id: str | None = None
|
||||||
forked_from_version: int | None = None
|
forked_from_version: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class BaseGraph(GraphBaseMeta):
|
|
||||||
"""
|
|
||||||
Graph with nodes, links, and computed I/O schema fields.
|
|
||||||
|
|
||||||
Used to represent sub-graphs within a `Graph`. Contains the full graph
|
|
||||||
structure including nodes and links, plus computed fields for schemas
|
|
||||||
and trigger info. Does NOT include user_id or created_at (see GraphModel).
|
|
||||||
"""
|
|
||||||
|
|
||||||
nodes: list[Node] = Field(default_factory=list)
|
|
||||||
links: list[Link] = Field(default_factory=list)
|
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def input_schema(self) -> dict[str, Any]:
|
def input_schema(self) -> dict[str, Any]:
|
||||||
@@ -377,79 +361,44 @@ class GraphTriggerInfo(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Graph(BaseGraph):
|
class Graph(BaseGraph):
|
||||||
"""Creatable graph model used in API create/update endpoints."""
|
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs
|
||||||
|
|
||||||
sub_graphs: list[BaseGraph] = Field(default_factory=list) # Flattened sub-graphs
|
|
||||||
|
|
||||||
|
|
||||||
class GraphMeta(GraphBaseMeta):
|
|
||||||
"""
|
|
||||||
Lightweight graph metadata model representing an existing graph from the database,
|
|
||||||
for use in listings and summaries.
|
|
||||||
|
|
||||||
Lacks `GraphModel`'s nodes, links, and expensive computed fields.
|
|
||||||
Use for list endpoints where full graph data is not needed and performance matters.
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: str # type: ignore
|
|
||||||
version: int # type: ignore
|
|
||||||
user_id: str
|
|
||||||
created_at: datetime
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_db(cls, graph: "AgentGraph") -> Self:
|
|
||||||
return cls(
|
|
||||||
id=graph.id,
|
|
||||||
version=graph.version,
|
|
||||||
is_active=graph.isActive,
|
|
||||||
name=graph.name or "",
|
|
||||||
description=graph.description or "",
|
|
||||||
instructions=graph.instructions,
|
|
||||||
recommended_schedule_cron=graph.recommendedScheduleCron,
|
|
||||||
forked_from_id=graph.forkedFromId,
|
|
||||||
forked_from_version=graph.forkedFromVersion,
|
|
||||||
user_id=graph.userId,
|
|
||||||
created_at=graph.createdAt,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GraphModel(Graph, GraphMeta):
|
|
||||||
"""
|
|
||||||
Full graph model representing an existing graph from the database.
|
|
||||||
|
|
||||||
This is the primary model for working with persisted graphs. Includes all
|
|
||||||
graph data (nodes, links, sub_graphs) plus user ownership and timestamps.
|
|
||||||
Provides computed fields (input_schema, output_schema, etc.) used during
|
|
||||||
set-up (frontend) and execution (backend).
|
|
||||||
|
|
||||||
Inherits from:
|
|
||||||
- `Graph`: provides structure (nodes, links, sub_graphs) and computed schemas
|
|
||||||
- `GraphMeta`: provides user_id, created_at for database records
|
|
||||||
"""
|
|
||||||
|
|
||||||
nodes: list[NodeModel] = Field(default_factory=list) # type: ignore
|
|
||||||
|
|
||||||
@property
|
|
||||||
def starting_nodes(self) -> list[NodeModel]:
|
|
||||||
outbound_nodes = {link.sink_id for link in self.links}
|
|
||||||
input_nodes = {
|
|
||||||
node.id for node in self.nodes if node.block.block_type == BlockType.INPUT
|
|
||||||
}
|
|
||||||
return [
|
|
||||||
node
|
|
||||||
for node in self.nodes
|
|
||||||
if node.id not in outbound_nodes or node.id in input_nodes
|
|
||||||
]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def webhook_input_node(self) -> NodeModel | None: # type: ignore
|
|
||||||
return cast(NodeModel, super().webhook_input_node)
|
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def credentials_input_schema(self) -> dict[str, Any]:
|
def credentials_input_schema(self) -> dict[str, Any]:
|
||||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
schema = self._credentials_input_schema.jsonschema()
|
||||||
|
|
||||||
|
# Determine which credential fields are required based on credentials_optional metadata
|
||||||
|
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||||
|
required_fields = []
|
||||||
|
|
||||||
|
# Build a map of node_id -> node for quick lookup
|
||||||
|
all_nodes = {node.id: node for node in self.nodes}
|
||||||
|
for sub_graph in self.sub_graphs:
|
||||||
|
for node in sub_graph.nodes:
|
||||||
|
all_nodes[node.id] = node
|
||||||
|
|
||||||
|
for field_key, (
|
||||||
|
_field_info,
|
||||||
|
node_field_pairs,
|
||||||
|
) in graph_credentials_inputs.items():
|
||||||
|
# A field is required if ANY node using it has credentials_optional=False
|
||||||
|
is_required = False
|
||||||
|
for node_id, _field_name in node_field_pairs:
|
||||||
|
node = all_nodes.get(node_id)
|
||||||
|
if node and not node.credentials_optional:
|
||||||
|
is_required = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if is_required:
|
||||||
|
required_fields.append(field_key)
|
||||||
|
|
||||||
|
schema["required"] = required_fields
|
||||||
|
return schema
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _credentials_input_schema(self) -> type[BlockSchema]:
|
||||||
|
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
||||||
f"{graph_credentials_inputs}"
|
f"{graph_credentials_inputs}"
|
||||||
@@ -457,15 +406,12 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
|
|
||||||
# Warn if same-provider credentials inputs can't be combined (= bad UX)
|
# Warn if same-provider credentials inputs can't be combined (= bad UX)
|
||||||
graph_cred_fields = list(graph_credentials_inputs.values())
|
graph_cred_fields = list(graph_credentials_inputs.values())
|
||||||
for i, (field, keys, _) in enumerate(graph_cred_fields):
|
for i, (field, keys) in enumerate(graph_cred_fields):
|
||||||
for other_field, other_keys, _ in list(graph_cred_fields)[i + 1 :]:
|
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
|
||||||
if field.provider != other_field.provider:
|
if field.provider != other_field.provider:
|
||||||
continue
|
continue
|
||||||
if ProviderName.HTTP in field.provider:
|
if ProviderName.HTTP in field.provider:
|
||||||
continue
|
continue
|
||||||
# MCP credentials are intentionally split by server URL
|
|
||||||
if ProviderName.MCP in field.provider:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# If this happens, that means a block implementation probably needs
|
# If this happens, that means a block implementation probably needs
|
||||||
# to be updated.
|
# to be updated.
|
||||||
@@ -477,90 +423,31 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
f"keys: {keys} <> {other_keys}."
|
f"keys: {keys} <> {other_keys}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build JSON schema directly to avoid expensive create_model + validation overhead
|
fields: dict[str, tuple[type[CredentialsMetaInput], CredentialsMetaInput]] = {
|
||||||
properties = {}
|
agg_field_key: (
|
||||||
required_fields = []
|
CredentialsMetaInput[
|
||||||
|
Literal[tuple(field_info.provider)], # type: ignore
|
||||||
for agg_field_key, (
|
Literal[tuple(field_info.supported_types)], # type: ignore
|
||||||
field_info,
|
],
|
||||||
_,
|
CredentialsField(
|
||||||
is_required,
|
required_scopes=set(field_info.required_scopes or []),
|
||||||
) in graph_credentials_inputs.items():
|
discriminator=field_info.discriminator,
|
||||||
providers = list(field_info.provider)
|
discriminator_mapping=field_info.discriminator_mapping,
|
||||||
cred_types = list(field_info.supported_types)
|
discriminator_values=field_info.discriminator_values,
|
||||||
|
),
|
||||||
field_schema: dict[str, Any] = {
|
|
||||||
"credentials_provider": providers,
|
|
||||||
"credentials_types": cred_types,
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"id": {"title": "Id", "type": "string"},
|
|
||||||
"title": {
|
|
||||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
|
||||||
"default": None,
|
|
||||||
"title": "Title",
|
|
||||||
},
|
|
||||||
"provider": {
|
|
||||||
"title": "Provider",
|
|
||||||
"type": "string",
|
|
||||||
**(
|
|
||||||
{"enum": providers}
|
|
||||||
if len(providers) > 1
|
|
||||||
else {"const": providers[0]}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"title": "Type",
|
|
||||||
"type": "string",
|
|
||||||
**(
|
|
||||||
{"enum": cred_types}
|
|
||||||
if len(cred_types) > 1
|
|
||||||
else {"const": cred_types[0]}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["id", "provider", "type"],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add a descriptive display title when URL-based discriminator values
|
|
||||||
# are present (e.g. "mcp.sentry.dev" instead of just "Mcp")
|
|
||||||
if (
|
|
||||||
field_info.discriminator
|
|
||||||
and not field_info.discriminator_mapping
|
|
||||||
and field_info.discriminator_values
|
|
||||||
):
|
|
||||||
hostnames = sorted(
|
|
||||||
parse_url(str(v)).netloc for v in field_info.discriminator_values
|
|
||||||
)
|
|
||||||
field_schema["display_name"] = ", ".join(hostnames)
|
|
||||||
|
|
||||||
# Add other (optional) field info items
|
|
||||||
field_schema.update(
|
|
||||||
field_info.model_dump(
|
|
||||||
by_alias=True,
|
|
||||||
exclude_defaults=True,
|
|
||||||
exclude={"provider", "supported_types"}, # already included above
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
|
||||||
# Ensure field schema is well-formed
|
|
||||||
CredentialsMetaInput.validate_credentials_field_schema(
|
|
||||||
field_schema, agg_field_key
|
|
||||||
)
|
|
||||||
|
|
||||||
properties[agg_field_key] = field_schema
|
|
||||||
if is_required:
|
|
||||||
required_fields.append(agg_field_key)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": properties,
|
|
||||||
"required": required_fields,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return create_model(
|
||||||
|
self.name.replace(" ", "") + "CredentialsInputSchema",
|
||||||
|
__base__=BlockSchema,
|
||||||
|
**fields, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
def aggregate_credentials_inputs(
|
def aggregate_credentials_inputs(
|
||||||
self,
|
self,
|
||||||
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]:
|
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]]]]:
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
dict[aggregated_field_key, tuple(
|
dict[aggregated_field_key, tuple(
|
||||||
@@ -568,19 +455,13 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
(now includes discriminator_values from matching nodes)
|
(now includes discriminator_values from matching nodes)
|
||||||
set[(node_id, field_name)]: Node credentials fields that are
|
set[(node_id, field_name)]: Node credentials fields that are
|
||||||
compatible with this aggregated field spec
|
compatible with this aggregated field spec
|
||||||
bool: True if the field is required (any node has credentials_optional=False)
|
|
||||||
)]
|
)]
|
||||||
"""
|
"""
|
||||||
# First collect all credential field data with input defaults
|
# First collect all credential field data with input defaults
|
||||||
# Track (field_info, (node_id, field_name), is_required) for each credential field
|
node_credential_data = []
|
||||||
node_credential_data: list[tuple[CredentialsFieldInfo, tuple[str, str]]] = []
|
|
||||||
node_required_map: dict[str, bool] = {} # node_id -> is_required
|
|
||||||
|
|
||||||
for graph in [self] + self.sub_graphs:
|
for graph in [self] + self.sub_graphs:
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
# Track if this node requires credentials (credentials_optional=False means required)
|
|
||||||
node_required_map[node.id] = not node.credentials_optional
|
|
||||||
|
|
||||||
for (
|
for (
|
||||||
field_name,
|
field_name,
|
||||||
field_info,
|
field_info,
|
||||||
@@ -604,21 +485,37 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Combine credential field info (this will merge discriminator_values automatically)
|
# Combine credential field info (this will merge discriminator_values automatically)
|
||||||
combined = CredentialsFieldInfo.combine(*node_credential_data)
|
return CredentialsFieldInfo.combine(*node_credential_data)
|
||||||
|
|
||||||
# Add is_required flag to each aggregated field
|
|
||||||
# A field is required if ANY node using it has credentials_optional=False
|
class GraphModel(Graph):
|
||||||
return {
|
user_id: str
|
||||||
key: (
|
nodes: list[NodeModel] = [] # type: ignore
|
||||||
field_info,
|
|
||||||
node_field_pairs,
|
created_at: datetime
|
||||||
any(
|
|
||||||
node_required_map.get(node_id, True)
|
@property
|
||||||
for node_id, _ in node_field_pairs
|
def starting_nodes(self) -> list[NodeModel]:
|
||||||
),
|
outbound_nodes = {link.sink_id for link in self.links}
|
||||||
)
|
input_nodes = {
|
||||||
for key, (field_info, node_field_pairs) in combined.items()
|
node.id for node in self.nodes if node.block.block_type == BlockType.INPUT
|
||||||
}
|
}
|
||||||
|
return [
|
||||||
|
node
|
||||||
|
for node in self.nodes
|
||||||
|
if node.id not in outbound_nodes or node.id in input_nodes
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def webhook_input_node(self) -> NodeModel | None: # type: ignore
|
||||||
|
return cast(NodeModel, super().webhook_input_node)
|
||||||
|
|
||||||
|
def meta(self) -> "GraphMeta":
|
||||||
|
"""
|
||||||
|
Returns a GraphMeta object with metadata about the graph.
|
||||||
|
This is used to return metadata about the graph without exposing nodes and links.
|
||||||
|
"""
|
||||||
|
return GraphMeta.from_graph(self)
|
||||||
|
|
||||||
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
||||||
"""
|
"""
|
||||||
@@ -902,14 +799,13 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
if is_static_output_block(link.source_id):
|
if is_static_output_block(link.source_id):
|
||||||
link.is_static = True # Each value block output should be static.
|
link.is_static = True # Each value block output should be static.
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def from_db( # type: ignore[reportIncompatibleMethodOverride]
|
def from_db(
|
||||||
cls,
|
|
||||||
graph: AgentGraph,
|
graph: AgentGraph,
|
||||||
for_export: bool = False,
|
for_export: bool = False,
|
||||||
sub_graphs: list[AgentGraph] | None = None,
|
sub_graphs: list[AgentGraph] | None = None,
|
||||||
) -> Self:
|
) -> "GraphModel":
|
||||||
return cls(
|
return GraphModel(
|
||||||
id=graph.id,
|
id=graph.id,
|
||||||
user_id=graph.userId if not for_export else "",
|
user_id=graph.userId if not for_export else "",
|
||||||
version=graph.version,
|
version=graph.version,
|
||||||
@@ -935,28 +831,17 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def hide_nodes(self) -> "GraphModelWithoutNodes":
|
|
||||||
"""
|
|
||||||
Returns a copy of the `GraphModel` with nodes, links, and sub-graphs hidden
|
|
||||||
(excluded from serialization). They are still present in the model instance
|
|
||||||
so all computed fields (e.g. `credentials_input_schema`) still work.
|
|
||||||
"""
|
|
||||||
return GraphModelWithoutNodes.model_validate(self, from_attributes=True)
|
|
||||||
|
|
||||||
|
class GraphMeta(Graph):
|
||||||
|
user_id: str
|
||||||
|
|
||||||
class GraphModelWithoutNodes(GraphModel):
|
# Easy work-around to prevent exposing nodes and links in the API response
|
||||||
"""
|
nodes: list[NodeModel] = Field(default=[], exclude=True) # type: ignore
|
||||||
GraphModel variant that excludes nodes, links, and sub-graphs from serialization.
|
links: list[Link] = Field(default=[], exclude=True)
|
||||||
|
|
||||||
Used in contexts like the store where exposing internal graph structure
|
@staticmethod
|
||||||
is not desired. Inherits all computed fields from GraphModel but marks
|
def from_graph(graph: GraphModel) -> "GraphMeta":
|
||||||
nodes and links as excluded from JSON output.
|
return GraphMeta(**graph.model_dump())
|
||||||
"""
|
|
||||||
|
|
||||||
nodes: list[NodeModel] = Field(default_factory=list, exclude=True)
|
|
||||||
links: list[Link] = Field(default_factory=list, exclude=True)
|
|
||||||
|
|
||||||
sub_graphs: list[BaseGraph] = Field(default_factory=list, exclude=True)
|
|
||||||
|
|
||||||
|
|
||||||
class GraphsPaginated(BaseModel):
|
class GraphsPaginated(BaseModel):
|
||||||
@@ -1027,11 +912,21 @@ async def list_graphs_paginated(
|
|||||||
where=where_clause,
|
where=where_clause,
|
||||||
distinct=["id"],
|
distinct=["id"],
|
||||||
order={"version": "desc"},
|
order={"version": "desc"},
|
||||||
|
include=AGENT_GRAPH_INCLUDE,
|
||||||
skip=offset,
|
skip=offset,
|
||||||
take=page_size,
|
take=page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
graph_models = [GraphMeta.from_db(graph) for graph in graphs]
|
graph_models: list[GraphMeta] = []
|
||||||
|
for graph in graphs:
|
||||||
|
try:
|
||||||
|
graph_meta = GraphModel.from_db(graph).meta()
|
||||||
|
# Trigger serialization to validate that the graph is well formed
|
||||||
|
graph_meta.model_dump()
|
||||||
|
graph_models.append(graph_meta)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing graph {graph.id}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
return GraphsPaginated(
|
return GraphsPaginated(
|
||||||
graphs=graph_models,
|
graphs=graph_models,
|
||||||
|
|||||||
@@ -463,108 +463,3 @@ def test_node_credentials_optional_with_other_metadata():
|
|||||||
assert node.credentials_optional is True
|
assert node.credentials_optional is True
|
||||||
assert node.metadata["position"] == {"x": 100, "y": 200}
|
assert node.metadata["position"] == {"x": 100, "y": 200}
|
||||||
assert node.metadata["customized_name"] == "My Custom Node"
|
assert node.metadata["customized_name"] == "My Custom Node"
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Tests for MCP Credential Deduplication
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def test_mcp_credential_combine_different_servers():
|
|
||||||
"""Two MCP credential fields with different server URLs should produce
|
|
||||||
separate entries when combined (not merged into one)."""
|
|
||||||
from backend.data.model import CredentialsFieldInfo
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
field_sentry = CredentialsFieldInfo(
|
|
||||||
credentials_provider=frozenset([ProviderName.MCP]),
|
|
||||||
credentials_types=frozenset(["oauth2"]),
|
|
||||||
discriminator="server_url",
|
|
||||||
discriminator_values={"https://mcp.sentry.dev/mcp"},
|
|
||||||
)
|
|
||||||
field_linear = CredentialsFieldInfo(
|
|
||||||
credentials_provider=frozenset([ProviderName.MCP]),
|
|
||||||
credentials_types=frozenset(["oauth2"]),
|
|
||||||
discriminator="server_url",
|
|
||||||
discriminator_values={"https://mcp.linear.app/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
combined = CredentialsFieldInfo.combine(
|
|
||||||
(field_sentry, ("node-sentry", "credentials")),
|
|
||||||
(field_linear, ("node-linear", "credentials")),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should produce 2 separate credential entries
|
|
||||||
assert len(combined) == 2, (
|
|
||||||
f"Expected 2 credential entries for 2 MCP blocks with different servers, "
|
|
||||||
f"got {len(combined)}: {list(combined.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Each entry should contain the server hostname in its key
|
|
||||||
keys = list(combined.keys())
|
|
||||||
assert any(
|
|
||||||
"mcp.sentry.dev" in k for k in keys
|
|
||||||
), f"Expected 'mcp.sentry.dev' in one key, got {keys}"
|
|
||||||
assert any(
|
|
||||||
"mcp.linear.app" in k for k in keys
|
|
||||||
), f"Expected 'mcp.linear.app' in one key, got {keys}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_mcp_credential_combine_same_server():
|
|
||||||
"""Two MCP credential fields with the same server URL should be combined
|
|
||||||
into one credential entry."""
|
|
||||||
from backend.data.model import CredentialsFieldInfo
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
field_a = CredentialsFieldInfo(
|
|
||||||
credentials_provider=frozenset([ProviderName.MCP]),
|
|
||||||
credentials_types=frozenset(["oauth2"]),
|
|
||||||
discriminator="server_url",
|
|
||||||
discriminator_values={"https://mcp.sentry.dev/mcp"},
|
|
||||||
)
|
|
||||||
field_b = CredentialsFieldInfo(
|
|
||||||
credentials_provider=frozenset([ProviderName.MCP]),
|
|
||||||
credentials_types=frozenset(["oauth2"]),
|
|
||||||
discriminator="server_url",
|
|
||||||
discriminator_values={"https://mcp.sentry.dev/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
combined = CredentialsFieldInfo.combine(
|
|
||||||
(field_a, ("node-a", "credentials")),
|
|
||||||
(field_b, ("node-b", "credentials")),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should produce 1 credential entry (same server URL)
|
|
||||||
assert len(combined) == 1, (
|
|
||||||
f"Expected 1 credential entry for 2 MCP blocks with same server, "
|
|
||||||
f"got {len(combined)}: {list(combined.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_mcp_credential_combine_no_discriminator_values():
|
|
||||||
"""MCP credential fields without discriminator_values should be merged
|
|
||||||
into a single entry (backwards compat for blocks without server_url set)."""
|
|
||||||
from backend.data.model import CredentialsFieldInfo
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
field_a = CredentialsFieldInfo(
|
|
||||||
credentials_provider=frozenset([ProviderName.MCP]),
|
|
||||||
credentials_types=frozenset(["oauth2"]),
|
|
||||||
discriminator="server_url",
|
|
||||||
)
|
|
||||||
field_b = CredentialsFieldInfo(
|
|
||||||
credentials_provider=frozenset([ProviderName.MCP]),
|
|
||||||
credentials_types=frozenset(["oauth2"]),
|
|
||||||
discriminator="server_url",
|
|
||||||
)
|
|
||||||
|
|
||||||
combined = CredentialsFieldInfo.combine(
|
|
||||||
(field_a, ("node-a", "credentials")),
|
|
||||||
(field_b, ("node-b", "credentials")),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should produce 1 entry (no URL differentiation)
|
|
||||||
assert len(combined) == 1, (
|
|
||||||
f"Expected 1 credential entry for MCP blocks without discriminator_values, "
|
|
||||||
f"got {len(combined)}: {list(combined.keys())}"
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
get_args,
|
get_args,
|
||||||
)
|
)
|
||||||
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from prisma.enums import CreditTransactionType, OnboardingStep
|
from prisma.enums import CreditTransactionType, OnboardingStep
|
||||||
@@ -41,7 +42,6 @@ from typing_extensions import TypedDict
|
|||||||
|
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.json import loads as json_loads
|
from backend.util.json import loads as json_loads
|
||||||
from backend.util.request import parse_url
|
|
||||||
from backend.util.settings import Secrets
|
from backend.util.settings import Secrets
|
||||||
|
|
||||||
# Type alias for any provider name (including custom ones)
|
# Type alias for any provider name (including custom ones)
|
||||||
@@ -163,6 +163,7 @@ class User(BaseModel):
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from prisma.models import User as PrismaUser
|
from prisma.models import User as PrismaUser
|
||||||
|
|
||||||
|
from backend.data.block import BlockSchema
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -396,25 +397,19 @@ class HostScopedCredentials(_BaseCredentials):
|
|||||||
def matches_url(self, url: str) -> bool:
|
def matches_url(self, url: str) -> bool:
|
||||||
"""Check if this credential should be applied to the given URL."""
|
"""Check if this credential should be applied to the given URL."""
|
||||||
|
|
||||||
request_host, request_port = _extract_host_from_url(url)
|
parsed_url = urlparse(url)
|
||||||
cred_scope_host, cred_scope_port = _extract_host_from_url(self.host)
|
# Extract hostname without port
|
||||||
|
request_host = parsed_url.hostname
|
||||||
if not request_host:
|
if not request_host:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# If a port is specified in credential host, the request host port must match
|
# Simple host matching - exact match or wildcard subdomain match
|
||||||
if cred_scope_port is not None and request_port != cred_scope_port:
|
if self.host == request_host:
|
||||||
return False
|
|
||||||
# Non-standard ports are only allowed if explicitly specified in credential host
|
|
||||||
elif cred_scope_port is None and request_port not in (80, 443, None):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Simple host matching
|
|
||||||
if cred_scope_host == request_host:
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
|
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
|
||||||
if cred_scope_host.startswith("*."):
|
if self.host.startswith("*."):
|
||||||
domain = cred_scope_host[2:] # Remove "*."
|
domain = self.host[2:] # Remove "*."
|
||||||
return request_host.endswith(f".{domain}") or request_host == domain
|
return request_host.endswith(f".{domain}") or request_host == domain
|
||||||
|
|
||||||
return False
|
return False
|
||||||
@@ -507,13 +502,15 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||||||
def allowed_cred_types(cls) -> tuple[CredentialsType, ...]:
|
def allowed_cred_types(cls) -> tuple[CredentialsType, ...]:
|
||||||
return get_args(cls.model_fields["type"].annotation)
|
return get_args(cls.model_fields["type"].annotation)
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def validate_credentials_field_schema(
|
def validate_credentials_field_schema(cls, model: type["BlockSchema"]):
|
||||||
field_schema: dict[str, Any], field_name: str
|
|
||||||
):
|
|
||||||
"""Validates the schema of a credentials input field"""
|
"""Validates the schema of a credentials input field"""
|
||||||
|
field_name = next(
|
||||||
|
name for name, type in model.get_credentials_fields().items() if type is cls
|
||||||
|
)
|
||||||
|
field_schema = model.jsonschema()["properties"][field_name]
|
||||||
try:
|
try:
|
||||||
field_info = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
|
schema_extra = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
if "Field required [type=missing" not in str(e):
|
if "Field required [type=missing" not in str(e):
|
||||||
raise
|
raise
|
||||||
@@ -523,11 +520,11 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||||||
f"{field_schema}"
|
f"{field_schema}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
providers = field_info.provider
|
providers = cls.allowed_providers()
|
||||||
if (
|
if (
|
||||||
providers is not None
|
providers is not None
|
||||||
and len(providers) > 1
|
and len(providers) > 1
|
||||||
and not field_info.discriminator
|
and not schema_extra.discriminator
|
||||||
):
|
):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Multi-provider CredentialsField '{field_name}' "
|
f"Multi-provider CredentialsField '{field_name}' "
|
||||||
@@ -554,13 +551,13 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _extract_host_from_url(url: str) -> tuple[str, int | None]:
|
def _extract_host_from_url(url: str) -> str:
|
||||||
"""Extract host and port from URL for grouping host-scoped credentials."""
|
"""Extract host from URL for grouping host-scoped credentials."""
|
||||||
try:
|
try:
|
||||||
parsed = parse_url(url)
|
parsed = urlparse(url)
|
||||||
return parsed.hostname or url, parsed.port
|
return parsed.hostname or url
|
||||||
except Exception:
|
except Exception:
|
||||||
return "", None
|
return ""
|
||||||
|
|
||||||
|
|
||||||
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||||
@@ -603,20 +600,13 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
] = defaultdict(list)
|
] = defaultdict(list)
|
||||||
|
|
||||||
for field, key in fields:
|
for field, key in fields:
|
||||||
if (
|
if field.provider == frozenset([ProviderName.HTTP]):
|
||||||
field.discriminator
|
# HTTP host-scoped credentials can have different hosts that reqires different credential sets.
|
||||||
and not field.discriminator_mapping
|
# Group by host extracted from the URL
|
||||||
and field.discriminator_values
|
|
||||||
):
|
|
||||||
# URL-based discrimination (e.g. HTTP host-scoped, MCP server URL):
|
|
||||||
# Each unique host gets its own credential entry.
|
|
||||||
provider_prefix = next(iter(field.provider))
|
|
||||||
# Use .value for enum types to get the plain string (e.g. "mcp" not "ProviderName.MCP")
|
|
||||||
prefix_str = getattr(provider_prefix, "value", str(provider_prefix))
|
|
||||||
providers = frozenset(
|
providers = frozenset(
|
||||||
[cast(CP, prefix_str)]
|
[cast(CP, "http")]
|
||||||
+ [
|
+ [
|
||||||
cast(CP, parse_url(str(value)).netloc)
|
cast(CP, _extract_host_from_url(str(value)))
|
||||||
for value in field.discriminator_values
|
for value in field.discriminator_values
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -79,23 +79,10 @@ class TestHostScopedCredentials:
|
|||||||
headers={"Authorization": SecretStr("Bearer token")},
|
headers={"Authorization": SecretStr("Bearer token")},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Non-standard ports require explicit port in credential host
|
assert creds.matches_url("http://localhost:8080/api/v1")
|
||||||
assert not creds.matches_url("http://localhost:8080/api/v1")
|
|
||||||
assert creds.matches_url("https://localhost:443/secure/endpoint")
|
assert creds.matches_url("https://localhost:443/secure/endpoint")
|
||||||
assert creds.matches_url("http://localhost/simple")
|
assert creds.matches_url("http://localhost/simple")
|
||||||
|
|
||||||
def test_matches_url_with_explicit_port(self):
|
|
||||||
"""Test URL matching with explicit port in credential host."""
|
|
||||||
creds = HostScopedCredentials(
|
|
||||||
provider="custom",
|
|
||||||
host="localhost:8080",
|
|
||||||
headers={"Authorization": SecretStr("Bearer token")},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert creds.matches_url("http://localhost:8080/api/v1")
|
|
||||||
assert not creds.matches_url("http://localhost:3000/api/v1")
|
|
||||||
assert not creds.matches_url("http://localhost/simple")
|
|
||||||
|
|
||||||
def test_empty_headers_dict(self):
|
def test_empty_headers_dict(self):
|
||||||
"""Test HostScopedCredentials with empty headers."""
|
"""Test HostScopedCredentials with empty headers."""
|
||||||
creds = HostScopedCredentials(
|
creds = HostScopedCredentials(
|
||||||
@@ -141,20 +128,8 @@ class TestHostScopedCredentials:
|
|||||||
("*.example.com", "https://sub.api.example.com/test", True),
|
("*.example.com", "https://sub.api.example.com/test", True),
|
||||||
("*.example.com", "https://example.com/test", True),
|
("*.example.com", "https://example.com/test", True),
|
||||||
("*.example.com", "https://example.org/test", False),
|
("*.example.com", "https://example.org/test", False),
|
||||||
# Non-standard ports require explicit port in credential host
|
("localhost", "http://localhost:3000/test", True),
|
||||||
("localhost", "http://localhost:3000/test", False),
|
|
||||||
("localhost:3000", "http://localhost:3000/test", True),
|
|
||||||
("localhost", "http://127.0.0.1:3000/test", False),
|
("localhost", "http://127.0.0.1:3000/test", False),
|
||||||
# IPv6 addresses (frontend stores with brackets via URL.hostname)
|
|
||||||
("[::1]", "http://[::1]/test", True),
|
|
||||||
("[::1]", "http://[::1]:80/test", True),
|
|
||||||
("[::1]", "https://[::1]:443/test", True),
|
|
||||||
("[::1]", "http://[::1]:8080/test", False), # Non-standard port
|
|
||||||
("[::1]:8080", "http://[::1]:8080/test", True),
|
|
||||||
("[::1]:8080", "http://[::1]:9090/test", False),
|
|
||||||
("[2001:db8::1]", "http://[2001:db8::1]/path", True),
|
|
||||||
("[2001:db8::1]", "https://[2001:db8::1]:443/path", True),
|
|
||||||
("[2001:db8::1]", "http://[2001:db8::ff]/path", False),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
|
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from redis.asyncio.lock import Lock as AsyncRedisLock
|
|||||||
|
|
||||||
from backend.blocks.agent import AgentExecutorBlock
|
from backend.blocks.agent import AgentExecutorBlock
|
||||||
from backend.blocks.io import AgentOutputBlock
|
from backend.blocks.io import AgentOutputBlock
|
||||||
from backend.blocks.mcp.block import MCPToolBlock
|
|
||||||
from backend.data import redis_client as redis
|
from backend.data import redis_client as redis
|
||||||
from backend.data.block import (
|
from backend.data.block import (
|
||||||
BlockInput,
|
BlockInput,
|
||||||
@@ -230,10 +229,6 @@ async def execute_node(
|
|||||||
_input_data.nodes_input_masks = nodes_input_masks
|
_input_data.nodes_input_masks = nodes_input_masks
|
||||||
_input_data.user_id = user_id
|
_input_data.user_id = user_id
|
||||||
input_data = _input_data.model_dump()
|
input_data = _input_data.model_dump()
|
||||||
elif isinstance(node_block, MCPToolBlock):
|
|
||||||
_mcp_data = MCPToolBlock.Input(**node.input_default)
|
|
||||||
_mcp_data.tool_arguments = input_data
|
|
||||||
input_data = _mcp_data.model_dump()
|
|
||||||
data.inputs = input_data
|
data.inputs = input_data
|
||||||
|
|
||||||
# Execute the node
|
# Execute the node
|
||||||
@@ -270,12 +265,7 @@ async def execute_node(
|
|||||||
|
|
||||||
# Handle regular credentials fields
|
# Handle regular credentials fields
|
||||||
for field_name, input_type in input_model.get_credentials_fields().items():
|
for field_name, input_type in input_model.get_credentials_fields().items():
|
||||||
field_value = input_data.get(field_name)
|
credentials_meta = input_type(**input_data[field_name])
|
||||||
if not field_value or (
|
|
||||||
isinstance(field_value, dict) and not field_value.get("id")
|
|
||||||
):
|
|
||||||
continue # No credentials configured — block runs without
|
|
||||||
credentials_meta = input_type(**field_value)
|
|
||||||
credentials, lock = await creds_manager.acquire(user_id, credentials_meta.id)
|
credentials, lock = await creds_manager.acquire(user_id, credentials_meta.id)
|
||||||
creds_locks.append(lock)
|
creds_locks.append(lock)
|
||||||
extra_exec_kwargs[field_name] = credentials
|
extra_exec_kwargs[field_name] = credentials
|
||||||
|
|||||||
@@ -339,16 +339,16 @@ async def _validate_node_input_credentials(
|
|||||||
] = "Invalid credentials: type/provider mismatch"
|
] = "Invalid credentials: type/provider mismatch"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# If node has optional credentials and any are missing, allow running without.
|
# If node has optional credentials and any are missing, mark for skipping
|
||||||
# The executor will pass credentials=None to the block's run().
|
# But only if there are no other errors for this node
|
||||||
if (
|
if (
|
||||||
has_missing_credentials
|
has_missing_credentials
|
||||||
and node.credentials_optional
|
and node.credentials_optional
|
||||||
and node.id not in credential_errors
|
and node.id not in credential_errors
|
||||||
):
|
):
|
||||||
|
nodes_to_skip.add(node.id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Node #{node.id}: optional credentials not configured, "
|
f"Node #{node.id} will be skipped: optional credentials not configured"
|
||||||
"running without"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return credential_errors, nodes_to_skip
|
return credential_errors, nodes_to_skip
|
||||||
@@ -373,7 +373,7 @@ def make_node_credentials_input_map(
|
|||||||
# Get aggregated credentials fields for the graph
|
# Get aggregated credentials fields for the graph
|
||||||
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
||||||
|
|
||||||
for graph_input_name, (_, compatible_node_fields, _) in graph_cred_inputs.items():
|
for graph_input_name, (_, compatible_node_fields) in graph_cred_inputs.items():
|
||||||
# Best-effort map: skip missing items
|
# Best-effort map: skip missing items
|
||||||
if graph_input_name not in graph_credentials_input:
|
if graph_input_name not in graph_credentials_input:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -224,14 +224,6 @@ openweathermap_credentials = APIKeyCredentials(
|
|||||||
expires_at=None,
|
expires_at=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
elevenlabs_credentials = APIKeyCredentials(
|
|
||||||
id="f4a8b6c2-3d1e-4f5a-9b8c-7d6e5f4a3b2c",
|
|
||||||
provider="elevenlabs",
|
|
||||||
api_key=SecretStr(settings.secrets.elevenlabs_api_key),
|
|
||||||
title="Use Credits for ElevenLabs",
|
|
||||||
expires_at=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_CREDENTIALS = [
|
DEFAULT_CREDENTIALS = [
|
||||||
ollama_credentials,
|
ollama_credentials,
|
||||||
revid_credentials,
|
revid_credentials,
|
||||||
@@ -260,7 +252,6 @@ DEFAULT_CREDENTIALS = [
|
|||||||
v0_credentials,
|
v0_credentials,
|
||||||
webshare_proxy_credentials,
|
webshare_proxy_credentials,
|
||||||
openweathermap_credentials,
|
openweathermap_credentials,
|
||||||
elevenlabs_credentials,
|
|
||||||
]
|
]
|
||||||
|
|
||||||
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
|
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
|
||||||
@@ -375,8 +366,6 @@ class IntegrationCredentialsStore:
|
|||||||
all_credentials.append(webshare_proxy_credentials)
|
all_credentials.append(webshare_proxy_credentials)
|
||||||
if settings.secrets.openweathermap_api_key:
|
if settings.secrets.openweathermap_api_key:
|
||||||
all_credentials.append(openweathermap_credentials)
|
all_credentials.append(openweathermap_credentials)
|
||||||
if settings.secrets.elevenlabs_api_key:
|
|
||||||
all_credentials.append(elevenlabs_credentials)
|
|
||||||
return all_credentials
|
return all_credentials
|
||||||
|
|
||||||
async def get_creds_by_id(
|
async def get_creds_by_id(
|
||||||
|
|||||||
@@ -137,10 +137,7 @@ class IntegrationCredentialsManager:
|
|||||||
self, user_id: str, credentials: OAuth2Credentials, lock: bool = True
|
self, user_id: str, credentials: OAuth2Credentials, lock: bool = True
|
||||||
) -> OAuth2Credentials:
|
) -> OAuth2Credentials:
|
||||||
async with self._locked(user_id, credentials.id, "refresh"):
|
async with self._locked(user_id, credentials.id, "refresh"):
|
||||||
if credentials.provider == str(ProviderName.MCP):
|
oauth_handler = await _get_provider_oauth_handler(credentials.provider)
|
||||||
oauth_handler = _create_mcp_oauth_handler(credentials)
|
|
||||||
else:
|
|
||||||
oauth_handler = await _get_provider_oauth_handler(credentials.provider)
|
|
||||||
if oauth_handler.needs_refresh(credentials):
|
if oauth_handler.needs_refresh(credentials):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Refreshing '{credentials.provider}' "
|
f"Refreshing '{credentials.provider}' "
|
||||||
@@ -239,25 +236,3 @@ async def _get_provider_oauth_handler(provider_name_str: str) -> "BaseOAuthHandl
|
|||||||
client_secret=client_secret,
|
client_secret=client_secret,
|
||||||
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
|
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _create_mcp_oauth_handler(
|
|
||||||
credentials: OAuth2Credentials,
|
|
||||||
) -> "BaseOAuthHandler":
|
|
||||||
"""Create an MCPOAuthHandler from credential metadata for token refresh.
|
|
||||||
|
|
||||||
MCP OAuth handlers have dynamic endpoints discovered per-server, so they
|
|
||||||
can't be registered as singletons in HANDLERS_BY_NAME. Instead, the handler
|
|
||||||
is reconstructed from metadata stored on the credential during initial auth.
|
|
||||||
"""
|
|
||||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
|
||||||
|
|
||||||
meta = credentials.metadata or {}
|
|
||||||
return MCPOAuthHandler(
|
|
||||||
client_id=meta.get("mcp_client_id", ""),
|
|
||||||
client_secret=meta.get("mcp_client_secret", ""),
|
|
||||||
redirect_uri="", # Not needed for token refresh
|
|
||||||
authorize_url="", # Not needed for token refresh
|
|
||||||
token_url=meta.get("mcp_token_url", ""),
|
|
||||||
resource_url=meta.get("mcp_resource_url"),
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ class ProviderName(str, Enum):
|
|||||||
DISCORD = "discord"
|
DISCORD = "discord"
|
||||||
D_ID = "d_id"
|
D_ID = "d_id"
|
||||||
E2B = "e2b"
|
E2B = "e2b"
|
||||||
ELEVENLABS = "elevenlabs"
|
|
||||||
FAL = "fal"
|
FAL = "fal"
|
||||||
GITHUB = "github"
|
GITHUB = "github"
|
||||||
GOOGLE = "google"
|
GOOGLE = "google"
|
||||||
@@ -30,7 +29,6 @@ class ProviderName(str, Enum):
|
|||||||
IDEOGRAM = "ideogram"
|
IDEOGRAM = "ideogram"
|
||||||
JINA = "jina"
|
JINA = "jina"
|
||||||
LLAMA_API = "llama_api"
|
LLAMA_API = "llama_api"
|
||||||
MCP = "mcp"
|
|
||||||
MEDIUM = "medium"
|
MEDIUM = "medium"
|
||||||
MEM0 = "mem0"
|
MEM0 = "mem0"
|
||||||
NOTION = "notion"
|
NOTION = "notion"
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ from pathlib import Path
|
|||||||
from typing import TYPE_CHECKING, Literal
|
from typing import TYPE_CHECKING, Literal
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
@@ -19,35 +17,6 @@ from backend.util.virus_scanner import scan_content_safe
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceUri(BaseModel):
|
|
||||||
"""Parsed workspace:// URI."""
|
|
||||||
|
|
||||||
file_ref: str # File ID or path (e.g. "abc123" or "/path/to/file.txt")
|
|
||||||
mime_type: str | None = None # MIME type from fragment (e.g. "video/mp4")
|
|
||||||
is_path: bool = False # True if file_ref is a path (starts with "/")
|
|
||||||
|
|
||||||
|
|
||||||
def parse_workspace_uri(uri: str) -> WorkspaceUri:
|
|
||||||
"""Parse a workspace:// URI into its components.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
"workspace://abc123" → WorkspaceUri(file_ref="abc123", mime_type=None, is_path=False)
|
|
||||||
"workspace://abc123#video/mp4" → WorkspaceUri(file_ref="abc123", mime_type="video/mp4", is_path=False)
|
|
||||||
"workspace:///path/to/file.txt" → WorkspaceUri(file_ref="/path/to/file.txt", mime_type=None, is_path=True)
|
|
||||||
"""
|
|
||||||
raw = uri.removeprefix("workspace://")
|
|
||||||
mime_type: str | None = None
|
|
||||||
if "#" in raw:
|
|
||||||
raw, fragment = raw.split("#", 1)
|
|
||||||
mime_type = fragment or None
|
|
||||||
return WorkspaceUri(
|
|
||||||
file_ref=raw,
|
|
||||||
mime_type=mime_type,
|
|
||||||
is_path=raw.startswith("/"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Return format options for store_media_file
|
# Return format options for store_media_file
|
||||||
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
||||||
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
|
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
|
||||||
@@ -214,20 +183,22 @@ async def store_media_file(
|
|||||||
"This file type is only available in CoPilot sessions."
|
"This file type is only available in CoPilot sessions."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse workspace reference (strips #mimeType fragment from file ID)
|
# Parse workspace reference
|
||||||
ws = parse_workspace_uri(file)
|
# workspace://abc123 - by file ID
|
||||||
|
# workspace:///path/to/file.txt - by virtual path
|
||||||
|
file_ref = file[12:] # Remove "workspace://"
|
||||||
|
|
||||||
if ws.is_path:
|
if file_ref.startswith("/"):
|
||||||
# Path reference: workspace:///path/to/file.txt
|
# Path reference
|
||||||
workspace_content = await workspace_manager.read_file(ws.file_ref)
|
workspace_content = await workspace_manager.read_file(file_ref)
|
||||||
file_info = await workspace_manager.get_file_info_by_path(ws.file_ref)
|
file_info = await workspace_manager.get_file_info_by_path(file_ref)
|
||||||
filename = sanitize_filename(
|
filename = sanitize_filename(
|
||||||
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# ID reference: workspace://abc123 or workspace://abc123#video/mp4
|
# ID reference
|
||||||
workspace_content = await workspace_manager.read_file_by_id(ws.file_ref)
|
workspace_content = await workspace_manager.read_file_by_id(file_ref)
|
||||||
file_info = await workspace_manager.get_file_info(ws.file_ref)
|
file_info = await workspace_manager.get_file_info(file_ref)
|
||||||
filename = sanitize_filename(
|
filename = sanitize_filename(
|
||||||
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||||
)
|
)
|
||||||
@@ -363,21 +334,7 @@ async def store_media_file(
|
|||||||
|
|
||||||
# Don't re-save if input was already from workspace
|
# Don't re-save if input was already from workspace
|
||||||
if is_from_workspace:
|
if is_from_workspace:
|
||||||
# Return original workspace reference, ensuring MIME type fragment
|
# Return original workspace reference
|
||||||
ws = parse_workspace_uri(file)
|
|
||||||
if not ws.mime_type:
|
|
||||||
# Add MIME type fragment if missing (older refs without it)
|
|
||||||
try:
|
|
||||||
if ws.is_path:
|
|
||||||
info = await workspace_manager.get_file_info_by_path(
|
|
||||||
ws.file_ref
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
info = await workspace_manager.get_file_info(ws.file_ref)
|
|
||||||
if info:
|
|
||||||
return MediaFileType(f"{file}#{info.mimeType}")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return MediaFileType(file)
|
return MediaFileType(file)
|
||||||
|
|
||||||
# Save new content to workspace
|
# Save new content to workspace
|
||||||
@@ -389,7 +346,7 @@ async def store_media_file(
|
|||||||
filename=filename,
|
filename=filename,
|
||||||
overwrite=True,
|
overwrite=True,
|
||||||
)
|
)
|
||||||
return MediaFileType(f"workspace://{file_record.id}#{file_record.mimeType}")
|
return MediaFileType(f"workspace://{file_record.id}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid return_format: {return_format}")
|
raise ValueError(f"Invalid return_format: {return_format}")
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ class HostResolver(abc.AbstractResolver):
|
|||||||
def __init__(self, ssl_hostname: str, ip_addresses: list[str]):
|
def __init__(self, ssl_hostname: str, ip_addresses: list[str]):
|
||||||
self.ssl_hostname = ssl_hostname
|
self.ssl_hostname = ssl_hostname
|
||||||
self.ip_addresses = ip_addresses
|
self.ip_addresses = ip_addresses
|
||||||
self._default = aiohttp.ThreadedResolver()
|
self._default = aiohttp.AsyncResolver()
|
||||||
|
|
||||||
async def resolve(self, host, port=0, family=socket.AF_INET):
|
async def resolve(self, host, port=0, family=socket.AF_INET):
|
||||||
if host == self.ssl_hostname:
|
if host == self.ssl_hostname:
|
||||||
@@ -157,7 +157,12 @@ async def validate_url(
|
|||||||
is_trusted: Boolean indicating if the hostname is in trusted_origins
|
is_trusted: Boolean indicating if the hostname is in trusted_origins
|
||||||
ip_addresses: List of IP addresses for the host; empty if the host is trusted
|
ip_addresses: List of IP addresses for the host; empty if the host is trusted
|
||||||
"""
|
"""
|
||||||
parsed = parse_url(url)
|
# Canonicalize URL
|
||||||
|
url = url.strip("/ ").replace("\\", "/")
|
||||||
|
parsed = urlparse(url)
|
||||||
|
if not parsed.scheme:
|
||||||
|
url = f"http://{url}"
|
||||||
|
parsed = urlparse(url)
|
||||||
|
|
||||||
# Check scheme
|
# Check scheme
|
||||||
if parsed.scheme not in ALLOWED_SCHEMES:
|
if parsed.scheme not in ALLOWED_SCHEMES:
|
||||||
@@ -215,17 +220,6 @@ async def validate_url(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_url(url: str) -> URL:
|
|
||||||
"""Canonicalizes and parses a URL string."""
|
|
||||||
url = url.strip("/ ").replace("\\", "/")
|
|
||||||
|
|
||||||
# Ensure scheme is present for proper parsing
|
|
||||||
if not re.match(r"[a-z0-9+.\-]+://", url):
|
|
||||||
url = f"http://{url}"
|
|
||||||
|
|
||||||
return urlparse(url)
|
|
||||||
|
|
||||||
|
|
||||||
def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL:
|
def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL:
|
||||||
"""
|
"""
|
||||||
Pins a URL to a specific IP address to prevent DNS rebinding attacks.
|
Pins a URL to a specific IP address to prevent DNS rebinding attacks.
|
||||||
@@ -467,7 +461,7 @@ class Requests:
|
|||||||
resolver = HostResolver(ssl_hostname=hostname, ip_addresses=ip_addresses)
|
resolver = HostResolver(ssl_hostname=hostname, ip_addresses=ip_addresses)
|
||||||
ssl_context = ssl.create_default_context()
|
ssl_context = ssl.create_default_context()
|
||||||
connector = aiohttp.TCPConnector(resolver=resolver, ssl=ssl_context)
|
connector = aiohttp.TCPConnector(resolver=resolver, ssl=ssl_context)
|
||||||
session_kwargs: dict = {}
|
session_kwargs = {}
|
||||||
if connector:
|
if connector:
|
||||||
session_kwargs["connector"] = connector
|
session_kwargs["connector"] = connector
|
||||||
|
|
||||||
|
|||||||
@@ -656,7 +656,6 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
|||||||
e2b_api_key: str = Field(default="", description="E2B API key")
|
e2b_api_key: str = Field(default="", description="E2B API key")
|
||||||
nvidia_api_key: str = Field(default="", description="Nvidia API key")
|
nvidia_api_key: str = Field(default="", description="Nvidia API key")
|
||||||
mem0_api_key: str = Field(default="", description="Mem0 API key")
|
mem0_api_key: str = Field(default="", description="Mem0 API key")
|
||||||
elevenlabs_api_key: str = Field(default="", description="ElevenLabs API key")
|
|
||||||
|
|
||||||
linear_client_id: str = Field(default="", description="Linear client ID")
|
linear_client_id: str = Field(default="", description="Linear client ID")
|
||||||
linear_client_secret: str = Field(default="", description="Linear client secret")
|
linear_client_secret: str = Field(default="", description="Linear client secret")
|
||||||
|
|||||||
32
autogpt_platform/backend/backend/util/validation.py
Normal file
32
autogpt_platform/backend/backend/util/validation.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Validation utilities."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
_UUID_V4_PATTERN = re.compile(
|
||||||
|
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_uuid_v4(text: str) -> bool:
|
||||||
|
"""Check if text is a valid UUID v4.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: String to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the text is a valid UUID v4, False otherwise
|
||||||
|
"""
|
||||||
|
return bool(_UUID_V4_PATTERN.fullmatch(text.strip()))
|
||||||
|
|
||||||
|
|
||||||
|
def extract_uuids(text: str) -> list[str]:
|
||||||
|
"""Extract all UUID v4 strings from text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: String to search for UUIDs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of unique UUIDs found (lowercase)
|
||||||
|
"""
|
||||||
|
return list({m.lower() for m in _UUID_V4_PATTERN.findall(text)})
|
||||||
@@ -22,7 +22,6 @@ from backend.data.workspace import (
|
|||||||
soft_delete_workspace_file,
|
soft_delete_workspace_file,
|
||||||
)
|
)
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
from backend.util.virus_scanner import scan_content_safe
|
|
||||||
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
|
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -188,9 +187,6 @@ class WorkspaceManager:
|
|||||||
f"{Config().max_file_size_mb}MB limit"
|
f"{Config().max_file_size_mb}MB limit"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Virus scan content before persisting (defense in depth)
|
|
||||||
await scan_content_safe(content, filename=filename)
|
|
||||||
|
|
||||||
# Determine path with session scoping
|
# Determine path with session scoping
|
||||||
if path is None:
|
if path is None:
|
||||||
path = f"/{filename}"
|
path = f"/{filename}"
|
||||||
|
|||||||
6796
autogpt_platform/backend/poetry.lock
generated
6796
autogpt_platform/backend/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -20,8 +20,7 @@ click = "^8.2.0"
|
|||||||
cryptography = "^45.0"
|
cryptography = "^45.0"
|
||||||
discord-py = "^2.5.2"
|
discord-py = "^2.5.2"
|
||||||
e2b-code-interpreter = "^1.5.2"
|
e2b-code-interpreter = "^1.5.2"
|
||||||
elevenlabs = "^1.50.0"
|
fastapi = "^0.116.1"
|
||||||
fastapi = "^0.128.0"
|
|
||||||
feedparser = "^6.0.11"
|
feedparser = "^6.0.11"
|
||||||
flake8 = "^7.3.0"
|
flake8 = "^7.3.0"
|
||||||
google-api-python-client = "^2.177.0"
|
google-api-python-client = "^2.177.0"
|
||||||
@@ -35,7 +34,7 @@ jinja2 = "^3.1.6"
|
|||||||
jsonref = "^1.1.0"
|
jsonref = "^1.1.0"
|
||||||
jsonschema = "^4.25.0"
|
jsonschema = "^4.25.0"
|
||||||
langfuse = "^3.11.0"
|
langfuse = "^3.11.0"
|
||||||
launchdarkly-server-sdk = "^9.14.1"
|
launchdarkly-server-sdk = "^9.12.0"
|
||||||
mem0ai = "^0.1.115"
|
mem0ai = "^0.1.115"
|
||||||
moviepy = "^2.1.2"
|
moviepy = "^2.1.2"
|
||||||
ollama = "^0.5.1"
|
ollama = "^0.5.1"
|
||||||
@@ -52,8 +51,8 @@ prometheus-client = "^0.22.1"
|
|||||||
prometheus-fastapi-instrumentator = "^7.0.0"
|
prometheus-fastapi-instrumentator = "^7.0.0"
|
||||||
psutil = "^7.0.0"
|
psutil = "^7.0.0"
|
||||||
psycopg2-binary = "^2.9.10"
|
psycopg2-binary = "^2.9.10"
|
||||||
pydantic = { extras = ["email"], version = "^2.12.5" }
|
pydantic = { extras = ["email"], version = "^2.11.7" }
|
||||||
pydantic-settings = "^2.12.0"
|
pydantic-settings = "^2.10.1"
|
||||||
pytest = "^8.4.1"
|
pytest = "^8.4.1"
|
||||||
pytest-asyncio = "^1.1.0"
|
pytest-asyncio = "^1.1.0"
|
||||||
python-dotenv = "^1.1.1"
|
python-dotenv = "^1.1.1"
|
||||||
@@ -65,14 +64,13 @@ sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlal
|
|||||||
sqlalchemy = "^2.0.40"
|
sqlalchemy = "^2.0.40"
|
||||||
strenum = "^0.4.9"
|
strenum = "^0.4.9"
|
||||||
stripe = "^11.5.0"
|
stripe = "^11.5.0"
|
||||||
supabase = "2.27.2"
|
supabase = "2.17.0"
|
||||||
tenacity = "^9.1.2"
|
tenacity = "^9.1.2"
|
||||||
todoist-api-python = "^2.1.7"
|
todoist-api-python = "^2.1.7"
|
||||||
tweepy = "^4.16.0"
|
tweepy = "^4.16.0"
|
||||||
uvicorn = { extras = ["standard"], version = "^0.40.0" }
|
uvicorn = { extras = ["standard"], version = "^0.35.0" }
|
||||||
websockets = "^15.0"
|
websockets = "^15.0"
|
||||||
youtube-transcript-api = "^1.2.1"
|
youtube-transcript-api = "^1.2.1"
|
||||||
yt-dlp = "2025.12.08"
|
|
||||||
zerobouncesdk = "^1.1.2"
|
zerobouncesdk = "^1.1.2"
|
||||||
# NOTE: please insert new dependencies in their alphabetical location
|
# NOTE: please insert new dependencies in their alphabetical location
|
||||||
pytest-snapshot = "^0.9.0"
|
pytest-snapshot = "^0.9.0"
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
"credentials_input_schema": {
|
"credentials_input_schema": {
|
||||||
"properties": {},
|
"properties": {},
|
||||||
"required": [],
|
"required": [],
|
||||||
|
"title": "TestGraphCredentialsInputSchema",
|
||||||
"type": "object"
|
"type": "object"
|
||||||
},
|
},
|
||||||
"description": "A test graph",
|
"description": "A test graph",
|
||||||
|
|||||||
@@ -1,14 +1,34 @@
|
|||||||
[
|
[
|
||||||
{
|
{
|
||||||
"created_at": "2025-09-04T13:37:00",
|
"credentials_input_schema": {
|
||||||
|
"properties": {},
|
||||||
|
"required": [],
|
||||||
|
"title": "TestGraphCredentialsInputSchema",
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
"description": "A test graph",
|
"description": "A test graph",
|
||||||
"forked_from_id": null,
|
"forked_from_id": null,
|
||||||
"forked_from_version": null,
|
"forked_from_version": null,
|
||||||
|
"has_external_trigger": false,
|
||||||
|
"has_human_in_the_loop": false,
|
||||||
|
"has_sensitive_action": false,
|
||||||
"id": "graph-123",
|
"id": "graph-123",
|
||||||
|
"input_schema": {
|
||||||
|
"properties": {},
|
||||||
|
"required": [],
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
"instructions": null,
|
"instructions": null,
|
||||||
"is_active": true,
|
"is_active": true,
|
||||||
"name": "Test Graph",
|
"name": "Test Graph",
|
||||||
|
"output_schema": {
|
||||||
|
"properties": {},
|
||||||
|
"required": [],
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
"recommended_schedule_cron": null,
|
"recommended_schedule_cron": null,
|
||||||
|
"sub_graphs": [],
|
||||||
|
"trigger_setup_info": null,
|
||||||
"user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
"user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||||
"version": 1
|
"version": 1
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -111,7 +111,9 @@ class TestGenerateAgent:
|
|||||||
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||||
result = await core.generate_agent(instructions)
|
result = await core.generate_agent(instructions)
|
||||||
|
|
||||||
mock_external.assert_called_once_with(instructions, None, None, None)
|
# library_agents defaults to None
|
||||||
|
mock_external.assert_called_once_with(instructions, None)
|
||||||
|
# Result should have id, version, is_active added if not present
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["name"] == "Test Agent"
|
assert result["name"] == "Test Agent"
|
||||||
assert "id" in result
|
assert "id" in result
|
||||||
@@ -175,9 +177,8 @@ class TestGenerateAgentPatch:
|
|||||||
current_agent = {"nodes": [], "links": []}
|
current_agent = {"nodes": [], "links": []}
|
||||||
result = await core.generate_agent_patch("Add a node", current_agent)
|
result = await core.generate_agent_patch("Add a node", current_agent)
|
||||||
|
|
||||||
mock_external.assert_called_once_with(
|
# library_agents defaults to None
|
||||||
"Add a node", current_agent, None, None, None
|
mock_external.assert_called_once_with("Add a node", current_agent, None)
|
||||||
)
|
|
||||||
assert result == expected_result
|
assert result == expected_result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user