Compare commits

..

1 Commits

Author SHA1 Message Date
0ubbe
5a3308074f Update frontend build based on commit 9538992eaf 2026-01-29 11:14:57 +00:00
211 changed files with 2459 additions and 16470 deletions

View File

@@ -29,7 +29,8 @@
"postCreateCmd": [
"cd autogpt_platform/autogpt_libs && poetry install",
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
"cd autogpt_platform/frontend && pnpm install"
"cd autogpt_platform/frontend && pnpm install",
"cd docs && pip install -r requirements.txt"
],
"terminalCommand": "code .",
"deleteBranchWithWorktree": false

View File

@@ -160,7 +160,7 @@ pnpm storybook # Start component development server
**Backend Entry Points:**
- `backend/backend/api/rest_api.py` - FastAPI application setup
- `backend/backend/server/server.py` - FastAPI application setup
- `backend/backend/data/` - Database models and user management
- `backend/blocks/` - Agent execution blocks and logic
@@ -219,7 +219,7 @@ Agents are built using a visual block-based system where each block performs a s
### API Development
1. Update routes in `/backend/backend/api/features/`
1. Update routes in `/backend/backend/server/routers/`
2. Add/update Pydantic models in same directory
3. Write tests alongside route files
4. For `data/*.py` changes, validate user ID checks
@@ -285,7 +285,7 @@ Agents are built using a visual block-based system where each block performs a s
### Security Guidelines
**Cache Protection Middleware** (`/backend/backend/api/middleware/security.py`):
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)

1
.gitignore vendored
View File

@@ -178,6 +178,5 @@ autogpt_platform/backend/settings.py
*.ign.*
.test-contents
.claude/settings.local.json
CLAUDE.local.md
/autogpt_platform/backend/logs
.next

View File

@@ -16,6 +16,7 @@ See `docs/content/platform/getting-started.md` for setup instructions.
- Format Python code with `poetry run format`.
- Format frontend code using `pnpm format`.
## Frontend guidelines:
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
@@ -32,17 +33,14 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
- Use function declarations for components, arrow functions only for callbacks
- No barrel files or `index.ts` re-exports
- Do not use `useCallback` or `useMemo` unless strictly needed
- Avoid comments at all times unless the code is very complex
- Do not use `useCallback` or `useMemo` unless asked to optimise a given function
- Do not type hook returns, let Typescript infer as much as possible
- Never type with `any`, if not types available use `unknown`
## Testing
@@ -51,8 +49,22 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
Always run the relevant linters and tests before committing.
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).
Types: - feat - fix - refactor - ci - dx (developer experience)
Scopes: - platform - platform/library - platform/marketplace - backend - backend/executor - frontend - frontend/library - frontend/marketplace - blocks
Types:
- feat
- fix
- refactor
- ci
- dx (developer experience)
Scopes:
- platform
- platform/library
- platform/marketplace
- backend
- backend/executor
- frontend
- frontend/library
- frontend/marketplace
- blocks
## Pull requests

View File

@@ -54,7 +54,7 @@ Before proceeding with the installation, ensure your system meets the following
### Updated Setup Instructions:
We've moved to a fully maintained and regularly updated documentation site.
👉 [Follow the official self-hosting guide here](https://agpt.co/docs/platform/getting-started/getting-started)
👉 [Follow the official self-hosting guide here](https://docs.agpt.co/platform/getting-started/)
This tutorial assumes you have Docker, VSCode, git and npm installed.

View File

@@ -6,30 +6,152 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
AutoGPT Platform is a monorepo containing:
- **Backend** (`backend`): Python FastAPI server with async support
- **Frontend** (`frontend`): Next.js React application
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
- **Backend** (`/backend`): Python FastAPI server with async support
- **Frontend** (`/frontend`): Next.js React application
- **Shared Libraries** (`/autogpt_libs`): Common Python utilities
## Component Documentation
## Essential Commands
- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks
- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns
### Backend Development
## Key Concepts
```bash
# Install dependencies
cd backend && poetry install
# Run database migrations
poetry run prisma migrate dev
# Start all services (database, redis, rabbitmq, clamav)
docker compose up -d
# Run the backend server
poetry run serve
# Run tests
poetry run test
# Run specific test
poetry run pytest path/to/test_file.py::test_function_name
# Run block tests (tests that validate all blocks work correctly)
poetry run pytest backend/blocks/test/test_block.py -xvs
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
# Lint and format
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
poetry run format # Black + isort
poetry run lint # ruff
```
More details can be found in TESTING.md
#### Creating/Updating Snapshots
When you first write a test or when the expected output changes:
```bash
poetry run pytest path/to/test.py --snapshot-update
```
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
### Frontend Development
```bash
# Install dependencies
cd frontend && pnpm i
# Generate API client from OpenAPI spec
pnpm generate:api
# Start development server
pnpm dev
# Run E2E tests
pnpm test
# Run Storybook for component development
pnpm storybook
# Build production
pnpm build
# Format and lint
pnpm format
# Type checking
pnpm types
```
**📖 Complete Guide**: See `/frontend/CONTRIBUTING.md` and `/frontend/.cursorrules` for comprehensive frontend patterns.
**Key Frontend Conventions:**
- Separate render logic from data/behavior in components
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
- Use function declarations (not arrow functions) for components/handlers
- Use design system components from `src/components/` (atoms, molecules, organisms)
- Only use Phosphor Icons
- Never use `src/components/__legacy__/*` or deprecated `BackendAPI`
## Architecture Overview
### Backend Architecture
- **API Layer**: FastAPI with REST and WebSocket endpoints
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
- **Queue System**: RabbitMQ for async task processing
- **Execution Engine**: Separate executor service processes agent workflows
- **Authentication**: JWT-based with Supabase integration
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
### Frontend Architecture
- **Framework**: Next.js 15 App Router (client-first approach)
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
- **State Management**: React Query for server state, co-located UI state in components/hooks
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
- **Workflow Builder**: Visual graph editor using @xyflow/react
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
- **Icons**: Phosphor Icons only
- **Feature Flags**: LaunchDarkly integration
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
- **Testing**: Playwright for E2E, Storybook for component development
### Key Concepts
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
3. **Integrations**: OAuth and API connections stored per user
4. **Store**: Marketplace for sharing agent templates
5. **Virus Scanning**: ClamAV integration for file upload security
### Testing Approach
- Backend uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
- Frontend uses Playwright for E2E tests
- Component testing via Storybook
### Database Schema
Key models (defined in `/backend/schema.prisma`):
- `User`: Authentication and profile data
- `AgentGraph`: Workflow definitions with version control
- `AgentGraphExecution`: Execution history and results
- `AgentNode`: Individual nodes in a workflow
- `StoreListing`: Marketplace listings for sharing agents
### Environment Configuration
#### Configuration Files
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
- **Backend**: `/backend/.env.default` (defaults) → `/backend/.env` (user overrides)
- **Frontend**: `/frontend/.env.default` (defaults) → `/frontend/.env` (user overrides)
- **Platform**: `/.env.default` (Supabase/shared defaults) → `/.env` (user overrides)
#### Docker Environment Loading Order
@@ -45,12 +167,83 @@ AutoGPT Platform is a monorepo containing:
- Backend/Frontend services use YAML anchors for consistent configuration
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
### Common Development Tasks
**Adding a new block:**
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
- Provider configuration with `ProviderBuilder`
- Block schema definition
- Authentication (API keys, OAuth, webhooks)
- Testing and validation
- File organization
Quick steps:
1. Create new file in `/backend/backend/blocks/`
2. Configure provider using `ProviderBuilder` in `_config.py`
3. Inherit from `Block` base class
4. Define input/output schemas using `BlockSchema`
5. Implement async `run` method
6. Generate unique block ID using `uuid.uuid4()`
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
ex: do the inputs and outputs tie well together?
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
**Modifying the API:**
1. Update route in `/backend/backend/server/routers/`
2. Add/update Pydantic models in same directory
3. Write tests alongside the route file
4. Run `poetry run test` to verify
### Frontend guidelines:
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
- Add `usePageName.ts` hook for logic
- Put sub-components in local `components/` folder
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
- Use design system components from `src/components/` (atoms, molecules, organisms)
- Never use `src/components/__legacy__/*`
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
- Regenerate with `pnpm generate:api`
- Pattern: `use{Method}{Version}{OperationName}`
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
- Use function declarations for components, arrow functions only for callbacks
- No barrel files or `index.ts` re-exports
- Do not use `useCallback` or `useMemo` unless strictly needed
- Avoid comments at all times unless the code is very complex
### Security Implementation
**Cache Protection Middleware:**
- Located in `/backend/backend/server/middleware/security.py`
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses an allow list approach - only explicitly permitted paths can be cached
- Cacheable paths include: static assets (`/static/*`, `/_next/static/*`), health checks, public store pages, documentation
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
- Applied to both main API server and external API applications
### Creating Pull Requests
- Create the PR against the `dev` branch of the repository.
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
- Use conventional commit messages (see below)
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
- Create the PR aginst the `dev` branch of the repository.
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)/
- Use conventional commit messages (see below)/
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description/
- Run the github pre-commit hooks to ensure code quality.
### Reviewing/Revising Pull Requests

View File

@@ -152,7 +152,6 @@ REPLICATE_API_KEY=
REVID_API_KEY=
SCREENSHOTONE_API_KEY=
UNREAL_SPEECH_API_KEY=
ELEVENLABS_API_KEY=
# Data & Search Services
E2B_API_KEY=

View File

@@ -19,6 +19,3 @@ load-tests/*.json
load-tests/*.log
load-tests/node_modules/*
migrations/*/rollback*.sql
# Workspace files
workspaces/

View File

@@ -1,170 +0,0 @@
# CLAUDE.md - Backend
This file provides guidance to Claude Code when working with the backend.
## Essential Commands
To run something with Python package dependencies you MUST use `poetry run ...`.
```bash
# Install dependencies
poetry install
# Run database migrations
poetry run prisma migrate dev
# Start all services (database, redis, rabbitmq, clamav)
docker compose up -d
# Run the backend as a whole
poetry run app
# Run tests
poetry run test
# Run specific test
poetry run pytest path/to/test_file.py::test_function_name
# Run block tests (tests that validate all blocks work correctly)
poetry run pytest backend/blocks/test/test_block.py -xvs
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
# Lint and format
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
poetry run format # Black + isort
poetry run lint # ruff
```
More details can be found in @TESTING.md
### Creating/Updating Snapshots
When you first write a test or when the expected output changes:
```bash
poetry run pytest path/to/test.py --snapshot-update
```
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
## Architecture
- **API Layer**: FastAPI with REST and WebSocket endpoints
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
- **Queue System**: RabbitMQ for async task processing
- **Execution Engine**: Separate executor service processes agent workflows
- **Authentication**: JWT-based with Supabase integration
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
## Testing Approach
- Uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
## Database Schema
Key models (defined in `schema.prisma`):
- `User`: Authentication and profile data
- `AgentGraph`: Workflow definitions with version control
- `AgentGraphExecution`: Execution history and results
- `AgentNode`: Individual nodes in a workflow
- `StoreListing`: Marketplace listings for sharing agents
## Environment Configuration
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
## Common Development Tasks
### Adding a new block
Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers:
- Provider configuration with `ProviderBuilder`
- Block schema definition
- Authentication (API keys, OAuth, webhooks)
- Testing and validation
- File organization
Quick steps:
1. Create new file in `backend/blocks/`
2. Configure provider using `ProviderBuilder` in `_config.py`
3. Inherit from `Block` base class
4. Define input/output schemas using `BlockSchema`
5. Implement async `run` method
6. Generate unique block ID using `uuid.uuid4()`
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
ex: do the inputs and outputs tie well together?
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
#### Handling files in blocks with `store_media_file()`
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
| Format | Use When | Returns |
|--------|----------|---------|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
**Examples:**
```python
# INPUT: Need to process file locally with ffmpeg
local_path = await store_media_file(
file=input_data.video,
execution_context=execution_context,
return_format="for_local_processing",
)
# local_path = "video.mp4" - use with Path/ffmpeg/etc
# INPUT: Need to send to external API like Replicate
image_b64 = await store_media_file(
file=input_data.image,
execution_context=execution_context,
return_format="for_external_api",
)
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
# OUTPUT: Returning result from block
result_url = await store_media_file(
file=generated_image_url,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", result_url
# In CoPilot: result_url = "workspace://abc123"
# In graphs: result_url = "data:image/png;base64,..."
```
**Key points:**
- `for_block_output` is the ONLY format that auto-adapts to execution context
- Always use `for_block_output` for block outputs unless you have a specific reason not to
- Never hardcode workspace checks - let `for_block_output` handle it
### Modifying the API
1. Update route in `backend/api/features/`
2. Add/update Pydantic models in same directory
3. Write tests alongside the route file
4. Run `poetry run test` to verify
## Security Implementation
### Cache Protection Middleware
- Located in `backend/api/middleware/security.py`
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses an allow list approach - only explicitly permitted paths can be cached
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
- Applied to both main API server and external API applications

View File

@@ -62,12 +62,10 @@ ENV POETRY_HOME=/opt/poetry \
DEBIAN_FRONTEND=noninteractive
ENV PATH=/opt/poetry/bin:$PATH
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
# Install Python without upgrading system-managed packages
RUN apt-get update && apt-get install -y \
python3.13 \
python3-pip \
ffmpeg \
imagemagick \
&& rm -rf /var/lib/apt/lists/*
# Copy only necessary files from builder

View File

@@ -138,7 +138,7 @@ If the test doesn't need the `user_id` specifically, mocking is not necessary as
#### Using Global Auth Fixtures
Two global auth fixtures are provided by `backend/api/conftest.py`:
Two global auth fixtures are provided by `backend/server/conftest.py`:
- `mock_jwt_user` - Regular user with `test_user_id` ("test-user-id")
- `mock_jwt_admin` - Admin user with `admin_user_id` ("admin-user-id")

View File

@@ -17,7 +17,7 @@ router = fastapi.APIRouter(
)
# Taken from backend/api/features/store/db.py
# Taken from backend/server/v2/store/db.py
def sanitize_query(query: str | None) -> str | None:
if query is None:
return query

View File

@@ -1,368 +0,0 @@
"""Redis Streams consumer for operation completion messages.
This module provides a consumer (ChatCompletionConsumer) that listens for
completion notifications (OperationCompleteMessage) from external services
(like Agent Generator) and triggers the appropriate stream registry and
chat service updates via process_operation_success/process_operation_failure.
Why Redis Streams instead of RabbitMQ?
--------------------------------------
While the project typically uses RabbitMQ for async task queues (e.g., execution
queue), Redis Streams was chosen for chat completion notifications because:
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
Streams (via stream_registry) for message persistence and replay. Using Redis
Streams for completion notifications keeps all chat streaming infrastructure
in one system, simplifying operations and reducing cross-system coordination.
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
allowing consumers to replay missed messages after reconnection. This aligns
with the SSE reconnection pattern where clients can resume from last_message_id.
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
recovering from dead consumers - ideal for the completion callback pattern.
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
stream_registry) provides lower latency than an additional RabbitMQ hop.
5. **Atomicity with Task State**: Completion processing often needs to update
task metadata stored in Redis. Keeping both in Redis enables simpler
transactional semantics without distributed coordination.
The consumer uses Redis Streams with consumer groups for reliable message
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
stale pending messages from dead consumers.
"""
import asyncio
import logging
import os
import uuid
from typing import Any
import orjson
from prisma import Prisma
from pydantic import BaseModel
from redis.exceptions import ResponseError
from backend.data.redis_client import get_redis_async
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
logger = logging.getLogger(__name__)
config = ChatConfig()
class OperationCompleteMessage(BaseModel):
"""Message format for operation completion notifications."""
operation_id: str
task_id: str
success: bool
result: dict | str | None = None
error: str | None = None
class ChatCompletionConsumer:
"""Consumer for chat operation completion messages from Redis Streams.
This consumer initializes its own Prisma client in start() to ensure
database operations work correctly within this async context.
Uses Redis consumer groups to allow multiple platform pods to consume
messages reliably with automatic redelivery on failure.
"""
def __init__(self):
self._consumer_task: asyncio.Task | None = None
self._running = False
self._prisma: Prisma | None = None
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
async def start(self) -> None:
"""Start the completion consumer."""
if self._running:
logger.warning("Completion consumer already running")
return
# Create consumer group if it doesn't exist
try:
redis = await get_redis_async()
await redis.xgroup_create(
config.stream_completion_name,
config.stream_consumer_group,
id="0",
mkstream=True,
)
logger.info(
f"Created consumer group '{config.stream_consumer_group}' "
f"on stream '{config.stream_completion_name}'"
)
except ResponseError as e:
if "BUSYGROUP" in str(e):
logger.debug(
f"Consumer group '{config.stream_consumer_group}' already exists"
)
else:
raise
self._running = True
self._consumer_task = asyncio.create_task(self._consume_messages())
logger.info(
f"Chat completion consumer started (consumer: {self._consumer_name})"
)
async def _ensure_prisma(self) -> Prisma:
"""Lazily initialize Prisma client on first use."""
if self._prisma is None:
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
self._prisma = Prisma(datasource={"url": database_url})
await self._prisma.connect()
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
return self._prisma
async def stop(self) -> None:
"""Stop the completion consumer."""
self._running = False
if self._consumer_task:
self._consumer_task.cancel()
try:
await self._consumer_task
except asyncio.CancelledError:
pass
self._consumer_task = None
if self._prisma:
await self._prisma.disconnect()
self._prisma = None
logger.info("[COMPLETION] Consumer Prisma client disconnected")
logger.info("Chat completion consumer stopped")
async def _consume_messages(self) -> None:
"""Main message consumption loop with retry logic."""
max_retries = 10
retry_delay = 5 # seconds
retry_count = 0
block_timeout = 5000 # milliseconds
while self._running and retry_count < max_retries:
try:
redis = await get_redis_async()
# Reset retry count on successful connection
retry_count = 0
while self._running:
# First, claim any stale pending messages from dead consumers
# Redis does NOT auto-redeliver pending messages; we must explicitly
# claim them using XAUTOCLAIM
try:
claimed_result = await redis.xautoclaim(
name=config.stream_completion_name,
groupname=config.stream_consumer_group,
consumername=self._consumer_name,
min_idle_time=config.stream_claim_min_idle_ms,
start_id="0-0",
count=10,
)
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
if claimed_result and len(claimed_result) >= 2:
claimed_entries = claimed_result[1]
if claimed_entries:
logger.info(
f"Claimed {len(claimed_entries)} stale pending messages"
)
for entry_id, data in claimed_entries:
if not self._running:
return
await self._process_entry(redis, entry_id, data)
except Exception as e:
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
# Read new messages from the stream
messages = await redis.xreadgroup(
groupname=config.stream_consumer_group,
consumername=self._consumer_name,
streams={config.stream_completion_name: ">"},
block=block_timeout,
count=10,
)
if not messages:
continue
for stream_name, entries in messages:
for entry_id, data in entries:
if not self._running:
return
await self._process_entry(redis, entry_id, data)
except asyncio.CancelledError:
logger.info("Consumer cancelled")
return
except Exception as e:
retry_count += 1
logger.error(
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
exc_info=True,
)
if self._running and retry_count < max_retries:
await asyncio.sleep(retry_delay)
else:
logger.error("Max retries reached, stopping consumer")
return
async def _process_entry(
self, redis: Any, entry_id: str, data: dict[str, Any]
) -> None:
"""Process a single stream entry and acknowledge it on success.
Args:
redis: Redis client connection
entry_id: The stream entry ID
data: The entry data dict
"""
try:
# Handle the message
message_data = data.get("data")
if message_data:
await self._handle_message(
message_data.encode()
if isinstance(message_data, str)
else message_data
)
# Acknowledge the message after successful processing
await redis.xack(
config.stream_completion_name,
config.stream_consumer_group,
entry_id,
)
except Exception as e:
logger.error(
f"Error processing completion message {entry_id}: {e}",
exc_info=True,
)
# Message remains in pending state and will be claimed by
# XAUTOCLAIM after min_idle_time expires
async def _handle_message(self, body: bytes) -> None:
"""Handle a completion message using our own Prisma client."""
try:
data = orjson.loads(body)
message = OperationCompleteMessage(**data)
except Exception as e:
logger.error(f"Failed to parse completion message: {e}")
return
logger.info(
f"[COMPLETION] Received completion for operation {message.operation_id} "
f"(task_id={message.task_id}, success={message.success})"
)
# Find task in registry
task = await stream_registry.find_task_by_operation_id(message.operation_id)
if task is None:
task = await stream_registry.get_task(message.task_id)
if task is None:
logger.warning(
f"[COMPLETION] Task not found for operation {message.operation_id} "
f"(task_id={message.task_id})"
)
return
logger.info(
f"[COMPLETION] Found task: task_id={task.task_id}, "
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
)
# Guard against empty task fields
if not task.task_id or not task.session_id or not task.tool_call_id:
logger.error(
f"[COMPLETION] Task has empty critical fields! "
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
f"tool_call_id={task.tool_call_id!r}"
)
return
if message.success:
await self._handle_success(task, message)
else:
await self._handle_failure(task, message)
async def _handle_success(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle successful operation completion."""
prisma = await self._ensure_prisma()
await process_operation_success(task, message.result, prisma)
async def _handle_failure(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle failed operation completion."""
prisma = await self._ensure_prisma()
await process_operation_failure(task, message.error, prisma)
# Module-level consumer instance
_consumer: ChatCompletionConsumer | None = None
async def start_completion_consumer() -> None:
"""Start the global completion consumer."""
global _consumer
if _consumer is None:
_consumer = ChatCompletionConsumer()
await _consumer.start()
async def stop_completion_consumer() -> None:
"""Stop the global completion consumer."""
global _consumer
if _consumer:
await _consumer.stop()
_consumer = None
async def publish_operation_complete(
operation_id: str,
task_id: str,
success: bool,
result: dict | str | None = None,
error: str | None = None,
) -> None:
"""Publish an operation completion message to Redis Streams.
Args:
operation_id: The operation ID that completed.
task_id: The task ID associated with the operation.
success: Whether the operation succeeded.
result: The result data (for success).
error: The error message (for failure).
"""
message = OperationCompleteMessage(
operation_id=operation_id,
task_id=task_id,
success=success,
result=result,
error=error,
)
redis = await get_redis_async()
await redis.xadd(
config.stream_completion_name,
{"data": message.model_dump_json()},
maxlen=config.stream_max_length,
)
logger.info(f"Published completion for operation {operation_id}")

View File

@@ -1,344 +0,0 @@
"""Shared completion handling for operation success and failure.
This module provides common logic for handling operation completion from both:
- The Redis Streams consumer (completion_consumer.py)
- The HTTP webhook endpoint (routes.py)
"""
import logging
from typing import Any
import orjson
from prisma import Prisma
from . import service as chat_service
from . import stream_registry
from .response_model import StreamError, StreamToolOutputAvailable
from .tools.models import ErrorResponse
logger = logging.getLogger(__name__)
# Tools that produce agent_json that needs to be saved to library
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
# Keys that should be stripped from agent_json when returning in error responses
SENSITIVE_KEYS = frozenset(
{
"api_key",
"apikey",
"api_secret",
"password",
"secret",
"credentials",
"credential",
"token",
"access_token",
"refresh_token",
"private_key",
"privatekey",
"auth",
"authorization",
}
)
def _sanitize_agent_json(obj: Any) -> Any:
"""Recursively sanitize agent_json by removing sensitive keys.
Args:
obj: The object to sanitize (dict, list, or primitive)
Returns:
Sanitized copy with sensitive keys removed/redacted
"""
if isinstance(obj, dict):
return {
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
for k, v in obj.items()
}
elif isinstance(obj, list):
return [_sanitize_agent_json(item) for item in obj]
else:
return obj
class ToolMessageUpdateError(Exception):
"""Raised when updating a tool message in the database fails."""
pass
async def _update_tool_message(
session_id: str,
tool_call_id: str,
content: str,
prisma_client: Prisma | None,
) -> None:
"""Update tool message in database.
Args:
session_id: The session ID
tool_call_id: The tool call ID to update
content: The new content for the message
prisma_client: Optional Prisma client. If None, uses chat_service.
Raises:
ToolMessageUpdateError: If the database update fails. The caller should
handle this to avoid marking the task as completed with inconsistent state.
"""
try:
if prisma_client:
# Use provided Prisma client (for consumer with its own connection)
updated_count = await prisma_client.chatmessage.update_many(
where={
"sessionId": session_id,
"toolCallId": tool_call_id,
},
data={"content": content},
)
# Check if any rows were updated - 0 means message not found
if updated_count == 0:
raise ToolMessageUpdateError(
f"No message found with tool_call_id={tool_call_id} in session {session_id}"
)
else:
# Use service function (for webhook endpoint)
await chat_service._update_pending_operation(
session_id=session_id,
tool_call_id=tool_call_id,
result=content,
)
except ToolMessageUpdateError:
raise
except Exception as e:
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
raise ToolMessageUpdateError(
f"Failed to update tool message for tool_call_id={tool_call_id}: {e}"
) from e
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
"""Serialize result to JSON string with sensible defaults.
Args:
result: The result to serialize. Can be a dict, list, string,
number, boolean, or None.
Returns:
JSON string representation of the result. Returns '{"status": "completed"}'
only when result is explicitly None.
"""
if isinstance(result, str):
return result
if result is None:
return '{"status": "completed"}'
return orjson.dumps(result).decode("utf-8")
async def _save_agent_from_result(
result: dict[str, Any],
user_id: str | None,
tool_name: str,
) -> dict[str, Any]:
"""Save agent to library if result contains agent_json.
Args:
result: The result dict that may contain agent_json
user_id: The user ID to save the agent for
tool_name: The tool name (create_agent or edit_agent)
Returns:
Updated result dict with saved agent details, or original result if no agent_json
"""
if not user_id:
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
return result
agent_json = result.get("agent_json")
if not agent_json:
logger.warning(
f"[COMPLETION] {tool_name} completed but no agent_json in result"
)
return result
try:
from .tools.agent_generator import save_agent_to_library
is_update = tool_name == "edit_agent"
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id, is_update=is_update
)
logger.info(
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
)
# Return a response similar to AgentSavedResponse
return {
"type": "agent_saved",
"message": f"Agent '{created_graph.name}' has been saved to your library!",
"agent_id": created_graph.id,
"agent_name": created_graph.name,
"library_agent_id": library_agent.id,
"library_agent_link": f"/library/agents/{library_agent.id}",
"agent_page_link": f"/build?flowID={created_graph.id}",
}
except Exception as e:
logger.error(
f"[COMPLETION] Failed to save agent to library: {e}",
exc_info=True,
)
# Return error but don't fail the whole operation
# Sanitize agent_json to remove sensitive keys before returning
return {
"type": "error",
"message": f"Agent was generated but failed to save: {str(e)}",
"error": str(e),
"agent_json": _sanitize_agent_json(agent_json),
}
async def process_operation_success(
task: stream_registry.ActiveTask,
result: dict | str | None,
prisma_client: Prisma | None = None,
) -> None:
"""Handle successful operation completion.
Publishes the result to the stream registry, updates the database,
generates LLM continuation, and marks the task as completed.
Args:
task: The active task that completed
result: The result data from the operation
prisma_client: Optional Prisma client for database operations.
If None, uses chat_service._update_pending_operation instead.
Raises:
ToolMessageUpdateError: If the database update fails. The task will be
marked as failed instead of completed to avoid inconsistent state.
"""
# For agent generation tools, save the agent to library
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
# Serialize result for output (only substitute default when result is exactly None)
result_output = result if result is not None else {"status": "completed"}
output_str = (
result_output
if isinstance(result_output, str)
else orjson.dumps(result_output).decode("utf-8")
)
# Publish result to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamToolOutputAvailable(
toolCallId=task.tool_call_id,
toolName=task.tool_name,
output=output_str,
success=True,
),
)
# Update pending operation in database
# If this fails, we must not continue to mark the task as completed
result_str = serialize_result(result)
try:
await _update_tool_message(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
content=result_str,
prisma_client=prisma_client,
)
except ToolMessageUpdateError:
# DB update failed - mark task as failed to avoid inconsistent state
logger.error(
f"[COMPLETION] DB update failed for task {task.task_id}, "
"marking as failed instead of completed"
)
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText="Failed to save operation result to database"),
)
await stream_registry.mark_task_completed(task.task_id, status="failed")
raise
# Generate LLM continuation with streaming
try:
await chat_service._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
except Exception as e:
logger.error(
f"[COMPLETION] Failed to generate LLM continuation: {e}",
exc_info=True,
)
# Mark task as completed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="completed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
)
async def process_operation_failure(
task: stream_registry.ActiveTask,
error: str | None,
prisma_client: Prisma | None = None,
) -> None:
"""Handle failed operation completion.
Publishes the error to the stream registry, updates the database with
the error response, and marks the task as failed.
Args:
task: The active task that failed
error: The error message from the operation
prisma_client: Optional Prisma client for database operations.
If None, uses chat_service._update_pending_operation instead.
"""
error_msg = error or "Operation failed"
# Publish error to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText=error_msg),
)
# Update pending operation with error
# If this fails, we still continue to mark the task as failed
error_response = ErrorResponse(
message=error_msg,
error=error,
)
try:
await _update_tool_message(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
content=error_response.model_dump_json(),
prisma_client=prisma_client,
)
except ToolMessageUpdateError:
# DB update failed - log but continue with cleanup
logger.error(
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
"continuing with cleanup"
)
# Mark task as failed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="failed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")

View File

@@ -11,7 +11,7 @@ class ChatConfig(BaseSettings):
# OpenAI API Configuration
model: str = Field(
default="anthropic/claude-opus-4.6", description="Default model to use"
default="anthropic/claude-opus-4.5", description="Default model to use"
)
title_model: str = Field(
default="openai/gpt-4o-mini",
@@ -44,48 +44,6 @@ class ChatConfig(BaseSettings):
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
)
# Stream registry configuration for SSE reconnection
stream_ttl: int = Field(
default=3600,
description="TTL in seconds for stream data in Redis (1 hour)",
)
stream_max_length: int = Field(
default=10000,
description="Maximum number of messages to store per stream",
)
# Redis Streams configuration for completion consumer
stream_completion_name: str = Field(
default="chat:completions",
description="Redis Stream name for operation completions",
)
stream_consumer_group: str = Field(
default="chat_consumers",
description="Consumer group name for completion stream",
)
stream_claim_min_idle_ms: int = Field(
default=60000,
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
)
# Redis key prefixes for stream registry
task_meta_prefix: str = Field(
default="chat:task:meta:",
description="Prefix for task metadata hash keys",
)
task_stream_prefix: str = Field(
default="chat:stream:",
description="Prefix for task message stream keys",
)
task_op_prefix: str = Field(
default="chat:task:op:",
description="Prefix for operation ID to task ID mapping keys",
)
internal_api_key: str | None = Field(
default=None,
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
)
# Langfuse Prompt Management Configuration
# Note: Langfuse credentials are in Settings().secrets (settings.py)
langfuse_prompt_name: str = Field(
@@ -124,14 +82,6 @@ class ChatConfig(BaseSettings):
v = "https://openrouter.ai/api/v1"
return v
@field_validator("internal_api_key", mode="before")
@classmethod
def get_internal_api_key(cls, v):
"""Get internal API key from environment if not provided."""
if v is None:
v = os.getenv("CHAT_INTERNAL_API_KEY")
return v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",

View File

@@ -52,10 +52,6 @@ class StreamStart(StreamBaseResponse):
type: ResponseType = ResponseType.START
messageId: str = Field(..., description="Unique message ID")
taskId: str | None = Field(
default=None,
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
)
class StreamFinish(StreamBaseResponse):

View File

@@ -1,23 +1,19 @@
"""Chat API routes for chat session management and streaming via SSE."""
import logging
import uuid as uuid_module
from collections.abc import AsyncGenerator
from typing import Annotated
from autogpt_libs import auth
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
from fastapi import APIRouter, Depends, Query, Security
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from backend.util.exceptions import NotFoundError
from . import service as chat_service
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
config = ChatConfig()
@@ -59,15 +55,6 @@ class CreateSessionResponse(BaseModel):
user_id: str | None
class ActiveStreamInfo(BaseModel):
"""Information about an active stream for reconnection."""
task_id: str
last_message_id: str # Redis Stream message ID for resumption
operation_id: str # Operation ID for completion tracking
tool_name: str # Name of the tool being executed
class SessionDetailResponse(BaseModel):
"""Response model providing complete details for a chat session, including messages."""
@@ -76,7 +63,6 @@ class SessionDetailResponse(BaseModel):
updated_at: str
user_id: str | None
messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
class SessionSummaryResponse(BaseModel):
@@ -95,14 +81,6 @@ class ListSessionsResponse(BaseModel):
total: int
class OperationCompleteRequest(BaseModel):
"""Request model for external completion webhook."""
success: bool
result: dict | str | None = None
error: str | None = None
# ========== Routes ==========
@@ -188,14 +166,13 @@ async def get_session(
Retrieve the details of a specific chat session.
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
If there's an active stream for this session, returns the task_id for reconnection.
Args:
session_id: The unique identifier for the desired chat session.
user_id: The optional authenticated user ID, or None for anonymous access.
Returns:
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
SessionDetailResponse: Details for the requested session, or None if not found.
"""
session = await get_chat_session(session_id, user_id)
@@ -203,28 +180,11 @@ async def get_session(
raise NotFoundError(f"Session {session_id} not found.")
messages = [message.model_dump() for message in session.messages]
# Check if there's an active stream for this session
active_stream_info = None
active_task, last_message_id = await stream_registry.get_active_task_for_session(
session_id, user_id
logger.info(
f"Returning session {session_id}: "
f"message_count={len(messages)}, "
f"roles={[m.get('role') for m in messages]}"
)
if active_task:
# Filter out the in-progress assistant message from the session response.
# The client will receive the complete assistant response through the SSE
# stream replay instead, preventing duplicate content.
if messages and messages[-1].get("role") == "assistant":
messages = messages[:-1]
# Use "0-0" as last_message_id to replay the stream from the beginning.
# Since we filtered out the cached assistant message, the client needs
# the full stream to reconstruct the response.
active_stream_info = ActiveStreamInfo(
task_id=active_task.task_id,
last_message_id="0-0",
operation_id=active_task.operation_id,
tool_name=active_task.tool_name,
)
return SessionDetailResponse(
id=session.session_id,
@@ -232,7 +192,6 @@ async def get_session(
updated_at=session.updated_at.isoformat(),
user_id=session.user_id or None,
messages=messages,
active_stream=active_stream_info,
)
@@ -252,112 +211,49 @@ async def stream_chat_post(
- Tool call UI elements (if invoked)
- Tool execution results
The AI generation runs in a background task that continues even if the client disconnects.
All chunks are written to Redis for reconnection support. If the client disconnects,
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
Args:
session_id: The chat session identifier to associate with the streamed messages.
request: Request body containing message, is_user_message, and optional context.
user_id: Optional authenticated user ID.
Returns:
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
containing the task_id for reconnection.
StreamingResponse: SSE-formatted response chunks.
"""
import asyncio
session = await _validate_and_get_session(session_id, user_id)
# Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4())
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id="chat_stream", # Not a tool call, but needed for the model
tool_name="chat",
operation_id=operation_id,
)
# Background task that runs the AI generation independently of SSE connection
async def run_ai_generation():
try:
# Emit a start event with task_id for reconnection
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
await stream_registry.publish_chunk(task_id, start_chunk)
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
):
# Write to Redis (subscribers will receive via XREAD)
await stream_registry.publish_chunk(task_id, chunk)
# Mark task as completed
await stream_registry.mark_task_completed(task_id, "completed")
except Exception as e:
logger.error(
f"Error in background AI generation for session {session_id}: {e}"
)
await stream_registry.mark_task_completed(task_id, "failed")
# Start the AI generation in a background task
bg_task = asyncio.create_task(run_ai_generation())
await stream_registry.set_task_asyncio_task(task_id, bg_task)
# SSE endpoint that subscribes to the task's stream
async def event_generator() -> AsyncGenerator[str, None]:
subscriber_queue = None
try:
# Subscribe to the task stream (this replays existing messages + live updates)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id="0-0", # Get all messages from the beginning
)
if subscriber_queue is None:
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return
# Read from the subscriber queue and yield to SSE
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse()
except GeneratorExit:
pass # Client disconnected - background task continues
except Exception as e:
logger.error(f"Error in SSE stream for task {task_id}: {e}")
finally:
# Unsubscribe when client disconnects or stream ends to prevent resource leak
if subscriber_queue is not None:
try:
await stream_registry.unsubscribe_from_task(
task_id, subscriber_queue
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {task_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
yield "data: [DONE]\n\n"
chunk_count = 0
first_chunk_type: str | None = None
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
logger.info(
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
@@ -470,251 +366,6 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Task Streaming (SSE Reconnection) ==========
@router.get(
"/tasks/{task_id}/stream",
)
async def stream_task(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
last_message_id: str = Query(
default="0-0",
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
),
):
"""
Reconnect to a long-running task's SSE stream.
When a long-running operation (like agent generation) starts, the client
receives a task_id. If the connection drops, the client can reconnect
using this endpoint to resume receiving updates.
Args:
task_id: The task ID from the operation_started response.
user_id: Authenticated user ID for ownership validation.
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
Returns:
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
Raises:
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
"""
# Check task existence and expiry before subscribing
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
if error_code == "TASK_EXPIRED":
raise HTTPException(
status_code=410,
detail={
"code": "TASK_EXPIRED",
"message": "This operation has expired. Please try again.",
},
)
if error_code == "TASK_NOT_FOUND":
raise HTTPException(
status_code=404,
detail={
"code": "TASK_NOT_FOUND",
"message": f"Task {task_id} not found.",
},
)
# Validate ownership if task has an owner
if task and task.user_id and user_id != task.user_id:
raise HTTPException(
status_code=403,
detail={
"code": "ACCESS_DENIED",
"message": "You do not have access to this task.",
},
)
# Get subscriber queue from stream registry
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id=last_message_id,
)
if subscriber_queue is None:
raise HTTPException(
status_code=404,
detail={
"code": "TASK_NOT_FOUND",
"message": f"Task {task_id} not found or access denied.",
},
)
async def event_generator() -> AsyncGenerator[str, None]:
import asyncio
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
try:
while True:
try:
# Wait for next chunk with timeout for heartbeats
chunk = await asyncio.wait_for(
subscriber_queue.get(), timeout=heartbeat_interval
)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse()
except Exception as e:
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
finally:
# Unsubscribe when client disconnects or stream ends
try:
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {task_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
},
)
@router.get(
"/tasks/{task_id}",
)
async def get_task_status(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
) -> dict:
"""
Get the status of a long-running task.
Args:
task_id: The task ID to check.
user_id: Authenticated user ID for ownership validation.
Returns:
dict: Task status including task_id, status, tool_name, and operation_id.
Raises:
NotFoundError: If task_id is not found or user doesn't have access.
"""
task = await stream_registry.get_task(task_id)
if task is None:
raise NotFoundError(f"Task {task_id} not found.")
# Validate ownership - if task has an owner, requester must match
if task.user_id and user_id != task.user_id:
raise NotFoundError(f"Task {task_id} not found.")
return {
"task_id": task.task_id,
"session_id": task.session_id,
"status": task.status,
"tool_name": task.tool_name,
"operation_id": task.operation_id,
"created_at": task.created_at.isoformat(),
}
# ========== External Completion Webhook ==========
@router.post(
"/operations/{operation_id}/complete",
status_code=200,
)
async def complete_operation(
operation_id: str,
request: OperationCompleteRequest,
x_api_key: str | None = Header(default=None),
) -> dict:
"""
External completion webhook for long-running operations.
Called by Agent Generator (or other services) when an operation completes.
This triggers the stream registry to publish completion and continue LLM generation.
Args:
operation_id: The operation ID to complete.
request: Completion payload with success status and result/error.
x_api_key: Internal API key for authentication.
Returns:
dict: Status of the completion.
Raises:
HTTPException: If API key is invalid or operation not found.
"""
# Validate internal API key - reject if not configured or invalid
if not config.internal_api_key:
logger.error(
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
)
raise HTTPException(
status_code=503,
detail="Webhook not available: internal API key not configured",
)
if x_api_key != config.internal_api_key:
raise HTTPException(status_code=401, detail="Invalid API key")
# Find task by operation_id
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None:
raise HTTPException(
status_code=404,
detail=f"Operation {operation_id} not found",
)
logger.info(
f"Received completion webhook for operation {operation_id} "
f"(task_id={task.task_id}, success={request.success})"
)
if request.success:
await process_operation_success(task, request.result)
else:
await process_operation_failure(task, request.error)
return {"status": "ok", "task_id": task.task_id}
# ========== Configuration ==========
@router.get("/config/ttl", status_code=200)
async def get_ttl_config() -> dict:
"""
Get the stream TTL configuration.
Returns the Time-To-Live settings for chat streams, which determines
how long clients can reconnect to an active stream.
Returns:
dict: TTL configuration with seconds and milliseconds values.
"""
return {
"stream_ttl_seconds": config.stream_ttl,
"stream_ttl_ms": config.stream_ttl * 1000,
}
# ========== Health Check ==========

File diff suppressed because it is too large Load Diff

View File

@@ -1,704 +0,0 @@
"""Stream registry for managing reconnectable SSE streams.
This module provides a registry for tracking active streaming tasks and their
messages. It uses Redis for all state management (no in-memory state), making
pods stateless and horizontally scalable.
Architecture:
- Redis Stream: Persists all messages for replay and real-time delivery
- Redis Hash: Task metadata (status, session_id, etc.)
Subscribers:
1. Replay missed messages from Redis Stream (XREAD)
2. Listen for live updates via blocking XREAD
3. No in-memory state required on the subscribing pod
"""
import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Literal
import orjson
from backend.data.redis_client import get_redis_async
from .config import ChatConfig
from .response_model import StreamBaseResponse, StreamError, StreamFinish
logger = logging.getLogger(__name__)
config = ChatConfig()
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
_local_tasks: dict[str, asyncio.Task] = {}
# Track listener tasks per subscriber queue for cleanup
# Maps queue id() to (task_id, asyncio.Task) for proper cleanup on unsubscribe
_listener_tasks: dict[int, tuple[str, asyncio.Task]] = {}
# Timeout for putting chunks into subscriber queues (seconds)
# If the queue is full and doesn't drain within this time, send an overflow error
QUEUE_PUT_TIMEOUT = 5.0
# Lua script for atomic compare-and-swap status update (idempotent completion)
# Returns 1 if status was updated, 0 if already completed/failed
COMPLETE_TASK_SCRIPT = """
local current = redis.call("HGET", KEYS[1], "status")
if current == "running" then
redis.call("HSET", KEYS[1], "status", ARGV[1])
return 1
end
return 0
"""
@dataclass
class ActiveTask:
"""Represents an active streaming task (metadata only, no in-memory queues)."""
task_id: str
session_id: str
user_id: str | None
tool_call_id: str
tool_name: str
operation_id: str
status: Literal["running", "completed", "failed"] = "running"
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
asyncio_task: asyncio.Task | None = None
def _get_task_meta_key(task_id: str) -> str:
"""Get Redis key for task metadata."""
return f"{config.task_meta_prefix}{task_id}"
def _get_task_stream_key(task_id: str) -> str:
"""Get Redis key for task message stream."""
return f"{config.task_stream_prefix}{task_id}"
def _get_operation_mapping_key(operation_id: str) -> str:
"""Get Redis key for operation_id to task_id mapping."""
return f"{config.task_op_prefix}{operation_id}"
async def create_task(
task_id: str,
session_id: str,
user_id: str | None,
tool_call_id: str,
tool_name: str,
operation_id: str,
) -> ActiveTask:
"""Create a new streaming task in Redis.
Args:
task_id: Unique identifier for the task
session_id: Chat session ID
user_id: User ID (may be None for anonymous)
tool_call_id: Tool call ID from the LLM
tool_name: Name of the tool being executed
operation_id: Operation ID for webhook callbacks
Returns:
The created ActiveTask instance (metadata only)
"""
task = ActiveTask(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
operation_id=operation_id,
)
# Store metadata in Redis
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
op_key = _get_operation_mapping_key(operation_id)
await redis.hset( # type: ignore[misc]
meta_key,
mapping={
"task_id": task_id,
"session_id": session_id,
"user_id": user_id or "",
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"operation_id": operation_id,
"status": task.status,
"created_at": task.created_at.isoformat(),
},
)
await redis.expire(meta_key, config.stream_ttl)
# Create operation_id -> task_id mapping for webhook lookups
await redis.set(op_key, task_id, ex=config.stream_ttl)
logger.debug(f"Created task {task_id} for session {session_id}")
return task
async def publish_chunk(
task_id: str,
chunk: StreamBaseResponse,
) -> str:
"""Publish a chunk to Redis Stream.
All delivery is via Redis Streams - no in-memory state.
Args:
task_id: Task ID to publish to
chunk: The stream response chunk to publish
Returns:
The Redis Stream message ID
"""
chunk_json = chunk.model_dump_json()
message_id = "0-0"
try:
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
# Write to Redis Stream for persistence and real-time delivery
raw_id = await redis.xadd(
stream_key,
{"data": chunk_json},
maxlen=config.stream_max_length,
)
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
# Set TTL on stream to match task metadata TTL
await redis.expire(stream_key, config.stream_ttl)
except Exception as e:
logger.error(
f"Failed to publish chunk for task {task_id}: {e}",
exc_info=True,
)
return message_id
async def subscribe_to_task(
task_id: str,
user_id: str | None,
last_message_id: str = "0-0",
) -> asyncio.Queue[StreamBaseResponse] | None:
"""Subscribe to a task's stream with replay of missed messages.
This is fully stateless - uses Redis Stream for replay and pub/sub for live updates.
Args:
task_id: Task ID to subscribe to
user_id: User ID for ownership validation
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
Returns:
An asyncio Queue that will receive stream chunks, or None if task not found
or user doesn't have access
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
logger.debug(f"Task {task_id} not found in Redis")
return None
# Note: Redis client uses decode_responses=True, so keys are strings
task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None
# Validate ownership - if task has an owner, requester must match
if task_user_id:
if user_id != task_user_id:
logger.warning(
f"User {user_id} denied access to task {task_id} "
f"owned by {task_user_id}"
)
return None
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
stream_key = _get_task_stream_key(task_id)
# Step 1: Replay messages from Redis Stream
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
replayed_count = 0
replay_last_id = last_message_id
if messages:
for _stream_name, stream_messages in messages:
for msg_id, msg_data in stream_messages:
replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
# Note: Redis client uses decode_responses=True, so keys are strings
if "data" in msg_data:
try:
chunk_data = orjson.loads(msg_data["data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
await subscriber_queue.put(chunk)
replayed_count += 1
except Exception as e:
logger.warning(f"Failed to replay message: {e}")
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
# Step 2: If task is still running, start stream listener for live updates
if task_status == "running":
listener_task = asyncio.create_task(
_stream_listener(task_id, subscriber_queue, replay_last_id)
)
# Track listener task for cleanup on unsubscribe
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
else:
# Task is completed/failed - add finish marker
await subscriber_queue.put(StreamFinish())
return subscriber_queue
async def _stream_listener(
task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
last_replayed_id: str,
) -> None:
"""Listen to Redis Stream for new messages using blocking XREAD.
This approach avoids the duplicate message issue that can occur with pub/sub
when messages are published during the gap between replay and subscription.
Args:
task_id: Task ID to listen for
subscriber_queue: Queue to deliver messages to
last_replayed_id: Last message ID from replay (continue from here)
"""
queue_id = id(subscriber_queue)
# Track the last successfully delivered message ID for recovery hints
last_delivered_id = last_replayed_id
try:
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
current_id = last_replayed_id
while True:
# Block for up to 30 seconds waiting for new messages
# This allows periodic checking if task is still running
messages = await redis.xread(
{stream_key: current_id}, block=30000, count=100
)
if not messages:
# Timeout - check if task is still running
meta_key = _get_task_meta_key(task_id)
status = await redis.hget(meta_key, "status") # type: ignore[misc]
if status and status != "running":
try:
await asyncio.wait_for(
subscriber_queue.put(StreamFinish()),
timeout=QUEUE_PUT_TIMEOUT,
)
except asyncio.TimeoutError:
logger.warning(
f"Timeout delivering finish event for task {task_id}"
)
break
continue
for _stream_name, stream_messages in messages:
for msg_id, msg_data in stream_messages:
current_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
if "data" not in msg_data:
continue
try:
chunk_data = orjson.loads(msg_data["data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
try:
await asyncio.wait_for(
subscriber_queue.put(chunk),
timeout=QUEUE_PUT_TIMEOUT,
)
# Update last delivered ID on successful delivery
last_delivered_id = current_id
except asyncio.TimeoutError:
logger.warning(
f"Subscriber queue full for task {task_id}, "
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
)
# Send overflow error with recovery info
try:
overflow_error = StreamError(
errorText="Message delivery timeout - some messages may have been missed",
code="QUEUE_OVERFLOW",
details={
"last_delivered_id": last_delivered_id,
"recovery_hint": f"Reconnect with last_message_id={last_delivered_id}",
},
)
subscriber_queue.put_nowait(overflow_error)
except asyncio.QueueFull:
# Queue is completely stuck, nothing more we can do
logger.error(
f"Cannot deliver overflow error for task {task_id}, "
"queue completely blocked"
)
# Stop listening on finish
if isinstance(chunk, StreamFinish):
return
except Exception as e:
logger.warning(f"Error processing stream message: {e}")
except asyncio.CancelledError:
logger.debug(f"Stream listener cancelled for task {task_id}")
raise # Re-raise to propagate cancellation
except Exception as e:
logger.error(f"Stream listener error for task {task_id}: {e}")
# On error, send finish to unblock subscriber
try:
await asyncio.wait_for(
subscriber_queue.put(StreamFinish()),
timeout=QUEUE_PUT_TIMEOUT,
)
except (asyncio.TimeoutError, asyncio.QueueFull):
logger.warning(
f"Could not deliver finish event for task {task_id} after error"
)
finally:
# Clean up listener task mapping on exit
_listener_tasks.pop(queue_id, None)
async def mark_task_completed(
task_id: str,
status: Literal["completed", "failed"] = "completed",
) -> bool:
"""Mark a task as completed and publish finish event.
This is idempotent - calling multiple times with the same task_id is safe.
Uses atomic compare-and-swap via Lua script to prevent race conditions.
Status is updated first (source of truth), then finish event is published (best-effort).
Args:
task_id: Task ID to mark as completed
status: Final status ("completed" or "failed")
Returns:
True if task was newly marked completed, False if already completed/failed
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
# Atomic compare-and-swap: only update if status is "running"
# This prevents race conditions when multiple callers try to complete simultaneously
result = await redis.eval(COMPLETE_TASK_SCRIPT, 1, meta_key, status) # type: ignore[misc]
if result == 0:
logger.debug(f"Task {task_id} already completed/failed, skipping")
return False
# THEN publish finish event (best-effort - listeners can detect via status polling)
try:
await publish_chunk(task_id, StreamFinish())
except Exception as e:
logger.error(
f"Failed to publish finish event for task {task_id}: {e}. "
"Listeners will detect completion via status polling."
)
# Clean up local task reference if exists
_local_tasks.pop(task_id, None)
return True
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
"""Find a task by its operation ID.
Used by webhook callbacks to locate the task to update.
Args:
operation_id: Operation ID to search for
Returns:
ActiveTask if found, None otherwise
"""
redis = await get_redis_async()
op_key = _get_operation_mapping_key(operation_id)
task_id = await redis.get(op_key)
if not task_id:
return None
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
return await get_task(task_id_str)
async def get_task(task_id: str) -> ActiveTask | None:
"""Get a task by its ID from Redis.
Args:
task_id: Task ID to look up
Returns:
ActiveTask if found, None otherwise
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
return None
# Note: Redis client uses decode_responses=True, so keys/values are strings
return ActiveTask(
task_id=meta.get("task_id", ""),
session_id=meta.get("session_id", ""),
user_id=meta.get("user_id", "") or None,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
operation_id=meta.get("operation_id", ""),
status=meta.get("status", "running"), # type: ignore[arg-type]
)
async def get_task_with_expiry_info(
task_id: str,
) -> tuple[ActiveTask | None, str | None]:
"""Get a task by its ID with expiration detection.
Returns (task, error_code) where error_code is:
- None if task found
- "TASK_EXPIRED" if stream exists but metadata is gone (TTL expired)
- "TASK_NOT_FOUND" if neither exists
Args:
task_id: Task ID to look up
Returns:
Tuple of (ActiveTask or None, error_code or None)
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
stream_key = _get_task_stream_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
# Check if stream still has data (metadata expired but stream hasn't)
stream_len = await redis.xlen(stream_key)
if stream_len > 0:
return None, "TASK_EXPIRED"
return None, "TASK_NOT_FOUND"
# Note: Redis client uses decode_responses=True, so keys/values are strings
return (
ActiveTask(
task_id=meta.get("task_id", ""),
session_id=meta.get("session_id", ""),
user_id=meta.get("user_id", "") or None,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
operation_id=meta.get("operation_id", ""),
status=meta.get("status", "running"), # type: ignore[arg-type]
),
None,
)
async def get_active_task_for_session(
session_id: str,
user_id: str | None = None,
) -> tuple[ActiveTask | None, str]:
"""Get the active (running) task for a session, if any.
Scans Redis for tasks matching the session_id with status="running".
Args:
session_id: Session ID to look up
user_id: User ID for ownership validation (optional)
Returns:
Tuple of (ActiveTask if found and running, last_message_id from Redis Stream)
"""
redis = await get_redis_async()
# Scan Redis for task metadata keys
cursor = 0
tasks_checked = 0
while True:
cursor, keys = await redis.scan(
cursor, match=f"{config.task_meta_prefix}*", count=100
)
for key in keys:
tasks_checked += 1
meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc]
if not meta:
continue
# Note: Redis client uses decode_responses=True, so keys/values are strings
task_session_id = meta.get("session_id", "")
task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None
task_id = meta.get("task_id", "")
if task_session_id == session_id and task_status == "running":
# Validate ownership - if task has an owner, requester must match
if task_user_id and user_id != task_user_id:
continue
# Get the last message ID from Redis Stream
stream_key = _get_task_stream_key(task_id)
last_id = "0-0"
try:
messages = await redis.xrevrange(stream_key, count=1)
if messages:
msg_id = messages[0][0]
last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
except Exception as e:
logger.warning(f"Failed to get last message ID: {e}")
return (
ActiveTask(
task_id=task_id,
session_id=task_session_id,
user_id=task_user_id,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
operation_id=meta.get("operation_id", ""),
status="running",
),
last_id,
)
if cursor == 0:
break
return None, "0-0"
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
"""Reconstruct a StreamBaseResponse from JSON data.
Args:
chunk_data: Parsed JSON data from Redis
Returns:
Reconstructed response object, or None if unknown type
"""
from .response_model import (
ResponseType,
StreamError,
StreamFinish,
StreamHeartbeat,
StreamStart,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
# Map response types to their corresponding classes
type_to_class: dict[str, type[StreamBaseResponse]] = {
ResponseType.START.value: StreamStart,
ResponseType.FINISH.value: StreamFinish,
ResponseType.TEXT_START.value: StreamTextStart,
ResponseType.TEXT_DELTA.value: StreamTextDelta,
ResponseType.TEXT_END.value: StreamTextEnd,
ResponseType.TOOL_INPUT_START.value: StreamToolInputStart,
ResponseType.TOOL_INPUT_AVAILABLE.value: StreamToolInputAvailable,
ResponseType.TOOL_OUTPUT_AVAILABLE.value: StreamToolOutputAvailable,
ResponseType.ERROR.value: StreamError,
ResponseType.USAGE.value: StreamUsage,
ResponseType.HEARTBEAT.value: StreamHeartbeat,
}
chunk_type = chunk_data.get("type")
chunk_class = type_to_class.get(chunk_type) # type: ignore[arg-type]
if chunk_class is None:
logger.warning(f"Unknown chunk type: {chunk_type}")
return None
try:
return chunk_class(**chunk_data)
except Exception as e:
logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}")
return None
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
"""Track the asyncio.Task for a task (local reference only).
This is just for cleanup purposes - the task state is in Redis.
Args:
task_id: Task ID
asyncio_task: The asyncio Task to track
"""
_local_tasks[task_id] = asyncio_task
async def unsubscribe_from_task(
task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
) -> None:
"""Clean up when a subscriber disconnects.
Cancels the XREAD-based listener task associated with this subscriber queue
to prevent resource leaks.
Args:
task_id: Task ID
subscriber_queue: The subscriber's queue used to look up the listener task
"""
queue_id = id(subscriber_queue)
listener_entry = _listener_tasks.pop(queue_id, None)
if listener_entry is None:
logger.debug(
f"No listener task found for task {task_id} queue {queue_id} "
"(may have already completed)"
)
return
stored_task_id, listener_task = listener_entry
if stored_task_id != task_id:
logger.warning(
f"Task ID mismatch in unsubscribe: expected {task_id}, "
f"found {stored_task_id}"
)
if listener_task.done():
logger.debug(f"Listener task for task {task_id} already completed")
return
# Cancel the listener task
listener_task.cancel()
try:
# Wait for the task to be cancelled with a timeout
await asyncio.wait_for(listener_task, timeout=5.0)
except asyncio.CancelledError:
# Expected - the task was successfully cancelled
pass
except asyncio.TimeoutError:
logger.warning(
f"Timeout waiting for listener task cancellation for task {task_id}"
)
except Exception as e:
logger.error(f"Error during listener task cancellation for task {task_id}: {e}")
logger.debug(f"Successfully unsubscribed from task {task_id}")

View File

@@ -1,79 +0,0 @@
# CoPilot Tools - Future Ideas
## Multimodal Image Support for CoPilot
**Problem:** CoPilot uses a vision-capable model but can't "see" workspace images. When a block generates an image and returns `workspace://abc123`, CoPilot can't evaluate it (e.g., checking blog thumbnail quality).
**Backend Solution:**
When preparing messages for the LLM, detect `workspace://` image references and convert them to proper image content blocks:
```python
# Before sending to LLM, scan for workspace image references
# and inject them as image content parts
# Example message transformation:
# FROM: {"role": "assistant", "content": "Generated image: workspace://abc123"}
# TO: {"role": "assistant", "content": [
# {"type": "text", "text": "Generated image: workspace://abc123"},
# {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
# ]}
```
**Where to implement:**
- In the chat stream handler before calling the LLM
- Or in a message preprocessing step
- Need to fetch image from workspace, convert to base64, add as image content
**Considerations:**
- Only do this for image MIME types (image/png, image/jpeg, etc.)
- May want a size limit (don't pass 10MB images)
- Track which images were "shown" to the AI for frontend indicator
- Cost implications - vision API calls are more expensive
**Frontend Solution:**
Show visual indicator on workspace files in chat:
- If AI saw the image: normal display
- If AI didn't see it: overlay icon saying "AI can't see this image"
Requires response metadata indicating which `workspace://` refs were passed to the model.
---
## Output Post-Processing Layer for run_block
**Problem:** Many blocks produce large outputs that:
- Consume massive context (100KB base64 image = ~133KB tokens)
- Can't fit in conversation
- Break things and cause high LLM costs
**Proposed Solution:** Instead of modifying individual blocks or `store_media_file()`, implement a centralized output processor in `run_block.py` that handles outputs before they're returned to CoPilot.
**Benefits:**
1. **Centralized** - one place to handle all output processing
2. **Future-proof** - new blocks automatically get output processing
3. **Keeps blocks pure** - they don't need to know about context constraints
4. **Handles all large outputs** - not just images
**Processing Rules:**
- Detect base64 data URIs → save to workspace, return `workspace://` reference
- Truncate very long strings (>N chars) with truncation note
- Summarize large arrays/lists (e.g., "Array with 1000 items, first 5: [...]")
- Handle nested large outputs in dicts recursively
- Cap total output size
**Implementation Location:** `run_block.py` after block execution, before returning `BlockOutputResponse`
**Example:**
```python
def _process_outputs_for_context(
outputs: dict[str, list[Any]],
workspace_manager: WorkspaceManager,
max_string_length: int = 10000,
max_array_preview: int = 5,
) -> dict[str, list[Any]]:
"""Process block outputs to prevent context bloat."""
processed = {}
for name, values in outputs.items():
processed[name] = [_process_value(v, workspace_manager) for v in values]
return processed
```

View File

@@ -10,7 +10,6 @@ from .add_understanding import AddUnderstandingTool
from .agent_output import AgentOutputTool
from .base import BaseTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
from .edit_agent import EditAgentTool
from .find_agent import FindAgentTool
from .find_block import FindBlockTool
@@ -19,12 +18,6 @@ from .get_doc_page import GetDocPageTool
from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .search_docs import SearchDocsTool
from .workspace_files import (
DeleteWorkspaceFileTool,
ListWorkspaceFilesTool,
ReadWorkspaceFileTool,
WriteWorkspaceFileTool,
)
if TYPE_CHECKING:
from backend.api.features.chat.response_model import StreamToolOutputAvailable
@@ -35,7 +28,6 @@ logger = logging.getLogger(__name__)
TOOL_REGISTRY: dict[str, BaseTool] = {
"add_understanding": AddUnderstandingTool(),
"create_agent": CreateAgentTool(),
"customize_agent": CustomizeAgentTool(),
"edit_agent": EditAgentTool(),
"find_agent": FindAgentTool(),
"find_block": FindBlockTool(),
@@ -45,11 +37,6 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"view_agent_output": AgentOutputTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
# Workspace tools for CoPilot file operations
"list_workspace_files": ListWorkspaceFilesTool(),
"read_workspace_file": ReadWorkspaceFileTool(),
"write_workspace_file": WriteWorkspaceFileTool(),
"delete_workspace_file": DeleteWorkspaceFileTool(),
}
# Export individual tool instances for backwards compatibility

View File

@@ -2,58 +2,27 @@
from .core import (
AgentGeneratorNotConfiguredError,
AgentJsonValidationError,
AgentSummary,
DecompositionResult,
DecompositionStep,
LibraryAgentSummary,
MarketplaceAgentSummary,
customize_template,
decompose_goal,
enrich_library_agents_from_steps,
extract_search_terms_from_steps,
extract_uuids_from_text,
generate_agent,
generate_agent_patch,
get_agent_as_json,
get_all_relevant_agents_for_generation,
get_library_agent_by_graph_id,
get_library_agent_by_id,
get_library_agents_for_generation,
graph_to_json,
json_to_graph,
save_agent_to_library,
search_marketplace_agents_for_generation,
)
from .errors import get_user_message_for_error
from .service import health_check as check_external_service_health
from .service import is_external_service_configured
__all__ = [
"AgentGeneratorNotConfiguredError",
"AgentJsonValidationError",
"AgentSummary",
"DecompositionResult",
"DecompositionStep",
"LibraryAgentSummary",
"MarketplaceAgentSummary",
"check_external_service_health",
"customize_template",
# Core functions
"decompose_goal",
"enrich_library_agents_from_steps",
"extract_search_terms_from_steps",
"extract_uuids_from_text",
"generate_agent",
"generate_agent_patch",
"get_agent_as_json",
"get_all_relevant_agents_for_generation",
"get_library_agent_by_graph_id",
"get_library_agent_by_id",
"get_library_agents_for_generation",
"get_user_message_for_error",
"graph_to_json",
"is_external_service_configured",
"json_to_graph",
"save_agent_to_library",
"search_marketplace_agents_for_generation",
"get_agent_as_json",
"json_to_graph",
# Exceptions
"AgentGeneratorNotConfiguredError",
# Service
"is_external_service_configured",
"check_external_service_health",
]

View File

@@ -1,17 +1,13 @@
"""Core agent generation functions."""
import logging
import re
import uuid
from typing import Any, NotRequired, TypedDict
from typing import Any
from backend.api.features.library import db as library_db
from backend.api.features.store import db as store_db
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
from backend.util.exceptions import DatabaseError, NotFoundError
from backend.data.graph import Graph, Link, Node, create_graph
from .service import (
customize_template_external,
decompose_goal_external,
generate_agent_external,
generate_agent_patch_external,
@@ -21,72 +17,6 @@ from .service import (
logger = logging.getLogger(__name__)
class ExecutionSummary(TypedDict):
"""Summary of a single execution for quality assessment."""
status: str
correctness_score: NotRequired[float]
activity_summary: NotRequired[str]
class LibraryAgentSummary(TypedDict):
"""Summary of a library agent for sub-agent composition.
Includes recent executions to help the LLM decide whether to use this agent.
Each execution shows status, correctness_score (0-1), and activity_summary.
"""
graph_id: str
graph_version: int
name: str
description: str
input_schema: dict[str, Any]
output_schema: dict[str, Any]
recent_executions: NotRequired[list[ExecutionSummary]]
class MarketplaceAgentSummary(TypedDict):
"""Summary of a marketplace agent for sub-agent composition."""
name: str
description: str
sub_heading: str
creator: str
is_marketplace_agent: bool
class DecompositionStep(TypedDict, total=False):
"""A single step in decomposed instructions."""
description: str
action: str
block_name: str
tool: str
name: str
class DecompositionResult(TypedDict, total=False):
"""Result from decompose_goal - can be instructions, questions, or error."""
type: str
steps: list[DecompositionStep]
questions: list[dict[str, Any]]
error: str
error_type: str
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
def _to_dict_list(
agents: list[AgentSummary] | list[dict[str, Any]] | None,
) -> list[dict[str, Any]] | None:
"""Convert typed agent summaries to plain dicts for external service calls."""
if agents is None:
return None
return [dict(a) for a in agents]
class AgentGeneratorNotConfiguredError(Exception):
"""Raised when the external Agent Generator service is not configured."""
@@ -106,422 +36,15 @@ def _check_service_configured() -> None:
)
_UUID_PATTERN = re.compile(
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
re.IGNORECASE,
)
def extract_uuids_from_text(text: str) -> list[str]:
"""Extract all UUID v4 strings from text.
Args:
text: Text that may contain UUIDs (e.g., user's goal description)
Returns:
List of unique UUIDs found in the text (lowercase)
"""
matches = _UUID_PATTERN.findall(text)
return list({m.lower() for m in matches})
async def get_library_agent_by_id(
user_id: str, agent_id: str
) -> LibraryAgentSummary | None:
"""Fetch a specific library agent by its ID (library agent ID or graph_id).
This function tries multiple lookup strategies:
1. First tries to find by graph_id (AgentGraph primary key)
2. If not found, tries to find by library agent ID (LibraryAgent primary key)
This handles both cases:
- User provides graph_id (e.g., from AgentExecutorBlock)
- User provides library agent ID (e.g., from library URL)
Args:
user_id: The user ID
agent_id: The ID to look up (can be graph_id or library agent ID)
Returns:
LibraryAgentSummary if found, None otherwise
"""
try:
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
if agent:
logger.debug(f"Found library agent by graph_id: {agent.name}")
return LibraryAgentSummary(
graph_id=agent.graph_id,
graph_version=agent.graph_version,
name=agent.name,
description=agent.description,
input_schema=agent.input_schema,
output_schema=agent.output_schema,
)
except DatabaseError:
raise
except Exception as e:
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
try:
agent = await library_db.get_library_agent(agent_id, user_id)
if agent:
logger.debug(f"Found library agent by library_id: {agent.name}")
return LibraryAgentSummary(
graph_id=agent.graph_id,
graph_version=agent.graph_version,
name=agent.name,
description=agent.description,
input_schema=agent.input_schema,
output_schema=agent.output_schema,
)
except NotFoundError:
logger.debug(f"Library agent not found by library_id: {agent_id}")
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by library_id {agent_id}: {e}",
exc_info=True,
)
return None
get_library_agent_by_graph_id = get_library_agent_by_id
async def get_library_agents_for_generation(
user_id: str,
search_query: str | None = None,
exclude_graph_id: str | None = None,
max_results: int = 15,
) -> list[LibraryAgentSummary]:
"""Fetch user's library agents formatted for Agent Generator.
Uses search-based fetching to return relevant agents instead of all agents.
This is more scalable for users with large libraries.
Includes recent_executions list to help the LLM assess agent quality:
- Each execution has status, correctness_score (0-1), and activity_summary
- This gives the LLM concrete examples of recent performance
Args:
user_id: The user ID
search_query: Optional search term to find relevant agents (user's goal/description)
exclude_graph_id: Optional graph ID to exclude (prevents circular references)
max_results: Maximum number of agents to return (default 15)
Returns:
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
"""
try:
response = await library_db.list_library_agents(
user_id=user_id,
search_term=search_query,
page=1,
page_size=max_results,
include_executions=True,
)
results: list[LibraryAgentSummary] = []
for agent in response.agents:
if exclude_graph_id is not None and agent.graph_id == exclude_graph_id:
continue
summary = LibraryAgentSummary(
graph_id=agent.graph_id,
graph_version=agent.graph_version,
name=agent.name,
description=agent.description,
input_schema=agent.input_schema,
output_schema=agent.output_schema,
)
if agent.recent_executions:
exec_summaries: list[ExecutionSummary] = []
for ex in agent.recent_executions:
exec_sum = ExecutionSummary(status=ex.status)
if ex.correctness_score is not None:
exec_sum["correctness_score"] = ex.correctness_score
if ex.activity_summary:
exec_sum["activity_summary"] = ex.activity_summary
exec_summaries.append(exec_sum)
summary["recent_executions"] = exec_summaries
results.append(summary)
return results
except DatabaseError:
raise
except Exception as e:
logger.warning(f"Failed to fetch library agents: {e}")
return []
async def search_marketplace_agents_for_generation(
search_query: str,
max_results: int = 10,
) -> list[LibraryAgentSummary]:
"""Search marketplace agents formatted for Agent Generator.
Fetches marketplace agents and their full schemas so they can be used
as sub-agents in generated workflows.
Args:
search_query: Search term to find relevant public agents
max_results: Maximum number of agents to return (default 10)
Returns:
List of LibraryAgentSummary with full input/output schemas
"""
try:
response = await store_db.get_store_agents(
search_query=search_query,
page=1,
page_size=max_results,
)
agents_with_graphs = [
agent for agent in response.agents if agent.agent_graph_id
]
if not agents_with_graphs:
return []
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
graphs = await get_store_listed_graphs(*graph_ids)
results: list[LibraryAgentSummary] = []
for agent in agents_with_graphs:
graph_id = agent.agent_graph_id
if graph_id and graph_id in graphs:
graph = graphs[graph_id]
results.append(
LibraryAgentSummary(
graph_id=graph.id,
graph_version=graph.version,
name=agent.agent_name,
description=agent.description,
input_schema=graph.input_schema,
output_schema=graph.output_schema,
)
)
return results
except Exception as e:
logger.warning(f"Failed to search marketplace agents: {e}")
return []
async def get_all_relevant_agents_for_generation(
user_id: str,
search_query: str | None = None,
exclude_graph_id: str | None = None,
include_library: bool = True,
include_marketplace: bool = True,
max_library_results: int = 15,
max_marketplace_results: int = 10,
) -> list[AgentSummary]:
"""Fetch relevant agents from library and/or marketplace.
Searches both user's library and marketplace by default.
Explicitly mentioned UUIDs in the search query are always looked up.
Args:
user_id: The user ID
search_query: Search term to find relevant agents (user's goal/description)
exclude_graph_id: Optional graph ID to exclude (prevents circular references)
include_library: Whether to search user's library (default True)
include_marketplace: Whether to also search marketplace (default True)
max_library_results: Max library agents to return (default 15)
max_marketplace_results: Max marketplace agents to return (default 10)
Returns:
List of AgentSummary with full schemas (both library and marketplace agents)
"""
agents: list[AgentSummary] = []
seen_graph_ids: set[str] = set()
if search_query:
mentioned_uuids = extract_uuids_from_text(search_query)
for graph_id in mentioned_uuids:
if graph_id == exclude_graph_id:
continue
agent = await get_library_agent_by_graph_id(user_id, graph_id)
agent_graph_id = agent.get("graph_id") if agent else None
if agent and agent_graph_id and agent_graph_id not in seen_graph_ids:
agents.append(agent)
seen_graph_ids.add(agent_graph_id)
logger.debug(
f"Found explicitly mentioned agent: {agent.get('name') or 'Unknown'}"
)
if include_library:
library_agents = await get_library_agents_for_generation(
user_id=user_id,
search_query=search_query,
exclude_graph_id=exclude_graph_id,
max_results=max_library_results,
)
for agent in library_agents:
graph_id = agent.get("graph_id")
if graph_id and graph_id not in seen_graph_ids:
agents.append(agent)
seen_graph_ids.add(graph_id)
if include_marketplace and search_query:
marketplace_agents = await search_marketplace_agents_for_generation(
search_query=search_query,
max_results=max_marketplace_results,
)
for agent in marketplace_agents:
graph_id = agent.get("graph_id")
if graph_id and graph_id not in seen_graph_ids:
agents.append(agent)
seen_graph_ids.add(graph_id)
return agents
def extract_search_terms_from_steps(
decomposition_result: DecompositionResult | dict[str, Any],
) -> list[str]:
"""Extract search terms from decomposed instruction steps.
Analyzes the decomposition result to extract relevant keywords
for additional library agent searches.
Args:
decomposition_result: Result from decompose_goal containing steps
Returns:
List of unique search terms extracted from steps
"""
search_terms: list[str] = []
if decomposition_result.get("type") != "instructions":
return search_terms
steps = decomposition_result.get("steps", [])
if not steps:
return search_terms
step_keys: list[str] = ["description", "action", "block_name", "tool", "name"]
for step in steps:
for key in step_keys:
value = step.get(key) # type: ignore[union-attr]
if isinstance(value, str) and len(value) > 3:
search_terms.append(value)
seen: set[str] = set()
unique_terms: list[str] = []
for term in search_terms:
term_lower = term.lower()
if term_lower not in seen:
seen.add(term_lower)
unique_terms.append(term)
return unique_terms
async def enrich_library_agents_from_steps(
user_id: str,
decomposition_result: DecompositionResult | dict[str, Any],
existing_agents: list[AgentSummary] | list[dict[str, Any]],
exclude_graph_id: str | None = None,
include_marketplace: bool = True,
max_additional_results: int = 10,
) -> list[AgentSummary] | list[dict[str, Any]]:
"""Enrich library agents list with additional searches based on decomposed steps.
This implements two-phase search: after decomposition, we search for additional
relevant agents based on the specific steps identified.
Args:
user_id: The user ID
decomposition_result: Result from decompose_goal containing steps
existing_agents: Already fetched library agents from initial search
exclude_graph_id: Optional graph ID to exclude
include_marketplace: Whether to also search marketplace
max_additional_results: Max additional agents per search term (default 10)
Returns:
Combined list of library agents (existing + newly discovered)
"""
search_terms = extract_search_terms_from_steps(decomposition_result)
if not search_terms:
return existing_agents
existing_ids: set[str] = set()
existing_names: set[str] = set()
for agent in existing_agents:
agent_name = agent.get("name")
if agent_name and isinstance(agent_name, str):
existing_names.add(agent_name.lower())
graph_id = agent.get("graph_id") # type: ignore[call-overload]
if graph_id and isinstance(graph_id, str):
existing_ids.add(graph_id)
all_agents: list[AgentSummary] | list[dict[str, Any]] = list(existing_agents)
for term in search_terms[:3]:
try:
additional_agents = await get_all_relevant_agents_for_generation(
user_id=user_id,
search_query=term,
exclude_graph_id=exclude_graph_id,
include_marketplace=include_marketplace,
max_library_results=max_additional_results,
max_marketplace_results=5,
)
for agent in additional_agents:
agent_name = agent.get("name")
if not agent_name or not isinstance(agent_name, str):
continue
agent_name_lower = agent_name.lower()
if agent_name_lower in existing_names:
continue
graph_id = agent.get("graph_id") # type: ignore[call-overload]
if graph_id and graph_id in existing_ids:
continue
all_agents.append(agent)
existing_names.add(agent_name_lower)
if graph_id and isinstance(graph_id, str):
existing_ids.add(graph_id)
except DatabaseError:
logger.error(f"Database error searching for agents with term '{term}'")
raise
except Exception as e:
logger.warning(
f"Failed to search for additional agents with term '{term}': {e}"
)
logger.debug(
f"Enriched library agents: {len(existing_agents)} initial + "
f"{len(all_agents) - len(existing_agents)} additional = {len(all_agents)} total"
)
return all_agents
async def decompose_goal(
description: str,
context: str = "",
library_agents: list[AgentSummary] | None = None,
) -> DecompositionResult | None:
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
"""Break down a goal into steps or return clarifying questions.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
library_agents: User's library agents available for sub-agent composition
Returns:
DecompositionResult with either:
Dict with either:
- {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]}
Or None on error
@@ -531,47 +54,26 @@ async def decompose_goal(
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for decompose_goal")
result = await decompose_goal_external(
description, context, _to_dict_list(library_agents)
)
return result # type: ignore[return-value]
return await decompose_goal_external(description, context)
async def generate_agent(
instructions: DecompositionResult | dict[str, Any],
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
"""Generate agent JSON from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams
completion notification)
task_id: Task ID for async processing (enables Redis Streams persistence
and SSE delivery)
Returns:
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
Agent JSON dict or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent")
result = await generate_agent_external(
dict(instructions), _to_dict_list(library_agents), operation_id, task_id
)
# Don't modify async response
if result and result.get("status") == "accepted":
return result
result = await generate_agent_external(instructions)
if result:
if isinstance(result, dict) and result.get("type") == "error":
return result
# Ensure required fields
if "id" not in result:
result["id"] = str(uuid.uuid4())
if "version" not in result:
@@ -581,12 +83,6 @@ async def generate_agent(
return result
class AgentJsonValidationError(Exception):
"""Raised when agent JSON is invalid or missing required fields."""
pass
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
"""Convert agent JSON dict to Graph model.
@@ -595,55 +91,25 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
Returns:
Graph ready for saving
Raises:
AgentJsonValidationError: If required fields are missing from nodes or links
"""
nodes = []
for idx, n in enumerate(agent_json.get("nodes", [])):
block_id = n.get("block_id")
if not block_id:
node_id = n.get("id", f"index_{idx}")
raise AgentJsonValidationError(
f"Node '{node_id}' is missing required field 'block_id'"
)
for n in agent_json.get("nodes", []):
node = Node(
id=n.get("id", str(uuid.uuid4())),
block_id=block_id,
block_id=n["block_id"],
input_default=n.get("input_default", {}),
metadata=n.get("metadata", {}),
)
nodes.append(node)
links = []
for idx, link_data in enumerate(agent_json.get("links", [])):
source_id = link_data.get("source_id")
sink_id = link_data.get("sink_id")
source_name = link_data.get("source_name")
sink_name = link_data.get("sink_name")
missing_fields = []
if not source_id:
missing_fields.append("source_id")
if not sink_id:
missing_fields.append("sink_id")
if not source_name:
missing_fields.append("source_name")
if not sink_name:
missing_fields.append("sink_name")
if missing_fields:
link_id = link_data.get("id", f"index_{idx}")
raise AgentJsonValidationError(
f"Link '{link_id}' is missing required fields: {', '.join(missing_fields)}"
)
for link_data in agent_json.get("links", []):
link = Link(
id=link_data.get("id", str(uuid.uuid4())),
source_id=source_id,
sink_id=sink_id,
source_name=source_name,
sink_name=sink_name,
source_id=link_data["source_id"],
sink_id=link_data["sink_id"],
source_name=link_data["source_name"],
sink_name=link_data["sink_name"],
is_static=link_data.get("is_static", False),
)
links.append(link)
@@ -659,6 +125,27 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
)
def _reassign_node_ids(graph: Graph) -> None:
"""Reassign all node and link IDs to new UUIDs.
This is needed when creating a new version to avoid unique constraint violations.
"""
# Create mapping from old node IDs to new UUIDs
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
# Reassign node IDs
for node in graph.nodes:
node.id = id_map[node.id]
# Update link references to use new node IDs
for link in graph.links:
link.id = str(uuid.uuid4()) # Also give links new IDs
if link.source_id in id_map:
link.source_id = id_map[link.source_id]
if link.sink_id in id_map:
link.sink_id = id_map[link.sink_id]
async def save_agent_to_library(
agent_json: dict[str, Any], user_id: str, is_update: bool = False
) -> tuple[Graph, Any]:
@@ -672,21 +159,63 @@ async def save_agent_to_library(
Returns:
Tuple of (created Graph, LibraryAgent)
"""
from backend.data.graph import get_graph_all_versions
graph = json_to_graph(agent_json)
if is_update:
return await library_db.update_graph_in_library(graph, user_id)
return await library_db.create_graph_in_library(graph, user_id)
# For updates, keep the same graph ID but increment version
# and reassign node/link IDs to avoid conflicts
if graph.id:
existing_versions = await get_graph_all_versions(graph.id, user_id)
if existing_versions:
latest_version = max(v.version for v in existing_versions)
graph.version = latest_version + 1
# Reassign node IDs (but keep graph ID the same)
_reassign_node_ids(graph)
logger.info(f"Updating agent {graph.id} to version {graph.version}")
else:
# For new agents, always generate a fresh UUID to avoid collisions
graph.id = str(uuid.uuid4())
graph.version = 1
# Reassign all node IDs as well
_reassign_node_ids(graph)
logger.info(f"Creating new agent with ID {graph.id}")
# Save to database
created_graph = await create_graph(graph, user_id)
# Add to user's library (or update existing library agent)
library_agents = await library_db.create_library_agent(
graph=created_graph,
user_id=user_id,
sensitive_action_safe_mode=True,
create_library_agents_for_sub_graphs=False,
)
return created_graph, library_agents[0]
def graph_to_json(graph: Graph) -> dict[str, Any]:
"""Convert a Graph object to JSON format for the agent generator.
async def get_agent_as_json(
graph_id: str, user_id: str | None
) -> dict[str, Any] | None:
"""Fetch an agent and convert to JSON format for editing.
Args:
graph: Graph object to convert
graph_id: Graph ID or library agent ID
user_id: User ID
Returns:
Agent as JSON dict
Agent as JSON dict or None if not found
"""
from backend.data.graph import get_graph
# Try to get the graph (version=None gets the active version)
graph = await get_graph(graph_id, version=None, user_id=user_id)
if not graph:
return None
# Convert to JSON format
nodes = []
for node in graph.nodes:
nodes.append(
@@ -723,41 +252,8 @@ def graph_to_json(graph: Graph) -> dict[str, Any]:
}
async def get_agent_as_json(
agent_id: str, user_id: str | None
) -> dict[str, Any] | None:
"""Fetch an agent and convert to JSON format for editing.
Args:
agent_id: Graph ID or library agent ID
user_id: User ID
Returns:
Agent as JSON dict or None if not found
"""
graph = await get_graph(agent_id, version=None, user_id=user_id)
if not graph and user_id:
try:
library_agent = await library_db.get_library_agent(agent_id, user_id)
graph = await get_graph(
library_agent.graph_id, version=None, user_id=user_id
)
except NotFoundError:
pass
if not graph:
return None
return graph_to_json(graph)
async def generate_agent_patch(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[AgentSummary] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
update_request: str, current_agent: dict[str, Any]
) -> dict[str, Any] | None:
"""Update an existing agent using natural language.
@@ -769,57 +265,13 @@ async def generate_agent_patch(
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
Updated agent JSON, clarifying questions dict, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent_patch")
return await generate_agent_patch_external(
update_request,
current_agent,
_to_dict_list(library_agents),
operation_id,
task_id,
)
async def customize_template(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any] | None:
"""Customize a template/marketplace agent using natural language.
This is used when users want to modify a template or marketplace agent
to fit their specific needs before adding it to their library.
The external Agent Generator service handles:
- Understanding the modification request
- Applying changes to the template
- Fixing and validating the result
Args:
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
Returns:
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on unexpected error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for customize_template")
return await customize_template_external(
template_agent, modification_request, context
)
return await generate_agent_patch_external(update_request, current_agent)

View File

@@ -1,95 +0,0 @@
"""Error handling utilities for agent generator."""
import re
def _sanitize_error_details(details: str) -> str:
"""Sanitize error details to remove sensitive information.
Strips common patterns that could expose internal system info:
- File paths (Unix and Windows)
- Database connection strings
- URLs with credentials
- Stack trace internals
Args:
details: Raw error details string
Returns:
Sanitized error details safe for user display
"""
sanitized = re.sub(
r"/[a-zA-Z0-9_./\-]+\.(py|js|ts|json|yaml|yml)", "[path]", details
)
sanitized = re.sub(r"[A-Z]:\\[a-zA-Z0-9_\\.\\-]+", "[path]", sanitized)
sanitized = re.sub(
r"(postgres|mysql|mongodb|redis)://[^\s]+", "[database_url]", sanitized
)
sanitized = re.sub(r"https?://[^:]+:[^@]+@[^\s]+", "[url]", sanitized)
sanitized = re.sub(r", line \d+", "", sanitized)
sanitized = re.sub(r'File "[^"]+",?', "", sanitized)
return sanitized.strip()
def get_user_message_for_error(
error_type: str,
operation: str = "process the request",
llm_parse_message: str | None = None,
validation_message: str | None = None,
error_details: str | None = None,
) -> str:
"""Get a user-friendly error message based on error type.
This function maps internal error types to user-friendly messages,
providing a consistent experience across different agent operations.
Args:
error_type: The error type from the external service
(e.g., "llm_parse_error", "timeout", "rate_limit")
operation: Description of what operation failed, used in the default
message (e.g., "analyze the goal", "generate the agent")
llm_parse_message: Custom message for llm_parse_error type
validation_message: Custom message for validation_error type
error_details: Optional additional details about the error
Returns:
User-friendly error message suitable for display to the user
"""
base_message = ""
if error_type == "llm_parse_error":
base_message = (
llm_parse_message
or "The AI had trouble processing this request. Please try again."
)
elif error_type == "validation_error":
base_message = (
validation_message
or "The generated agent failed validation. "
"This usually happens when the agent structure doesn't match "
"what the platform expects. Please try simplifying your goal "
"or breaking it into smaller parts."
)
elif error_type == "patch_error":
base_message = (
"Failed to apply the changes. The modification couldn't be "
"validated. Please try a different approach or simplify the change."
)
elif error_type in ("timeout", "llm_timeout"):
base_message = (
"The request took too long to process. This can happen with "
"complex agents. Please try again or simplify your goal."
)
elif error_type in ("rate_limit", "llm_rate_limit"):
base_message = "The service is currently busy. Please try again in a moment."
else:
base_message = f"Failed to {operation}. Please try again."
if error_details:
details = _sanitize_error_details(error_details)
if len(details) > 200:
details = details[:200] + "..."
base_message += f"\n\nTechnical details: {details}"
return base_message

View File

@@ -14,70 +14,6 @@ from backend.util.settings import Settings
logger = logging.getLogger(__name__)
def _create_error_response(
error_message: str,
error_type: str = "unknown",
details: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Create a standardized error response dict.
Args:
error_message: Human-readable error message
error_type: Machine-readable error type
details: Optional additional error details
Returns:
Error dict with type="error" and error details
"""
response: dict[str, Any] = {
"type": "error",
"error": error_message,
"error_type": error_type,
}
if details:
response["details"] = details
return response
def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
"""Classify an HTTP error into error_type and message.
Args:
e: The HTTP status error
Returns:
Tuple of (error_type, error_message)
"""
status = e.response.status_code
if status == 429:
return "rate_limit", f"Agent Generator rate limited: {e}"
elif status == 503:
return "service_unavailable", f"Agent Generator unavailable: {e}"
elif status == 504 or status == 408:
return "timeout", f"Agent Generator timed out: {e}"
else:
return "http_error", f"HTTP error calling Agent Generator: {e}"
def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
"""Classify a request error into error_type and message.
Args:
e: The request error
Returns:
Tuple of (error_type, error_message)
"""
error_str = str(e).lower()
if "timeout" in error_str or "timed out" in error_str:
return "timeout", f"Agent Generator request timed out: {e}"
elif "connect" in error_str:
return "connection_error", f"Could not connect to Agent Generator: {e}"
else:
return "request_error", f"Request error calling Agent Generator: {e}"
_client: httpx.AsyncClient | None = None
_settings: Settings | None = None
@@ -117,16 +53,13 @@ def _get_client() -> httpx.AsyncClient:
async def decompose_goal_external(
description: str,
context: str = "",
library_agents: list[dict[str, Any]] | None = None,
description: str, context: str = ""
) -> dict[str, Any] | None:
"""Call the external service to decompose a goal.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
library_agents: User's library agents available for sub-agent composition
Returns:
Dict with either:
@@ -134,17 +67,15 @@ async def decompose_goal_external(
- {"type": "instructions", "steps": [...]}
- {"type": "unachievable_goal", ...}
- {"type": "vague_goal", ...}
- {"type": "error", "error": "...", "error_type": "..."} on error
Or None on unexpected error
Or None on error
"""
client = _get_client()
if context:
description = f"{description}\n\nAdditional context from user:\n{context}"
# Build the request payload
payload: dict[str, Any] = {"description": description}
if library_agents:
payload["library_agents"] = library_agents
if context:
# The external service uses user_instruction for additional context
payload["user_instruction"] = context
try:
response = await client.post("/api/decompose-description", json=payload)
@@ -152,13 +83,8 @@ async def decompose_goal_external(
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator decomposition failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
logger.error(f"External service returned error: {data.get('error')}")
return None
# Map the response to the expected format
response_type = data.get("type")
@@ -180,162 +106,88 @@ async def decompose_goal_external(
"type": "vague_goal",
"suggested_goal": data.get("suggested_goal"),
}
elif response_type == "error":
# Pass through error from the service
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
else:
logger.error(
f"Unknown response type from external service: {response_type}"
)
return _create_error_response(
f"Unknown response type from Agent Generator: {response_type}",
"invalid_response",
)
return None
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
logger.error(f"HTTP error calling external agent generator: {e}")
return None
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
logger.error(f"Request error calling external agent generator: {e}")
return None
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
logger.error(f"Unexpected error calling external agent generator: {e}")
return None
async def generate_agent_external(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
instructions: dict[str, Any]
) -> dict[str, Any] | None:
"""Call the external service to generate an agent from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
Agent JSON dict or None on error
"""
client = _get_client()
# Build request payload
payload: dict[str, Any] = {"instructions": instructions}
if library_agents:
payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post("/api/generate-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response = await client.post(
"/api/generate-agent", json={"instructions": instructions}
)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
)
return _create_error_response(error_msg, error_type)
logger.error(f"External service returned error: {data.get('error')}")
return None
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
logger.error(f"HTTP error calling external agent generator: {e}")
return None
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
logger.error(f"Request error calling external agent generator: {e}")
return None
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
logger.error(f"Unexpected error calling external agent generator: {e}")
return None
async def generate_agent_patch_external(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
update_request: str, current_agent: dict[str, Any]
) -> dict[str, Any] | None:
"""Call the external service to generate a patch for an existing agent.
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
Updated agent JSON, clarifying questions dict, or None on error
"""
client = _get_client()
# Build request payload
payload: dict[str, Any] = {
"update_request": update_request,
"current_agent_json": current_agent,
}
if library_agents:
payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post("/api/update-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async update request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response = await client.post(
"/api/update-agent",
json={
"update_request": update_request,
"current_agent_json": current_agent,
},
)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator patch generation failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
logger.error(f"External service returned error: {data.get('error')}")
return None
# Check if it's clarifying questions
if data.get("type") == "clarifying_questions":
@@ -344,99 +196,18 @@ async def generate_agent_patch_external(
"questions": data.get("questions", []),
}
# Check if it's an error passed through
if data.get("type") == "error":
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
# Otherwise return the updated agent JSON
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
logger.error(f"HTTP error calling external agent generator: {e}")
return None
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
logger.error(f"Request error calling external agent generator: {e}")
return None
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def customize_template_external(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any] | None:
"""Call the external service to customize a template/marketplace agent.
Args:
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
Returns:
Customized agent JSON, clarifying questions dict, or error dict on error
"""
client = _get_client()
request = modification_request
if context:
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
payload: dict[str, Any] = {
"template_agent_json": template_agent,
"modification_request": request,
}
try:
response = await client.post("/api/template-modification", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator template customization failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Check if it's clarifying questions
if data.get("type") == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
# Check if it's an error passed through
if data.get("type") == "error":
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
# Otherwise return the customized agent JSON
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
logger.error(f"Unexpected error calling external agent generator: {e}")
return None
async def get_blocks_external() -> list[dict[str, Any]] | None:

View File

@@ -1,7 +1,6 @@
"""Shared agent search functionality for find_agent and find_library_agent tools."""
import logging
import re
from typing import Literal
from backend.api.features.library import db as library_db
@@ -20,85 +19,6 @@ logger = logging.getLogger(__name__)
SearchSource = Literal["marketplace", "library"]
_UUID_PATTERN = re.compile(
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$",
re.IGNORECASE,
)
def _is_uuid(text: str) -> bool:
"""Check if text is a valid UUID v4."""
return bool(_UUID_PATTERN.match(text.strip()))
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
"""Fetch a library agent by ID (library agent ID or graph_id).
Tries multiple lookup strategies:
1. First by graph_id (AgentGraph primary key)
2. Then by library agent ID (LibraryAgent primary key)
Args:
user_id: The user ID
agent_id: The ID to look up (can be graph_id or library agent ID)
Returns:
AgentInfo if found, None otherwise
"""
try:
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
if agent:
logger.debug(f"Found library agent by graph_id: {agent.name}")
return AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by graph_id {agent_id}: {e}",
exc_info=True,
)
try:
agent = await library_db.get_library_agent(agent_id, user_id)
if agent:
logger.debug(f"Found library agent by library_id: {agent.name}")
return AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
except NotFoundError:
logger.debug(f"Library agent not found by library_id: {agent_id}")
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by library_id {agent_id}: {e}",
exc_info=True,
)
return None
async def search_agents(
query: str,
@@ -149,37 +69,29 @@ async def search_agents(
is_featured=False,
)
)
else:
if _is_uuid(query):
logger.info(f"Query looks like UUID, trying direct lookup: {query}")
agent = await _get_library_agent_by_id(user_id, query) # type: ignore[arg-type]
if agent:
agents.append(agent)
logger.info(f"Found agent by direct ID lookup: {agent.name}")
if not agents:
logger.info(f"Searching user library for: {query}")
results = await library_db.list_library_agents(
user_id=user_id, # type: ignore[arg-type]
search_term=query,
page_size=10,
)
for agent in results.agents:
agents.append(
AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
else: # library
logger.info(f"Searching user library for: {query}")
results = await library_db.list_library_agents(
user_id=user_id, # type: ignore[arg-type]
search_term=query,
page_size=10,
)
for agent in results.agents:
agents.append(
AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
)
logger.info(f"Found {len(agents)} agents in {source}")
except NotFoundError:
pass
@@ -206,9 +118,9 @@ async def search_agents(
]
)
no_results_msg = (
f"No agents found matching '{query}'. Let the user know they can try different keywords or browse the marketplace. Also let them know you can create a custom agent for them based on their needs."
f"No agents found matching '{query}'. Try different keywords or browse the marketplace."
if source == "marketplace"
else f"No agents matching '{query}' found in your library. Let the user know you can create a custom agent for them based on their needs."
else f"No agents matching '{query}' found in your library."
)
return NoResultsResponse(
message=no_results_msg, session_id=session_id, suggestions=suggestions
@@ -224,10 +136,10 @@ async def search_agents(
message = (
"Now you have found some options for the user to choose from. "
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
"Please ask the user if they would like to use any of these agents. Let the user know we can create a custom agent for them based on their needs."
"Please ask the user if they would like to use any of these agents."
if source == "marketplace"
else "Found agents in the user's library. You can provide a link to view an agent at: "
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute. Let the user know we can create a custom agent for them based on their needs."
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute."
)
return AgentsFoundResponse(

View File

@@ -1,251 +0,0 @@
"""
Detect and save embedded binary data in block outputs.
Scans stdout_logs and other string outputs for embedded base64 patterns,
saves detected binary content to workspace, and replaces the base64 with
workspace:// references. This reduces LLM output token usage by ~97% for
file generation tasks.
Primary use case: ExecuteCodeBlock prints base64 to stdout, which appears
in stdout_logs. Without this processor, the LLM would re-type the entire
base64 string when saving files.
"""
import base64
import binascii
import hashlib
import logging
import re
import uuid
from typing import Any, Optional
from backend.util.file import sanitize_filename
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import WorkspaceManager
logger = logging.getLogger(__name__)
# Minimum decoded size to process (filters out small base64 strings)
MIN_DECODED_SIZE = 1024 # 1KB
# Pattern to find base64 chunks in text (at least 100 chars to be worth checking)
# Matches continuous base64 characters (with optional whitespace for line wrapping),
# optionally ending with = padding
EMBEDDED_BASE64_PATTERN = re.compile(r"[A-Za-z0-9+/\s]{100,}={0,2}")
# Magic numbers for binary file detection
MAGIC_SIGNATURES = [
(b"\x89PNG\r\n\x1a\n", "png"),
(b"\xff\xd8\xff", "jpg"),
(b"%PDF-", "pdf"),
(b"GIF87a", "gif"),
(b"GIF89a", "gif"),
(b"RIFF", "webp"), # Also check content[8:12] == b'WEBP'
]
async def process_binary_outputs(
outputs: dict[str, list[Any]],
workspace_manager: WorkspaceManager,
block_name: str,
) -> dict[str, list[Any]]:
"""
Scan all string values in outputs for embedded base64 binary content.
Save detected binaries to workspace and replace with references.
Args:
outputs: Block execution outputs (dict of output_name -> list of values)
workspace_manager: WorkspaceManager instance with session scoping
block_name: Name of the block (used in generated filenames)
Returns:
Processed outputs with embedded base64 replaced by workspace references
"""
cache: dict[str, str] = {} # content_hash -> workspace_ref
processed: dict[str, list[Any]] = {}
for name, items in outputs.items():
processed_items = []
for item in items:
processed_items.append(
await _process_value(item, workspace_manager, block_name, cache)
)
processed[name] = processed_items
return processed
async def _process_value(
value: Any,
wm: WorkspaceManager,
block: str,
cache: dict[str, str],
) -> Any:
"""Recursively process a value, detecting embedded base64 in strings."""
if isinstance(value, dict):
result = {}
for k, v in value.items():
result[k] = await _process_value(v, wm, block, cache)
return result
if isinstance(value, list):
return [await _process_value(v, wm, block, cache) for v in value]
if isinstance(value, str) and len(value) > MIN_DECODED_SIZE:
return await _extract_and_replace_base64(value, wm, block, cache)
return value
async def _extract_and_replace_base64(
text: str,
wm: WorkspaceManager,
block: str,
cache: dict[str, str],
) -> str:
"""
Find embedded base64 in text, save binaries, replace with references.
Scans for base64 patterns, validates each as binary via magic numbers,
saves valid binaries to workspace, and replaces the base64 portion
(plus any surrounding markers) with the workspace reference.
"""
result = text
offset = 0
for match in EMBEDDED_BASE64_PATTERN.finditer(text):
b64_str = match.group(0)
# Try to decode and validate
detection = _decode_and_validate(b64_str)
if detection is None:
continue
content, ext = detection
# Save to workspace
ref = await _save_binary(content, ext, wm, block, cache)
if ref is None:
continue
# Calculate replacement bounds (include surrounding markers if present)
start, end = match.start(), match.end()
start, end = _expand_to_markers(text, start, end)
# Apply replacement with offset adjustment
adj_start = start + offset
adj_end = end + offset
result = result[:adj_start] + ref + result[adj_end:]
offset += len(ref) - (end - start)
return result
def _decode_and_validate(b64_str: str) -> Optional[tuple[bytes, str]]:
"""
Decode base64 and validate it's a known binary format.
Tries multiple 4-byte aligned offsets to handle cases where marker text
(e.g., "START" from "PDF_BASE64_START") bleeds into the regex match.
Base64 works in 4-char chunks, so we only check aligned offsets.
Returns (content, extension) if valid binary, None otherwise.
"""
# Strip whitespace for RFC 2045 line-wrapped base64
normalized = re.sub(r"\s+", "", b64_str)
# Try offsets 0, 4, 8, ... up to 32 chars (handles markers up to ~24 chars)
# This handles cases like "STARTJVBERi0..." where "START" bleeds into match
for char_offset in range(0, min(33, len(normalized)), 4):
candidate = normalized[char_offset:]
try:
content = base64.b64decode(candidate, validate=True)
except (ValueError, binascii.Error):
continue
# Must meet minimum size
if len(content) < MIN_DECODED_SIZE:
continue
# Check magic numbers
for magic, ext in MAGIC_SIGNATURES:
if content.startswith(magic):
# Special case for WebP: RIFF container, verify "WEBP" at offset 8
if magic == b"RIFF":
if len(content) < 12 or content[8:12] != b"WEBP":
continue
return content, ext
return None
def _expand_to_markers(text: str, start: int, end: int) -> tuple[int, int]:
"""
Expand replacement bounds to include surrounding markers if present.
Handles patterns like:
- ---BASE64_START---\\n{base64}\\n---BASE64_END---
- [BASE64]{base64}[/BASE64]
- Or just the raw base64
"""
# Common marker patterns to strip (order matters - check longer patterns first)
start_markers = [
"PDF_BASE64_START",
"---BASE64_START---\n",
"---BASE64_START---",
"[BASE64]\n",
"[BASE64]",
]
end_markers = [
"PDF_BASE64_END",
"\n---BASE64_END---",
"---BASE64_END---",
"\n[/BASE64]",
"[/BASE64]",
]
# Check for start markers
for marker in start_markers:
marker_start = start - len(marker)
if marker_start >= 0 and text[marker_start:start] == marker:
start = marker_start
break
# Check for end markers
for marker in end_markers:
marker_end = end + len(marker)
if marker_end <= len(text) and text[end:marker_end] == marker:
end = marker_end
break
return start, end
async def _save_binary(
content: bytes,
ext: str,
wm: WorkspaceManager,
block: str,
cache: dict[str, str],
) -> Optional[str]:
"""
Save binary content to workspace with deduplication.
Returns workspace://file-id reference, or None on failure.
"""
content_hash = hashlib.sha256(content).hexdigest()
if content_hash in cache:
return cache[content_hash]
try:
safe_block = sanitize_filename(block)[:20].lower()
filename = f"{safe_block}_{uuid.uuid4().hex[:12]}.{ext}"
# Scan for viruses before saving
await scan_content_safe(content, filename=filename)
file = await wm.write_file(content, filename)
ref = f"workspace://{file.id}"
cache[content_hash] = ref
return ref
except Exception as e:
logger.warning("Failed to save binary output: %s", e)
return None

View File

@@ -8,17 +8,13 @@ from backend.api.features.chat.model import ChatSession
from .agent_generator import (
AgentGeneratorNotConfiguredError,
decompose_goal,
enrich_library_agents_from_steps,
generate_agent,
get_all_relevant_agents_for_generation,
get_user_message_for_error,
save_agent_to_library,
)
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -99,10 +95,6 @@ class CreateAgentTool(BaseTool):
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not description:
return ErrorResponse(
message="Please provide a description of what the agent should do.",
@@ -110,24 +102,9 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
library_agents = None
if user_id:
try:
library_agents = await get_all_relevant_agents_for_generation(
user_id=user_id,
search_query=description,
include_marketplace=True,
)
logger.debug(
f"Found {len(library_agents)} relevant agents for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents: {e}")
# Step 1: Decompose goal into steps
try:
decomposition_result = await decompose_goal(
description, context, library_agents
)
decomposition_result = await decompose_goal(description, context)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
@@ -140,31 +117,15 @@ class CreateAgentTool(BaseTool):
if decomposition_result is None:
return ErrorResponse(
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
message="Failed to analyze the goal. The agent generation service may be unavailable or timed out. Please try again.",
error="decomposition_failed",
details={"description": description[:100]},
session_id=session_id,
)
if decomposition_result.get("type") == "error":
error_msg = decomposition_result.get("error", "Unknown error")
error_type = decomposition_result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="analyze the goal",
llm_parse_message="The AI had trouble understanding this request. Please try rephrasing your goal.",
)
return ErrorResponse(
message=user_message,
error=f"decomposition_failed:{error_type}",
details={
"description": description[:100],
"service_error": error_msg,
"error_type": error_type,
},
"description": description[:100]
}, # Include context for debugging
session_id=session_id,
)
# Check if LLM returned clarifying questions
if decomposition_result.get("type") == "clarifying_questions":
questions = decomposition_result.get("questions", [])
return ClarificationNeededResponse(
@@ -183,6 +144,7 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Check for unachievable/vague goals
if decomposition_result.get("type") == "unachievable_goal":
suggested = decomposition_result.get("suggested_goal", "")
reason = decomposition_result.get("reason", "")
@@ -209,27 +171,9 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
if user_id and library_agents is not None:
try:
library_agents = await enrich_library_agents_from_steps(
user_id=user_id,
decomposition_result=decomposition_result,
existing_agents=library_agents,
include_marketplace=True,
)
logger.debug(
f"After enrichment: {len(library_agents)} total agents for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to enrich library agents from steps: {e}")
# Step 2: Generate agent JSON (external service handles fixing and validation)
try:
agent_json = await generate_agent(
decomposition_result,
library_agents,
operation_id=operation_id,
task_id=task_id,
)
agent_json = await generate_agent(decomposition_result)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
@@ -242,47 +186,11 @@ class CreateAgentTool(BaseTool):
if agent_json is None:
return ErrorResponse(
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
message="Failed to generate the agent. The agent generation service may be unavailable or timed out. Please try again.",
error="generation_failed",
details={"description": description[:100]},
session_id=session_id,
)
if isinstance(agent_json, dict) and agent_json.get("type") == "error":
error_msg = agent_json.get("error", "Unknown error")
error_type = agent_json.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="generate the agent",
llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.",
validation_message=(
"I wasn't able to create a valid agent for this request. "
"The generated workflow had some structural issues. "
"Please try simplifying your goal or breaking it into smaller steps."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"generation_failed:{error_type}",
details={
"description": description[:100],
"service_error": error_msg,
"error_type": error_type,
},
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if agent_json.get("status") == "accepted":
logger.info(
f"Agent generation delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent generation started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
"description": description[:100]
}, # Include context for debugging
session_id=session_id,
)
@@ -291,6 +199,7 @@ class CreateAgentTool(BaseTool):
node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", []))
# Step 3: Preview or save
if not save:
return AgentPreviewResponse(
message=(
@@ -305,6 +214,7 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Save to library
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
@@ -322,7 +232,7 @@ class CreateAgentTool(BaseTool):
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
library_agent_link=f"/library/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)

View File

@@ -1,337 +0,0 @@
"""CustomizeAgentTool - Customizes marketplace/template agents using natural language."""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.store import db as store_db
from backend.api.features.store.exceptions import AgentNotFoundError
from .agent_generator import (
AgentGeneratorNotConfiguredError,
customize_template,
get_user_message_for_error,
graph_to_json,
save_agent_to_library,
)
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
ToolResponseBase,
)
logger = logging.getLogger(__name__)
class CustomizeAgentTool(BaseTool):
"""Tool for customizing marketplace/template agents using natural language."""
@property
def name(self) -> str:
return "customize_agent"
@property
def description(self) -> str:
return (
"Customize a marketplace or template agent using natural language. "
"Takes an existing agent from the marketplace and modifies it based on "
"the user's requirements before adding to their library."
)
@property
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"agent_id": {
"type": "string",
"description": (
"The marketplace agent ID in format 'creator/slug' "
"(e.g., 'autogpt/newsletter-writer'). "
"Get this from find_agent results."
),
},
"modifications": {
"type": "string",
"description": (
"Natural language description of how to customize the agent. "
"Be specific about what changes you want to make."
),
},
"context": {
"type": "string",
"description": (
"Additional context or answers to previous clarifying questions."
),
},
"save": {
"type": "boolean",
"description": (
"Whether to save the customized agent to the user's library. "
"Default is true. Set to false for preview only."
),
"default": True,
},
},
"required": ["agent_id", "modifications"],
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the customize_agent tool.
Flow:
1. Parse the agent ID to get creator/slug
2. Fetch the template agent from the marketplace
3. Call customize_template with the modification request
4. Preview or save based on the save parameter
"""
agent_id = kwargs.get("agent_id", "").strip()
modifications = kwargs.get("modifications", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
session_id = session.session_id if session else None
if not agent_id:
return ErrorResponse(
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
error="missing_agent_id",
session_id=session_id,
)
if not modifications:
return ErrorResponse(
message="Please describe how you want to customize this agent.",
error="missing_modifications",
session_id=session_id,
)
# Parse agent_id in format "creator/slug"
parts = [p.strip() for p in agent_id.split("/")]
if len(parts) != 2 or not parts[0] or not parts[1]:
return ErrorResponse(
message=(
f"Invalid agent ID format: '{agent_id}'. "
"Expected format is 'creator/agent-name' "
"(e.g., 'autogpt/newsletter-writer')."
),
error="invalid_agent_id_format",
session_id=session_id,
)
creator_username, agent_slug = parts
# Fetch the marketplace agent details
try:
agent_details = await store_db.get_store_agent_details(
username=creator_username, agent_name=agent_slug
)
except AgentNotFoundError:
return ErrorResponse(
message=(
f"Could not find marketplace agent '{agent_id}'. "
"Please check the agent ID and try again."
),
error="agent_not_found",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
return ErrorResponse(
message="Failed to fetch the marketplace agent. Please try again.",
error="fetch_error",
session_id=session_id,
)
if not agent_details.store_listing_version_id:
return ErrorResponse(
message=(
f"The agent '{agent_id}' does not have an available version. "
"Please try a different agent."
),
error="no_version_available",
session_id=session_id,
)
# Get the full agent graph
try:
graph = await store_db.get_agent(agent_details.store_listing_version_id)
template_agent = graph_to_json(graph)
except Exception as e:
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
return ErrorResponse(
message="Failed to fetch the agent configuration. Please try again.",
error="graph_fetch_error",
session_id=session_id,
)
# Call customize_template
try:
result = await customize_template(
template_agent=template_agent,
modification_request=modifications,
context=context,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
"Agent customization is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error calling customize_template for {agent_id}: {e}")
return ErrorResponse(
message=(
"Failed to customize the agent due to a service error. "
"Please try again."
),
error="customization_service_error",
session_id=session_id,
)
if result is None:
return ErrorResponse(
message=(
"Failed to customize the agent. "
"The agent generation service may be unavailable or timed out. "
"Please try again."
),
error="customization_failed",
session_id=session_id,
)
# Handle error response
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="customize the agent",
llm_parse_message=(
"The AI had trouble customizing the agent. "
"Please try again or simplify your request."
),
validation_message=(
"The customized agent failed validation. "
"Please try rephrasing your request."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"customization_failed:{error_type}",
session_id=session_id,
)
# Handle clarifying questions
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
questions = result.get("questions") or []
if not isinstance(questions, list):
logger.error(
f"Unexpected clarifying questions format: {type(questions)}"
)
questions = []
return ClarificationNeededResponse(
message=(
"I need some more information to customize this agent. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
if isinstance(q, dict)
],
session_id=session_id,
)
# Result should be the customized agent JSON
if not isinstance(result, dict):
logger.error(f"Unexpected customize_template response type: {type(result)}")
return ErrorResponse(
message="Failed to customize the agent due to an unexpected response.",
error="unexpected_response_type",
session_id=session_id,
)
customized_agent = result
agent_name = customized_agent.get(
"name", f"Customized {agent_details.agent_name}"
)
agent_description = customized_agent.get("description", "")
nodes = customized_agent.get("nodes")
links = customized_agent.get("links")
node_count = len(nodes) if isinstance(nodes, list) else 0
link_count = len(links) if isinstance(links, list) else 0
if not save:
return AgentPreviewResponse(
message=(
f"I've customized the agent '{agent_details.agent_name}'. "
f"The customized agent has {node_count} blocks. "
f"Review it and call customize_agent with save=true to save it."
),
agent_json=customized_agent,
agent_name=agent_name,
description=agent_description,
node_count=node_count,
link_count=link_count,
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
error="auth_required",
session_id=session_id,
)
# Save to user's library
try:
created_graph, library_agent = await save_agent_to_library(
customized_agent, user_id, is_update=False
)
return AgentSavedResponse(
message=(
f"Customized agent '{created_graph.name}' "
f"(based on '{agent_details.agent_name}') "
f"has been saved to your library!"
),
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error saving customized agent: {e}")
return ErrorResponse(
message="Failed to save the customized agent. Please try again.",
error="save_failed",
session_id=session_id,
)

View File

@@ -9,15 +9,12 @@ from .agent_generator import (
AgentGeneratorNotConfiguredError,
generate_agent_patch,
get_agent_as_json,
get_all_relevant_agents_for_generation,
get_user_message_for_error,
save_agent_to_library,
)
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -105,10 +102,6 @@ class EditAgentTool(BaseTool):
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not agent_id:
return ErrorResponse(
message="Please provide the agent ID to edit.",
@@ -123,6 +116,7 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Step 1: Fetch current agent
current_agent = await get_agent_as_json(agent_id, user_id)
if current_agent is None:
@@ -132,34 +126,14 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
library_agents = None
if user_id:
try:
graph_id = current_agent.get("id")
library_agents = await get_all_relevant_agents_for_generation(
user_id=user_id,
search_query=changes,
exclude_graph_id=graph_id,
include_marketplace=True,
)
logger.debug(
f"Found {len(library_agents)} relevant agents for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents: {e}")
# Build the update request with context
update_request = changes
if context:
update_request = f"{changes}\n\nAdditional context:\n{context}"
# Step 2: Generate updated agent (external service handles fixing and validation)
try:
result = await generate_agent_patch(
update_request,
current_agent,
library_agents,
operation_id=operation_id,
task_id=task_id,
)
result = await generate_agent_patch(update_request, current_agent)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
@@ -178,42 +152,7 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if result.get("status") == "accepted":
logger.info(
f"Agent edit delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent edit started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
# Check if the result is an error from the external service
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="generate the changes",
llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.",
validation_message="The generated changes failed validation. Please try rephrasing your request.",
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"update_generation_failed:{error_type}",
details={
"agent_id": agent_id,
"changes": changes[:100],
"service_error": error_msg,
"error_type": error_type,
},
session_id=session_id,
)
# Check if LLM returned clarifying questions
if result.get("type") == "clarifying_questions":
questions = result.get("questions", [])
return ClarificationNeededResponse(
@@ -232,6 +171,7 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Result is the updated agent JSON
updated_agent = result
agent_name = updated_agent.get("name", "Updated Agent")
@@ -239,6 +179,7 @@ class EditAgentTool(BaseTool):
node_count = len(updated_agent.get("nodes", []))
link_count = len(updated_agent.get("links", []))
# Step 3: Preview or save
if not save:
return AgentPreviewResponse(
message=(
@@ -254,6 +195,7 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Save to library (creates a new version)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
@@ -271,7 +213,7 @@ class EditAgentTool(BaseTool):
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
library_agent_link=f"/library/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)

View File

@@ -28,18 +28,10 @@ class ResponseType(str, Enum):
BLOCK_OUTPUT = "block_output"
DOC_SEARCH_RESULTS = "doc_search_results"
DOC_PAGE = "doc_page"
# Workspace response types
WORKSPACE_FILE_LIST = "workspace_file_list"
WORKSPACE_FILE_CONTENT = "workspace_file_content"
WORKSPACE_FILE_METADATA = "workspace_file_metadata"
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
# Long-running operation types
OPERATION_STARTED = "operation_started"
OPERATION_PENDING = "operation_pending"
OPERATION_IN_PROGRESS = "operation_in_progress"
# Input validation
INPUT_VALIDATION_ERROR = "input_validation_error"
# Base response model
@@ -70,10 +62,6 @@ class AgentInfo(BaseModel):
has_external_trigger: bool | None = None
new_output: bool | None = None
graph_id: str | None = None
inputs: dict[str, Any] | None = Field(
default=None,
description="Input schema for the agent, including field names, types, and defaults",
)
class AgentsFoundResponse(ToolResponseBase):
@@ -200,20 +188,6 @@ class ErrorResponse(ToolResponseBase):
details: dict[str, Any] | None = None
class InputValidationErrorResponse(ToolResponseBase):
"""Response when run_agent receives unknown input fields."""
type: ResponseType = ResponseType.INPUT_VALIDATION_ERROR
unrecognized_fields: list[str] = Field(
description="List of input field names that were not recognized"
)
inputs: dict[str, Any] = Field(
description="The agent's valid input schema for reference"
)
graph_id: str | None = None
graph_version: int | None = None
# Agent output models
class ExecutionOutputInfo(BaseModel):
"""Summary of a single execution's outputs."""
@@ -372,15 +346,11 @@ class OperationStartedResponse(ToolResponseBase):
This is returned immediately to the client while the operation continues
to execute. The user can close the tab and check back later.
The task_id can be used to reconnect to the SSE stream via
GET /chat/tasks/{task_id}/stream?last_idx=0
"""
type: ResponseType = ResponseType.OPERATION_STARTED
operation_id: str
tool_name: str
task_id: str | None = None # For SSE reconnection
class OperationPendingResponse(ToolResponseBase):
@@ -404,20 +374,3 @@ class OperationInProgressResponse(ToolResponseBase):
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
tool_call_id: str
class AsyncProcessingResponse(ToolResponseBase):
"""Response when an operation has been delegated to async processing.
This is returned by tools when the external service accepts the request
for async processing (HTTP 202 Accepted). The Redis Streams completion
consumer will handle the result when the external service completes.
The status field is specifically "accepted" to allow the long-running tool
handler to detect this response and skip LLM continuation.
"""
type: ResponseType = ResponseType.OPERATION_STARTED
status: str = "accepted" # Must be "accepted" for detection
operation_id: str | None = None
task_id: str | None = None

View File

@@ -30,7 +30,6 @@ from .models import (
ErrorResponse,
ExecutionOptions,
ExecutionStartedResponse,
InputValidationErrorResponse,
SetupInfo,
SetupRequirementsResponse,
ToolResponseBase,
@@ -274,22 +273,6 @@ class RunAgentTool(BaseTool):
input_properties = graph.input_schema.get("properties", {})
required_fields = set(graph.input_schema.get("required", []))
provided_inputs = set(params.inputs.keys())
valid_fields = set(input_properties.keys())
# Check for unknown input fields
unrecognized_fields = provided_inputs - valid_fields
if unrecognized_fields:
return InputValidationErrorResponse(
message=(
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
f"Agent was not executed. Please use the correct field names from the schema."
),
session_id=session_id,
unrecognized_fields=sorted(unrecognized_fields),
inputs=graph.input_schema,
graph_id=graph.id,
graph_version=graph.version,
)
# If agent has inputs but none were provided AND use_defaults is not set,
# always show what's available first so user can decide

View File

@@ -402,42 +402,3 @@ async def test_run_agent_schedule_without_name(setup_test_data):
# Should return error about missing schedule_name
assert result_data.get("type") == "error"
assert "schedule_name" in result_data["message"].lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_run_agent_rejects_unknown_input_fields(setup_test_data):
"""Test that run_agent returns input_validation_error for unknown input fields."""
user = setup_test_data["user"]
store_submission = setup_test_data["store_submission"]
tool = RunAgentTool()
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
session = make_session(user_id=user.id)
# Execute with unknown input field names
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={
"unknown_field": "some value",
"another_unknown": "another value",
},
session=session,
)
assert response is not None
assert hasattr(response, "output")
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
# Should return input_validation_error type with unrecognized fields
assert result_data.get("type") == "input_validation_error"
assert "unrecognized_fields" in result_data
assert set(result_data["unrecognized_fields"]) == {
"another_unknown",
"unknown_field",
}
assert "inputs" in result_data # Contains the valid schema
assert "Agent was not executed" in result_data["message"]

View File

@@ -1,23 +1,17 @@
"""Tool for executing blocks directly."""
import logging
import uuid
from collections import defaultdict
from typing import Any
from pydantic_core import PydanticUndefined
from backend.api.features.chat.model import ChatSession
from backend.data.block import get_block
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
from backend.util.workspace import WorkspaceManager
from .base import BaseTool
from .binary_output_processor import process_binary_outputs
from .models import (
BlockOutputResponse,
ErrorResponse,
@@ -79,22 +73,15 @@ class RunBlockTool(BaseTool):
self,
user_id: str,
block: Any,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Check if user has required credentials for a block.
Args:
user_id: User ID
block: Block to check credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple[matched_credentials, missing_credentials]
"""
matched_credentials: dict[str, CredentialsMetaInput] = {}
missing_credentials: list[CredentialsMetaInput] = []
input_data = input_data or {}
# Get credential field info from block's input schema
credentials_fields_info = block.input_schema.get_credentials_fields_info()
@@ -107,33 +94,14 @@ class RunBlockTool(BaseTool):
available_creds = await creds_manager.store.get_all_creds(user_id)
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
# Get discriminator from input, falling back to schema default
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
# field_info.provider is a frozenset of acceptable providers
# field_info.supported_types is a frozenset of acceptable types
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in effective_field_info.provider
and cred.type in effective_field_info.supported_types
if cred.provider in field_info.provider
and cred.type in field_info.supported_types
),
None,
)
@@ -147,8 +115,8 @@ class RunBlockTool(BaseTool):
)
else:
# Create a placeholder for the missing credential
provider = next(iter(effective_field_info.provider), "unknown")
cred_type = next(iter(effective_field_info.supported_types), "api_key")
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing_credentials.append(
CredentialsMetaInput(
id=field_name,
@@ -216,9 +184,10 @@ class RunBlockTool(BaseTool):
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
# Check credentials
creds_manager = IntegrationCredentialsManager()
matched_credentials, missing_credentials = await self._check_block_credentials(
user_id, block, input_data
user_id, block
)
if missing_credentials:
@@ -254,48 +223,11 @@ class RunBlockTool(BaseTool):
)
try:
# Get or create user's workspace for CoPilot file operations
workspace = await get_or_create_workspace(user_id)
# Generate synthetic IDs for CoPilot context
# Each chat session is treated as its own agent with one continuous run
# This means:
# - graph_id (agent) = session (memories scoped to session when limit_to_agent=True)
# - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True)
# - node_exec_id = unique per block execution
synthetic_graph_id = f"copilot-session-{session.session_id}"
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
synthetic_node_id = f"copilot-node-{block_id}"
synthetic_node_exec_id = (
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
)
# Create unified execution context with all required fields
execution_context = ExecutionContext(
# Execution identity
user_id=user_id,
graph_id=synthetic_graph_id,
graph_exec_id=synthetic_graph_exec_id,
graph_version=1, # Versions are 1-indexed
node_id=synthetic_node_id,
node_exec_id=synthetic_node_exec_id,
# Workspace with session scoping
workspace_id=workspace.id,
session_id=session.session_id,
)
# Prepare kwargs for block execution
# Keep individual kwargs for backwards compatibility with existing blocks
# Fetch actual credentials and prepare kwargs for block execution
# Create execution context with defaults (blocks may require it)
exec_kwargs: dict[str, Any] = {
"user_id": user_id,
"execution_context": execution_context,
# Legacy: individual kwargs for blocks not yet using execution_context
"workspace_id": workspace.id,
"graph_exec_id": synthetic_graph_exec_id,
"node_exec_id": synthetic_node_exec_id,
"node_id": synthetic_node_id,
"graph_version": 1, # Versions are 1-indexed
"graph_id": synthetic_graph_id,
"execution_context": ExecutionContext(),
}
for field_name, cred_meta in matched_credentials.items():
@@ -323,16 +255,6 @@ class RunBlockTool(BaseTool):
):
outputs[output_name].append(output_data)
# Post-process outputs to save binary content to workspace
workspace_manager = WorkspaceManager(
user_id=user_id,
workspace_id=workspace.id,
session_id=session.session_id,
)
outputs = await process_binary_outputs(
dict(outputs), workspace_manager, block.name
)
return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully",
block_id=block_id,

View File

@@ -1,518 +0,0 @@
"""Tests for embedded binary detection in block outputs."""
import base64
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from .binary_output_processor import (
_decode_and_validate,
_expand_to_markers,
process_binary_outputs,
)
@pytest.fixture
def mock_workspace_manager():
"""Create a mock workspace manager that returns predictable file IDs."""
wm = MagicMock()
async def mock_write_file(content, filename):
file = MagicMock()
file.id = f"file-{filename[:10]}"
return file
wm.write_file = AsyncMock(side_effect=mock_write_file)
return wm
def _make_pdf_base64(size: int = 2000) -> str:
"""Create a valid PDF base64 string of specified size."""
pdf_content = b"%PDF-1.4 " + b"x" * size
return base64.b64encode(pdf_content).decode()
def _make_png_base64(size: int = 2000) -> str:
"""Create a valid PNG base64 string of specified size."""
png_content = b"\x89PNG\r\n\x1a\n" + b"\x00" * size
return base64.b64encode(png_content).decode()
# =============================================================================
# Decode and Validate Tests
# =============================================================================
class TestDecodeAndValidate:
"""Tests for _decode_and_validate function."""
def test_detects_pdf_magic_number(self):
"""Should detect valid PDF by magic number."""
pdf_b64 = _make_pdf_base64()
result = _decode_and_validate(pdf_b64)
assert result is not None
content, ext = result
assert ext == "pdf"
assert content.startswith(b"%PDF-")
def test_detects_png_magic_number(self):
"""Should detect valid PNG by magic number."""
png_b64 = _make_png_base64()
result = _decode_and_validate(png_b64)
assert result is not None
content, ext = result
assert ext == "png"
def test_detects_jpeg_magic_number(self):
"""Should detect valid JPEG by magic number."""
jpeg_content = b"\xff\xd8\xff\xe0" + b"\x00" * 2000
jpeg_b64 = base64.b64encode(jpeg_content).decode()
result = _decode_and_validate(jpeg_b64)
assert result is not None
_, ext = result
assert ext == "jpg"
def test_detects_gif_magic_number(self):
"""Should detect valid GIF by magic number."""
gif_content = b"GIF89a" + b"\x00" * 2000
gif_b64 = base64.b64encode(gif_content).decode()
result = _decode_and_validate(gif_b64)
assert result is not None
_, ext = result
assert ext == "gif"
def test_detects_webp_magic_number(self):
"""Should detect valid WebP by magic number."""
webp_content = b"RIFF\x00\x00\x00\x00WEBP" + b"\x00" * 2000
webp_b64 = base64.b64encode(webp_content).decode()
result = _decode_and_validate(webp_b64)
assert result is not None
_, ext = result
assert ext == "webp"
def test_rejects_small_content(self):
"""Should reject content smaller than threshold."""
small_pdf = b"%PDF-1.4 small"
small_b64 = base64.b64encode(small_pdf).decode()
result = _decode_and_validate(small_b64)
assert result is None
def test_rejects_no_magic_number(self):
"""Should reject content without recognized magic number."""
random_content = b"This is just random text" * 100
random_b64 = base64.b64encode(random_content).decode()
result = _decode_and_validate(random_b64)
assert result is None
def test_rejects_invalid_base64(self):
"""Should reject invalid base64."""
result = _decode_and_validate("not-valid-base64!!!")
assert result is None
def test_rejects_riff_without_webp(self):
"""Should reject RIFF files that aren't WebP (e.g., WAV)."""
wav_content = b"RIFF\x00\x00\x00\x00WAVE" + b"\x00" * 2000
wav_b64 = base64.b64encode(wav_content).decode()
result = _decode_and_validate(wav_b64)
assert result is None
def test_handles_line_wrapped_base64(self):
"""Should handle RFC 2045 line-wrapped base64."""
pdf_content = b"%PDF-1.4 " + b"x" * 2000
pdf_b64 = base64.b64encode(pdf_content).decode()
# Simulate line wrapping at 76 chars
wrapped = "\n".join(pdf_b64[i : i + 76] for i in range(0, len(pdf_b64), 76))
result = _decode_and_validate(wrapped)
assert result is not None
content, ext = result
assert ext == "pdf"
assert content == pdf_content
# =============================================================================
# Marker Expansion Tests
# =============================================================================
class TestExpandToMarkers:
"""Tests for _expand_to_markers function."""
def test_expands_base64_start_end_markers(self):
"""Should expand to include ---BASE64_START--- and ---BASE64_END---."""
text = "prefix\n---BASE64_START---\nABCDEF\n---BASE64_END---\nsuffix"
# Base64 "ABCDEF" is at position 26-32
start, end = _expand_to_markers(text, 26, 32)
assert text[start:end] == "---BASE64_START---\nABCDEF\n---BASE64_END---"
def test_expands_bracket_markers(self):
"""Should expand to include [BASE64] and [/BASE64] markers."""
text = "prefix[BASE64]ABCDEF[/BASE64]suffix"
# Base64 is at position 14-20
start, end = _expand_to_markers(text, 14, 20)
assert text[start:end] == "[BASE64]ABCDEF[/BASE64]"
def test_no_expansion_without_markers(self):
"""Should not expand if no markers present."""
text = "prefix ABCDEF suffix"
start, end = _expand_to_markers(text, 7, 13)
assert start == 7
assert end == 13
# =============================================================================
# Process Binary Outputs Tests
# =============================================================================
@pytest.fixture
def mock_scan():
"""Patch virus scanner for tests."""
with patch(
"backend.api.features.chat.tools.binary_output_processor.scan_content_safe",
new_callable=AsyncMock,
) as mock:
yield mock
class TestProcessBinaryOutputs:
"""Tests for process_binary_outputs function."""
@pytest.mark.asyncio
async def test_detects_embedded_pdf_in_stdout_logs(
self, mock_workspace_manager, mock_scan
):
"""Should detect and replace embedded PDF in stdout_logs."""
pdf_b64 = _make_pdf_base64()
stdout = f"PDF generated!\n---BASE64_START---\n{pdf_b64}\n---BASE64_END---\n"
outputs = {"stdout_logs": [stdout]}
result = await process_binary_outputs(
outputs, mock_workspace_manager, "ExecuteCodeBlock"
)
# Should contain workspace reference, not base64
assert "workspace://" in result["stdout_logs"][0]
assert pdf_b64 not in result["stdout_logs"][0]
assert "PDF generated!" in result["stdout_logs"][0]
mock_workspace_manager.write_file.assert_called_once()
@pytest.mark.asyncio
async def test_detects_embedded_png_without_markers(
self, mock_workspace_manager, mock_scan
):
"""Should detect embedded PNG even without markers."""
png_b64 = _make_png_base64()
stdout = f"Image created: {png_b64} done"
outputs = {"stdout_logs": [stdout]}
result = await process_binary_outputs(
outputs, mock_workspace_manager, "ExecuteCodeBlock"
)
assert "workspace://" in result["stdout_logs"][0]
assert "Image created:" in result["stdout_logs"][0]
assert "done" in result["stdout_logs"][0]
@pytest.mark.asyncio
async def test_preserves_small_strings(self, mock_workspace_manager, mock_scan):
"""Should not process small strings."""
outputs = {"stdout_logs": ["small output"]}
result = await process_binary_outputs(
outputs, mock_workspace_manager, "TestBlock"
)
assert result["stdout_logs"][0] == "small output"
mock_workspace_manager.write_file.assert_not_called()
@pytest.mark.asyncio
async def test_preserves_non_binary_large_strings(
self, mock_workspace_manager, mock_scan
):
"""Should preserve large strings that don't contain valid binary."""
large_text = "A" * 5000 # Large string - decodes to nulls, no magic number
outputs = {"stdout_logs": [large_text]}
result = await process_binary_outputs(
outputs, mock_workspace_manager, "TestBlock"
)
assert result["stdout_logs"][0] == large_text
mock_workspace_manager.write_file.assert_not_called()
@pytest.mark.asyncio
async def test_deduplicates_identical_content(
self, mock_workspace_manager, mock_scan
):
"""Should save identical content only once."""
pdf_b64 = _make_pdf_base64()
stdout1 = f"First: {pdf_b64}"
stdout2 = f"Second: {pdf_b64}"
outputs = {"stdout_logs": [stdout1, stdout2]}
result = await process_binary_outputs(
outputs, mock_workspace_manager, "TestBlock"
)
# Both should have references
assert "workspace://" in result["stdout_logs"][0]
assert "workspace://" in result["stdout_logs"][1]
# But only one write
assert mock_workspace_manager.write_file.call_count == 1
@pytest.mark.asyncio
async def test_handles_multiple_binaries_in_one_string(
self, mock_workspace_manager, mock_scan
):
"""Should handle multiple embedded binaries in a single string."""
pdf_b64 = _make_pdf_base64()
png_b64 = _make_png_base64()
stdout = f"PDF: {pdf_b64}\nPNG: {png_b64}"
outputs = {"stdout_logs": [stdout]}
result = await process_binary_outputs(
outputs, mock_workspace_manager, "TestBlock"
)
# Should have two workspace references
assert result["stdout_logs"][0].count("workspace://") == 2
assert mock_workspace_manager.write_file.call_count == 2
@pytest.mark.asyncio
async def test_processes_nested_structures(self, mock_workspace_manager, mock_scan):
"""Should recursively process nested dicts and lists."""
pdf_b64 = _make_pdf_base64()
outputs = {"result": [{"nested": {"deep": f"data: {pdf_b64}"}}]}
result = await process_binary_outputs(
outputs, mock_workspace_manager, "TestBlock"
)
assert "workspace://" in result["result"][0]["nested"]["deep"]
@pytest.mark.asyncio
async def test_graceful_degradation_on_save_failure(
self, mock_workspace_manager, mock_scan
):
"""Should preserve original on save failure."""
mock_workspace_manager.write_file = AsyncMock(
side_effect=Exception("Storage error")
)
pdf_b64 = _make_pdf_base64()
stdout = f"PDF: {pdf_b64}"
outputs = {"stdout_logs": [stdout]}
result = await process_binary_outputs(
outputs, mock_workspace_manager, "TestBlock"
)
# Should keep original since save failed
assert pdf_b64 in result["stdout_logs"][0]
# =============================================================================
# Offset Loop Tests (handling marker bleed-in)
# =============================================================================
class TestOffsetLoopHandling:
"""Tests for the offset-aligned decoding that handles marker bleed-in."""
def test_handles_4char_aligned_prefix(self):
"""Should detect base64 when a 4-char aligned prefix bleeds into match.
When 'TEST' (4 chars, aligned) bleeds in, offset 4 finds valid base64.
"""
pdf_content = b"%PDF-1.4 " + b"x" * 2000
pdf_b64 = base64.b64encode(pdf_content).decode()
# 4-char prefix (aligned)
with_prefix = f"TEST{pdf_b64}"
result = _decode_and_validate(with_prefix)
assert result is not None
content, ext = result
assert ext == "pdf"
assert content == pdf_content
def test_handles_8char_aligned_prefix(self):
"""Should detect base64 when an 8-char prefix bleeds into match."""
pdf_content = b"%PDF-1.4 " + b"x" * 2000
pdf_b64 = base64.b64encode(pdf_content).decode()
# 8-char prefix (aligned)
with_prefix = f"TESTTEST{pdf_b64}"
result = _decode_and_validate(with_prefix)
assert result is not None
content, ext = result
assert ext == "pdf"
def test_handles_misaligned_prefix(self):
"""Should handle misaligned prefix by finding a valid aligned offset.
'START' is 5 chars (misaligned). The loop tries offsets 0, 4, 8...
Since characters 0-4 include 'START' which is invalid base64 on its own,
we need the full PDF base64 to eventually decode correctly at some offset.
"""
pdf_content = b"%PDF-1.4 " + b"x" * 2000
pdf_b64 = base64.b64encode(pdf_content).decode()
# 5-char prefix - misaligned, but offset 4 should start mid-'START'
# and offset 8 will be past the prefix
with_prefix = f"START{pdf_b64}"
result = _decode_and_validate(with_prefix)
# Should find valid PDF at some offset (8 in this case)
assert result is not None
_, ext = result
assert ext == "pdf"
def test_handles_pdf_base64_start_marker_bleed(self):
"""Should handle PDF_BASE64_START marker bleeding into regex match.
This is the real-world case: regex matches 'STARTJVBERi0...' because
'START' chars are in the base64 alphabet. Offset loop skips past it.
PDF_BASE64_START is 16 chars (4-aligned), so offset 16 finds valid base64.
"""
pdf_content = b"%PDF-1.4 " + b"x" * 2000
pdf_b64 = base64.b64encode(pdf_content).decode()
# Simulate regex capturing 'PDF_BASE64_START' + base64 together
# This happens when there's no delimiter between marker and content
with_full_marker = f"PDF_BASE64_START{pdf_b64}"
result = _decode_and_validate(with_full_marker)
assert result is not None
_, ext = result
assert ext == "pdf"
def test_clean_base64_works_at_offset_zero(self):
"""Should detect clean base64 at offset 0 without issues."""
pdf_content = b"%PDF-1.4 " + b"x" * 2000
pdf_b64 = base64.b64encode(pdf_content).decode()
result = _decode_and_validate(pdf_b64)
assert result is not None
content, ext = result
assert ext == "pdf"
assert content == pdf_content
# =============================================================================
# PDF Marker Tests
# =============================================================================
class TestPdfMarkerExpansion:
"""Tests for PDF_BASE64_START/END marker handling."""
def test_expands_pdf_base64_start_marker(self):
"""Should expand to include PDF_BASE64_START marker."""
text = "prefixPDF_BASE64_STARTABCDEF"
# Base64 'ABCDEF' is at position 22-28
start, end = _expand_to_markers(text, 22, 28)
assert text[start:end] == "PDF_BASE64_STARTABCDEF"
def test_expands_pdf_base64_end_marker(self):
"""Should expand to include PDF_BASE64_END marker."""
text = "ABCDEFPDF_BASE64_ENDsuffix"
# Base64 'ABCDEF' is at position 0-6
start, end = _expand_to_markers(text, 0, 6)
assert text[start:end] == "ABCDEFPDF_BASE64_END"
def test_expands_both_pdf_markers(self):
"""Should expand to include both PDF_BASE64_START and END."""
text = "xPDF_BASE64_STARTABCDEFPDF_BASE64_ENDy"
# Base64 'ABCDEF' is at position 17-23
start, end = _expand_to_markers(text, 17, 23)
assert text[start:end] == "PDF_BASE64_STARTABCDEFPDF_BASE64_END"
def test_partial_marker_not_expanded(self):
"""Should not expand if only partial marker present."""
text = "BASE64_STARTABCDEF" # Missing 'PDF_' prefix
start, end = _expand_to_markers(text, 12, 18)
# Should not expand since it's not the full marker
assert start == 12
assert end == 18
@pytest.mark.asyncio
async def test_full_pipeline_with_pdf_markers(self, mock_workspace_manager):
"""Test full pipeline with PDF_BASE64_START/END markers."""
pdf_b64 = _make_pdf_base64()
stdout = f"Output: PDF_BASE64_START{pdf_b64}PDF_BASE64_END done"
outputs = {"stdout_logs": [stdout]}
with patch(
"backend.api.features.chat.tools.binary_output_processor.scan_content_safe",
new_callable=AsyncMock,
):
result = await process_binary_outputs(
outputs, mock_workspace_manager, "TestBlock"
)
# Should have workspace reference
assert "workspace://" in result["stdout_logs"][0]
# Markers should be consumed along with base64
assert "PDF_BASE64_START" not in result["stdout_logs"][0]
assert "PDF_BASE64_END" not in result["stdout_logs"][0]
# Surrounding text preserved
assert "Output:" in result["stdout_logs"][0]
assert "done" in result["stdout_logs"][0]
# =============================================================================
# Virus Scanning Tests
# =============================================================================
class TestVirusScanning:
"""Tests for virus scanning integration."""
@pytest.mark.asyncio
async def test_calls_virus_scanner_before_save(self, mock_workspace_manager):
"""Should call scan_content_safe before writing file."""
pdf_b64 = _make_pdf_base64()
stdout = f"PDF: {pdf_b64}"
outputs = {"stdout_logs": [stdout]}
with patch(
"backend.api.features.chat.tools.binary_output_processor.scan_content_safe",
new_callable=AsyncMock,
) as mock_scan:
result = await process_binary_outputs(
outputs, mock_workspace_manager, "TestBlock"
)
# Verify scanner was called
mock_scan.assert_called_once()
# Verify file was written after scan
mock_workspace_manager.write_file.assert_called_once()
# Verify result has workspace reference
assert "workspace://" in result["stdout_logs"][0]
@pytest.mark.asyncio
async def test_virus_scan_failure_preserves_original(self, mock_workspace_manager):
"""Should preserve original if virus scan fails."""
pdf_b64 = _make_pdf_base64()
stdout = f"PDF: {pdf_b64}"
outputs = {"stdout_logs": [stdout]}
with patch(
"backend.api.features.chat.tools.binary_output_processor.scan_content_safe",
new_callable=AsyncMock,
side_effect=Exception("Virus detected"),
):
result = await process_binary_outputs(
outputs, mock_workspace_manager, "TestBlock"
)
# Should keep original since scan failed
assert pdf_b64 in result["stdout_logs"][0]
# File should not have been written
mock_workspace_manager.write_file.assert_not_called()

View File

@@ -8,12 +8,7 @@ from backend.api.features.library import model as library_model
from backend.api.features.store import db as store_db
from backend.data import graph as graph_db
from backend.data.graph import GraphModel
from backend.data.model import (
CredentialsFieldInfo,
CredentialsMetaInput,
HostScopedCredentials,
OAuth2Credentials,
)
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import NotFoundError
@@ -271,21 +266,13 @@ async def match_user_credentials_to_graph(
credential_requirements,
_node_fields,
) in aggregated_creds.items():
# Find first matching credential by provider, type, and scopes
# Find first matching credential by provider and type
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in credential_requirements.provider
and cred.type in credential_requirements.supported_types
and (
cred.type != "oauth2"
or _credential_has_required_scopes(cred, credential_requirements)
)
and (
cred.type != "host_scoped"
or _credential_is_for_host(cred, credential_requirements)
)
),
None,
)
@@ -309,17 +296,10 @@ async def match_user_credentials_to_graph(
f"{credential_field_name} (validation failed: {e})"
)
else:
# Build a helpful error message including scope requirements
error_parts = [
f"provider in {list(credential_requirements.provider)}",
f"type in {list(credential_requirements.supported_types)}",
]
if credential_requirements.required_scopes:
error_parts.append(
f"scopes including {list(credential_requirements.required_scopes)}"
)
missing_creds.append(
f"{credential_field_name} (requires {', '.join(error_parts)})"
f"{credential_field_name} "
f"(requires provider in {list(credential_requirements.provider)}, "
f"type in {list(credential_requirements.supported_types)})"
)
logger.info(
@@ -329,35 +309,6 @@ async def match_user_credentials_to_graph(
return graph_credentials_inputs, missing_creds
def _credential_has_required_scopes(
credential: OAuth2Credentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""Check if an OAuth2 credential has all the scopes required by the input."""
# If no scopes are required, any credential matches
if not requirements.required_scopes:
return True
# Check that credential scopes are a superset of required scopes
return set(credential.scopes).issuperset(requirements.required_scopes)
def _credential_is_for_host(
credential: HostScopedCredentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""Check if a host-scoped credential matches the host required by the input."""
# We need to know the host to match host-scoped credentials to.
# Graph.aggregate_credentials_inputs() adds the node's set URL value (if any)
# to discriminator_values. No discriminator_values -> no host to match against.
if not requirements.discriminator_values:
return True
# Check that credential host matches required host.
# Host-scoped credential inputs are grouped by host, so any item from the set works.
return credential.matches_url(list(requirements.discriminator_values)[0])
async def check_user_has_required_credentials(
user_id: str,
required_credentials: list[CredentialsMetaInput],

View File

@@ -1,620 +0,0 @@
"""CoPilot tools for workspace file operations."""
import base64
import logging
from typing import Any, Optional
from pydantic import BaseModel
from backend.api.features.chat.model import ChatSession
from backend.data.workspace import get_or_create_workspace
from backend.util.settings import Config
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import WorkspaceManager
from .base import BaseTool
from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
class WorkspaceFileInfoData(BaseModel):
"""Data model for workspace file information (not a response itself)."""
file_id: str
name: str
path: str
mime_type: str
size_bytes: int
class WorkspaceFileListResponse(ToolResponseBase):
"""Response containing list of workspace files."""
type: ResponseType = ResponseType.WORKSPACE_FILE_LIST
files: list[WorkspaceFileInfoData]
total_count: int
class WorkspaceFileContentResponse(ToolResponseBase):
"""Response containing workspace file content (legacy, for small text files)."""
type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT
file_id: str
name: str
path: str
mime_type: str
content_base64: str
class WorkspaceFileMetadataResponse(ToolResponseBase):
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA
file_id: str
name: str
path: str
mime_type: str
size_bytes: int
download_url: str
preview: str | None = None # First 500 chars for text files
class WorkspaceWriteResponse(ToolResponseBase):
"""Response after writing a file to workspace."""
type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN
file_id: str
name: str
path: str
size_bytes: int
class WorkspaceDeleteResponse(ToolResponseBase):
"""Response after deleting a file from workspace."""
type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED
file_id: str
success: bool
class ListWorkspaceFilesTool(BaseTool):
"""Tool for listing files in user's workspace."""
@property
def name(self) -> str:
return "list_workspace_files"
@property
def description(self) -> str:
return (
"List files in the user's workspace. "
"Returns file names, paths, sizes, and metadata. "
"Optionally filter by path prefix."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"path_prefix": {
"type": "string",
"description": (
"Optional path prefix to filter files "
"(e.g., '/documents/' to list only files in documents folder). "
"By default, only files from the current session are listed."
),
},
"limit": {
"type": "integer",
"description": "Maximum number of files to return (default 50, max 100)",
"minimum": 1,
"maximum": 100,
},
"include_all_sessions": {
"type": "boolean",
"description": (
"If true, list files from all sessions. "
"Default is false (only current session's files)."
),
},
},
"required": [],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
path_prefix: Optional[str] = kwargs.get("path_prefix")
limit = min(kwargs.get("limit", 50), 100)
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
files = await manager.list_files(
path=path_prefix,
limit=limit,
include_all_sessions=include_all_sessions,
)
total = await manager.get_file_count(
path=path_prefix,
include_all_sessions=include_all_sessions,
)
file_infos = [
WorkspaceFileInfoData(
file_id=f.id,
name=f.name,
path=f.path,
mime_type=f.mimeType,
size_bytes=f.sizeBytes,
)
for f in files
]
scope_msg = "all sessions" if include_all_sessions else "current session"
return WorkspaceFileListResponse(
files=file_infos,
total_count=total,
message=f"Found {len(files)} files in workspace ({scope_msg})",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error listing workspace files: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to list workspace files: {str(e)}",
error=str(e),
session_id=session_id,
)
class ReadWorkspaceFileTool(BaseTool):
"""Tool for reading file content from workspace."""
# Size threshold for returning full content vs metadata+URL
# Files larger than this return metadata with download URL to prevent context bloat
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
# Preview size for text files
PREVIEW_SIZE = 500
@property
def name(self) -> str:
return "read_workspace_file"
@property
def description(self) -> str:
return (
"Read a file from the user's workspace. "
"Specify either file_id or path to identify the file. "
"For small text files, returns content directly. "
"For large or binary files, returns metadata and a download URL. "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"file_id": {
"type": "string",
"description": "The file's unique ID (from list_workspace_files)",
},
"path": {
"type": "string",
"description": (
"The virtual file path (e.g., '/documents/report.pdf'). "
"Scoped to current session by default."
),
},
"force_download_url": {
"type": "boolean",
"description": (
"If true, always return metadata+URL instead of inline content. "
"Default is false (auto-selects based on file size/type)."
),
},
},
"required": [], # At least one must be provided
}
@property
def requires_auth(self) -> bool:
return True
def _is_text_mime_type(self, mime_type: str) -> bool:
"""Check if the MIME type is a text-based type."""
text_types = [
"text/",
"application/json",
"application/xml",
"application/javascript",
"application/x-python",
"application/x-sh",
]
return any(mime_type.startswith(t) for t in text_types)
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
force_download_url: bool = kwargs.get("force_download_url", False)
if not file_id and not path:
return ErrorResponse(
message="Please provide either file_id or path",
session_id=session_id,
)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
# Get file info
if file_id:
file_info = await manager.get_file_info(file_id)
if file_info is None:
return ErrorResponse(
message=f"File not found: {file_id}",
session_id=session_id,
)
target_file_id = file_id
else:
# path is guaranteed to be non-None here due to the check above
assert path is not None
file_info = await manager.get_file_info_by_path(path)
if file_info is None:
return ErrorResponse(
message=f"File not found at path: {path}",
session_id=session_id,
)
target_file_id = file_info.id
# Decide whether to return inline content or metadata+URL
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
is_text_file = self._is_text_mime_type(file_info.mimeType)
# Return inline content for small text files (unless force_download_url)
if is_small_file and is_text_file and not force_download_url:
content = await manager.read_file_by_id(target_file_id)
content_b64 = base64.b64encode(content).decode("utf-8")
return WorkspaceFileContentResponse(
file_id=file_info.id,
name=file_info.name,
path=file_info.path,
mime_type=file_info.mimeType,
content_base64=content_b64,
message=f"Successfully read file: {file_info.name}",
session_id=session_id,
)
# Return metadata + workspace:// reference for large or binary files
# This prevents context bloat (100KB file = ~133KB as base64)
# Use workspace:// format so frontend urlTransform can add proxy prefix
download_url = f"workspace://{target_file_id}"
# Generate preview for text files
preview: str | None = None
if is_text_file:
try:
content = await manager.read_file_by_id(target_file_id)
preview_text = content[: self.PREVIEW_SIZE].decode(
"utf-8", errors="replace"
)
if len(content) > self.PREVIEW_SIZE:
preview_text += "..."
preview = preview_text
except Exception:
pass # Preview is optional
return WorkspaceFileMetadataResponse(
file_id=file_info.id,
name=file_info.name,
path=file_info.path,
mime_type=file_info.mimeType,
size_bytes=file_info.sizeBytes,
download_url=download_url,
preview=preview,
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
session_id=session_id,
)
except FileNotFoundError as e:
return ErrorResponse(
message=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Error reading workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to read workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)
class WriteWorkspaceFileTool(BaseTool):
"""Tool for writing files to workspace."""
@property
def name(self) -> str:
return "write_workspace_file"
@property
def description(self) -> str:
return (
"Write or create a file in the user's workspace. "
"Provide the content as a base64-encoded string. "
f"Maximum file size is {Config().max_file_size_mb}MB. "
"Files are saved to the current session's folder by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "Name for the file (e.g., 'report.pdf')",
},
"content_base64": {
"type": "string",
"description": "Base64-encoded file content",
},
"path": {
"type": "string",
"description": (
"Optional virtual path where to save the file "
"(e.g., '/documents/report.pdf'). "
"Defaults to '/{filename}'. Scoped to current session."
),
},
"mime_type": {
"type": "string",
"description": (
"Optional MIME type of the file. "
"Auto-detected from filename if not provided."
),
},
"overwrite": {
"type": "boolean",
"description": "Whether to overwrite if file exists at path (default: false)",
},
},
"required": ["filename", "content_base64"],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
filename: str = kwargs.get("filename", "")
content_b64: str = kwargs.get("content_base64", "")
path: Optional[str] = kwargs.get("path")
mime_type: Optional[str] = kwargs.get("mime_type")
overwrite: bool = kwargs.get("overwrite", False)
if not filename:
return ErrorResponse(
message="Please provide a filename",
session_id=session_id,
)
if not content_b64:
return ErrorResponse(
message="Please provide content_base64",
session_id=session_id,
)
# Decode content
try:
content = base64.b64decode(content_b64)
except Exception:
return ErrorResponse(
message="Invalid base64-encoded content",
session_id=session_id,
)
# Check size
max_file_size = Config().max_file_size_mb * 1024 * 1024
if len(content) > max_file_size:
return ErrorResponse(
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
session_id=session_id,
)
try:
# Virus scan
await scan_content_safe(content, filename=filename)
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
file_record = await manager.write_file(
content=content,
filename=filename,
path=path,
mime_type=mime_type,
overwrite=overwrite,
)
return WorkspaceWriteResponse(
file_id=file_record.id,
name=file_record.name,
path=file_record.path,
size_bytes=file_record.sizeBytes,
message=f"Successfully wrote file: {file_record.name}",
session_id=session_id,
)
except ValueError as e:
return ErrorResponse(
message=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Error writing workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to write workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)
class DeleteWorkspaceFileTool(BaseTool):
"""Tool for deleting files from workspace."""
@property
def name(self) -> str:
return "delete_workspace_file"
@property
def description(self) -> str:
return (
"Delete a file from the user's workspace. "
"Specify either file_id or path to identify the file. "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"file_id": {
"type": "string",
"description": "The file's unique ID (from list_workspace_files)",
},
"path": {
"type": "string",
"description": (
"The virtual file path (e.g., '/documents/report.pdf'). "
"Scoped to current session by default."
),
},
},
"required": [], # At least one must be provided
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
if not file_id and not path:
return ErrorResponse(
message="Please provide either file_id or path",
session_id=session_id,
)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
# Determine the file_id to delete
target_file_id: str
if file_id:
target_file_id = file_id
else:
# path is guaranteed to be non-None here due to the check above
assert path is not None
file_info = await manager.get_file_info_by_path(path)
if file_info is None:
return ErrorResponse(
message=f"File not found at path: {path}",
session_id=session_id,
)
target_file_id = file_info.id
success = await manager.delete_file(target_file_id)
if not success:
return ErrorResponse(
message=f"File not found: {target_file_id}",
session_id=session_id,
)
return WorkspaceDeleteResponse(
file_id=target_file_id,
success=True,
message="File deleted successfully",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to delete workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)

View File

@@ -19,10 +19,7 @@ from backend.data.graph import GraphSettings
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
from backend.data.model import CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import (
on_graph_activate,
on_graph_deactivate,
)
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
from backend.util.clients import get_scheduler_client
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
from backend.util.json import SafeJson
@@ -42,7 +39,6 @@ async def list_library_agents(
sort_by: library_model.LibraryAgentSort = library_model.LibraryAgentSort.UPDATED_AT,
page: int = 1,
page_size: int = 50,
include_executions: bool = False,
) -> library_model.LibraryAgentResponse:
"""
Retrieves a paginated list of LibraryAgent records for a given user.
@@ -53,9 +49,6 @@ async def list_library_agents(
sort_by: Sorting field (createdAt, updatedAt, isFavorite, isCreatedByUser).
page: Current page (1-indexed).
page_size: Number of items per page.
include_executions: Whether to include execution data for status calculation.
Defaults to False for performance (UI fetches status separately).
Set to True when accurate status/metrics are needed (e.g., agent generator).
Returns:
A LibraryAgentResponse containing the list of agents and pagination details.
@@ -83,6 +76,7 @@ async def list_library_agents(
"isArchived": False,
}
# Build search filter if applicable
if search_term:
where_clause["OR"] = [
{
@@ -99,6 +93,7 @@ async def list_library_agents(
},
]
# Determine sorting
order_by: prisma.types.LibraryAgentOrderByInput | None = None
if sort_by == library_model.LibraryAgentSort.CREATED_AT:
@@ -110,7 +105,7 @@ async def list_library_agents(
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
where=where_clause,
include=library_agent_include(
user_id, include_nodes=False, include_executions=include_executions
user_id, include_nodes=False, include_executions=False
),
order=order_by,
skip=(page - 1) * page_size,
@@ -540,92 +535,6 @@ async def update_agent_version_in_library(
return library_model.LibraryAgent.from_db(lib)
async def create_graph_in_library(
graph: graph_db.Graph,
user_id: str,
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
"""Create a new graph and add it to the user's library."""
graph.version = 1
graph_model = graph_db.make_graph_model(graph, user_id)
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=True)
created_graph = await graph_db.create_graph(graph_model, user_id)
library_agents = await create_library_agent(
graph=created_graph,
user_id=user_id,
sensitive_action_safe_mode=True,
create_library_agents_for_sub_graphs=False,
)
if created_graph.is_active:
created_graph = await on_graph_activate(created_graph, user_id=user_id)
return created_graph, library_agents[0]
async def update_graph_in_library(
graph: graph_db.Graph,
user_id: str,
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
"""Create a new version of an existing graph and update the library entry."""
existing_versions = await graph_db.get_graph_all_versions(graph.id, user_id)
current_active_version = (
next((v for v in existing_versions if v.is_active), None)
if existing_versions
else None
)
graph.version = (
max(v.version for v in existing_versions) + 1 if existing_versions else 1
)
graph_model = graph_db.make_graph_model(graph, user_id)
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=False)
created_graph = await graph_db.create_graph(graph_model, user_id)
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
if not library_agent:
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
library_agent = await update_library_agent_version_and_settings(
user_id, created_graph
)
if created_graph.is_active:
created_graph = await on_graph_activate(created_graph, user_id=user_id)
await graph_db.set_graph_active_version(
graph_id=created_graph.id,
version=created_graph.version,
user_id=user_id,
)
if current_active_version:
await on_graph_deactivate(current_active_version, user_id=user_id)
return created_graph, library_agent
async def update_library_agent_version_and_settings(
user_id: str, agent_graph: graph_db.GraphModel
) -> library_model.LibraryAgent:
"""Update library agent to point to new graph version and sync settings."""
library = await update_agent_version_in_library(
user_id, agent_graph.id, agent_graph.version
)
updated_settings = GraphSettings.from_graph(
graph=agent_graph,
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
)
if updated_settings != library.settings:
library = await update_library_agent(
library_agent_id=library.id,
user_id=user_id,
settings=updated_settings,
)
return library
async def update_library_agent(
library_agent_id: str,
user_id: str,

View File

@@ -9,7 +9,6 @@ import pydantic
from backend.data.block import BlockInput
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
from backend.util.json import loads as json_loads
from backend.util.models import Pagination
if TYPE_CHECKING:
@@ -17,10 +16,10 @@ if TYPE_CHECKING:
class LibraryAgentStatus(str, Enum):
COMPLETED = "COMPLETED"
HEALTHY = "HEALTHY"
WAITING = "WAITING"
ERROR = "ERROR"
COMPLETED = "COMPLETED" # All runs completed
HEALTHY = "HEALTHY" # Agent is running (not all runs have completed)
WAITING = "WAITING" # Agent is queued or waiting to start
ERROR = "ERROR" # Agent is in an error state
class MarketplaceListingCreator(pydantic.BaseModel):
@@ -40,30 +39,6 @@ class MarketplaceListing(pydantic.BaseModel):
creator: MarketplaceListingCreator
class RecentExecution(pydantic.BaseModel):
"""Summary of a recent execution for quality assessment.
Used by the LLM to understand the agent's recent performance with specific examples
rather than just aggregate statistics.
"""
status: str
correctness_score: float | None = None
activity_summary: str | None = None
def _parse_settings(settings: dict | str | None) -> GraphSettings:
"""Parse settings from database, handling both dict and string formats."""
if settings is None:
return GraphSettings()
try:
if isinstance(settings, str):
settings = json_loads(settings)
return GraphSettings.model_validate(settings)
except Exception:
return GraphSettings()
class LibraryAgent(pydantic.BaseModel):
"""
Represents an agent in the library, including metadata for display and
@@ -73,7 +48,7 @@ class LibraryAgent(pydantic.BaseModel):
id: str
graph_id: str
graph_version: int
owner_user_id: str
owner_user_id: str # ID of user who owns/created this agent graph
image_url: str | None
@@ -89,7 +64,7 @@ class LibraryAgent(pydantic.BaseModel):
description: str
instructions: str | None = None
input_schema: dict[str, Any]
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
output_schema: dict[str, Any]
credentials_input_schema: dict[str, Any] | None = pydantic.Field(
description="Input schema for credentials required by the agent",
@@ -106,19 +81,25 @@ class LibraryAgent(pydantic.BaseModel):
)
trigger_setup_info: Optional[GraphTriggerInfo] = None
# Indicates whether there's a new output (based on recent runs)
new_output: bool
execution_count: int = 0
success_rate: float | None = None
avg_correctness_score: float | None = None
recent_executions: list[RecentExecution] = pydantic.Field(
default_factory=list,
description="List of recent executions with status, score, and summary",
)
# Whether the user can access the underlying graph
can_access_graph: bool
# Indicates if this agent is the latest version
is_latest_version: bool
# Whether the agent is marked as favorite by the user
is_favorite: bool
# Recommended schedule cron (from marketplace agents)
recommended_schedule_cron: str | None = None
# User-specific settings for this library agent
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
# Marketplace listing information if the agent has been published
marketplace_listing: Optional["MarketplaceListing"] = None
@staticmethod
@@ -142,6 +123,7 @@ class LibraryAgent(pydantic.BaseModel):
agent_updated_at = agent.AgentGraph.updatedAt
lib_agent_updated_at = agent.updatedAt
# Compute updated_at as the latest between library agent and graph
updated_at = (
max(agent_updated_at, lib_agent_updated_at)
if agent_updated_at
@@ -154,6 +136,7 @@ class LibraryAgent(pydantic.BaseModel):
creator_name = agent.Creator.name or "Unknown"
creator_image_url = agent.Creator.avatarUrl or ""
# Logic to calculate status and new_output
week_ago = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
days=7
)
@@ -162,55 +145,13 @@ class LibraryAgent(pydantic.BaseModel):
status = status_result.status
new_output = status_result.new_output
execution_count = len(executions)
success_rate: float | None = None
avg_correctness_score: float | None = None
if execution_count > 0:
success_count = sum(
1
for e in executions
if e.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED
)
success_rate = (success_count / execution_count) * 100
correctness_scores = []
for e in executions:
if e.stats and isinstance(e.stats, dict):
score = e.stats.get("correctness_score")
if score is not None and isinstance(score, (int, float)):
correctness_scores.append(float(score))
if correctness_scores:
avg_correctness_score = sum(correctness_scores) / len(
correctness_scores
)
recent_executions: list[RecentExecution] = []
for e in executions:
exec_score: float | None = None
exec_summary: str | None = None
if e.stats and isinstance(e.stats, dict):
score = e.stats.get("correctness_score")
if score is not None and isinstance(score, (int, float)):
exec_score = float(score)
summary = e.stats.get("activity_status")
if summary is not None and isinstance(summary, str):
exec_summary = summary
exec_status = (
e.executionStatus.value
if hasattr(e.executionStatus, "value")
else str(e.executionStatus)
)
recent_executions.append(
RecentExecution(
status=exec_status,
correctness_score=exec_score,
activity_summary=exec_summary,
)
)
# Check if user can access the graph
can_access_graph = agent.AgentGraph.userId == agent.userId
# Hard-coded to True until a method to check is implemented
is_latest_version = True
# Build marketplace_listing if available
marketplace_listing_data = None
if store_listing and store_listing.ActiveVersion and profile:
creator_data = MarketplaceListingCreator(
@@ -249,15 +190,11 @@ class LibraryAgent(pydantic.BaseModel):
has_sensitive_action=graph.has_sensitive_action,
trigger_setup_info=graph.trigger_setup_info,
new_output=new_output,
execution_count=execution_count,
success_rate=success_rate,
avg_correctness_score=avg_correctness_score,
recent_executions=recent_executions,
can_access_graph=can_access_graph,
is_latest_version=is_latest_version,
is_favorite=agent.isFavorite,
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
settings=_parse_settings(agent.settings),
settings=GraphSettings.model_validate(agent.settings),
marketplace_listing=marketplace_listing_data,
)
@@ -283,15 +220,18 @@ def _calculate_agent_status(
if not executions:
return AgentStatusResult(status=LibraryAgentStatus.COMPLETED, new_output=False)
# Track how many times each execution status appears
status_counts = {status: 0 for status in prisma.enums.AgentExecutionStatus}
new_output = False
for execution in executions:
# Check if there's a completed run more recent than `recent_threshold`
if execution.createdAt >= recent_threshold:
if execution.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED:
new_output = True
status_counts[execution.executionStatus] += 1
# Determine the final status based on counts
if status_counts[prisma.enums.AgentExecutionStatus.FAILED] > 0:
return AgentStatusResult(status=LibraryAgentStatus.ERROR, new_output=new_output)
elif status_counts[prisma.enums.AgentExecutionStatus.QUEUED] > 0:

View File

@@ -112,7 +112,6 @@ async def get_store_agents(
description=agent["description"],
runs=agent["runs"],
rating=agent["rating"],
agent_graph_id=agent.get("agentGraphId", ""),
)
store_agents.append(store_agent)
except Exception as e:
@@ -171,7 +170,6 @@ async def get_store_agents(
description=agent.description,
runs=agent.runs,
rating=agent.rating,
agent_graph_id=agent.agentGraphId,
)
# Add to the list only if creation was successful
store_agents.append(store_agent)

View File

@@ -454,9 +454,6 @@ async def test_unified_hybrid_search_pagination(
cleanup_embeddings: list,
):
"""Test unified search pagination works correctly."""
# Use a unique search term to avoid matching other test data
unique_term = f"xyzpagtest{uuid.uuid4().hex[:8]}"
# Create multiple items
content_ids = []
for i in range(5):
@@ -468,14 +465,14 @@ async def test_unified_hybrid_search_pagination(
content_type=ContentType.BLOCK,
content_id=content_id,
embedding=mock_embedding,
searchable_text=f"{unique_term} item number {i}",
searchable_text=f"pagination test item number {i}",
metadata={"index": i},
user_id=None,
)
# Get first page
page1_results, total1 = await unified_hybrid_search(
query=unique_term,
query="pagination test",
content_types=[ContentType.BLOCK],
page=1,
page_size=2,
@@ -483,7 +480,7 @@ async def test_unified_hybrid_search_pagination(
# Get second page
page2_results, total2 = await unified_hybrid_search(
query=unique_term,
query="pagination test",
content_types=[ContentType.BLOCK],
page=2,
page_size=2,

View File

@@ -600,7 +600,6 @@ async def hybrid_search(
sa.featured,
sa.is_available,
sa.updated_at,
sa."agentGraphId",
-- Searchable text for BM25 reranking
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
-- Semantic score
@@ -660,7 +659,6 @@ async def hybrid_search(
featured,
is_available,
updated_at,
"agentGraphId",
searchable_text,
semantic_score,
lexical_score,

View File

@@ -38,7 +38,6 @@ class StoreAgent(pydantic.BaseModel):
description: str
runs: int
rating: float
agent_graph_id: str
class StoreAgentsResponse(pydantic.BaseModel):

View File

@@ -26,13 +26,11 @@ def test_store_agent():
description="Test description",
runs=50,
rating=4.5,
agent_graph_id="test-graph-id",
)
assert agent.slug == "test-agent"
assert agent.agent_name == "Test Agent"
assert agent.runs == 50
assert agent.rating == 4.5
assert agent.agent_graph_id == "test-graph-id"
def test_store_agents_response():
@@ -48,7 +46,6 @@ def test_store_agents_response():
description="Test description",
runs=50,
rating=4.5,
agent_graph_id="test-graph-id",
)
],
pagination=store_model.Pagination(

View File

@@ -82,7 +82,6 @@ def test_get_agents_featured(
description="Featured agent description",
runs=100,
rating=4.5,
agent_graph_id="test-graph-1",
)
],
pagination=store_model.Pagination(
@@ -128,7 +127,6 @@ def test_get_agents_by_creator(
description="Creator agent description",
runs=50,
rating=4.0,
agent_graph_id="test-graph-2",
)
],
pagination=store_model.Pagination(
@@ -174,7 +172,6 @@ def test_get_agents_sorted(
description="Top agent description",
runs=1000,
rating=5.0,
agent_graph_id="test-graph-3",
)
],
pagination=store_model.Pagination(
@@ -220,7 +217,6 @@ def test_get_agents_search(
description="Specific search term description",
runs=75,
rating=4.2,
agent_graph_id="test-graph-search",
)
],
pagination=store_model.Pagination(
@@ -266,7 +262,6 @@ def test_get_agents_category(
description="Category agent description",
runs=60,
rating=4.1,
agent_graph_id="test-graph-category",
)
],
pagination=store_model.Pagination(
@@ -311,7 +306,6 @@ def test_get_agents_pagination(
description=f"Agent {i} description",
runs=i * 10,
rating=4.0,
agent_graph_id="test-graph-2",
)
for i in range(5)
],

View File

@@ -33,7 +33,6 @@ class TestCacheDeletion:
description="Test description",
runs=100,
rating=4.5,
agent_graph_id="test-graph-id",
)
],
pagination=Pagination(

View File

@@ -101,6 +101,7 @@ from backend.util.timezone_utils import (
from backend.util.virus_scanner import scan_content_safe
from .library import db as library_db
from .library import model as library_model
from .store.model import StoreAgentDetails
@@ -822,16 +823,18 @@ async def update_graph(
graph: graph_db.Graph,
user_id: Annotated[str, Security(get_user_id)],
) -> graph_db.GraphModel:
# Sanity check
if graph.id and graph.id != graph_id:
raise HTTPException(400, detail="Graph ID does not match ID in URI")
# Determine new version
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
if not existing_versions:
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
latest_version_number = max(g.version for g in existing_versions)
graph.version = latest_version_number + 1
graph.version = max(g.version for g in existing_versions) + 1
current_active_version = next((v for v in existing_versions if v.is_active), None)
graph = graph_db.make_graph_model(graph, user_id)
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
graph.validate_graph(for_run=False)
@@ -839,23 +842,27 @@ async def update_graph(
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
if new_graph_version.is_active:
await library_db.update_library_agent_version_and_settings(
user_id, new_graph_version
)
# Keep the library agent up to date with the new active version
await _update_library_agent_version_and_settings(user_id, new_graph_version)
# Handle activation of the new graph first to ensure continuity
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
# Ensure new version is the only active version
await graph_db.set_graph_active_version(
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
)
if current_active_version:
# Handle deactivation of the previously active version
await on_graph_deactivate(current_active_version, user_id=user_id)
# Fetch new graph version *with sub-graphs* (needed for credentials input schema)
new_graph_version_with_subgraphs = await graph_db.get_graph(
graph_id,
new_graph_version.version,
user_id=user_id,
include_subgraphs=True,
)
assert new_graph_version_with_subgraphs
assert new_graph_version_with_subgraphs # make type checker happy
return new_graph_version_with_subgraphs
@@ -893,15 +900,33 @@ async def set_graph_active_version(
)
# Keep the library agent up to date with the new active version
await library_db.update_library_agent_version_and_settings(
user_id, new_active_graph
)
await _update_library_agent_version_and_settings(user_id, new_active_graph)
if current_active_graph and current_active_graph.version != new_active_version:
# Handle deactivation of the previously active version
await on_graph_deactivate(current_active_graph, user_id=user_id)
async def _update_library_agent_version_and_settings(
user_id: str, agent_graph: graph_db.GraphModel
) -> library_model.LibraryAgent:
library = await library_db.update_agent_version_in_library(
user_id, agent_graph.id, agent_graph.version
)
updated_settings = GraphSettings.from_graph(
graph=agent_graph,
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
)
if updated_settings != library.settings:
library = await library_db.update_library_agent(
library_agent_id=library.id,
user_id=user_id,
settings=updated_settings,
)
return library
@v1_router.patch(
path="/graphs/{graph_id}/settings",
summary="Update graph settings",

View File

@@ -1 +0,0 @@
# Workspace API feature module

View File

@@ -1,122 +0,0 @@
"""
Workspace API routes for managing user file storage.
"""
import logging
import re
from typing import Annotated
from urllib.parse import quote
import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
from fastapi.responses import Response
from backend.data.workspace import get_workspace, get_workspace_file
from backend.util.workspace_storage import get_workspace_storage
def _sanitize_filename_for_header(filename: str) -> str:
"""
Sanitize filename for Content-Disposition header to prevent header injection.
Removes/replaces characters that could break the header or inject new headers.
Uses RFC5987 encoding for non-ASCII characters.
"""
# Remove CR, LF, and null bytes (header injection prevention)
sanitized = re.sub(r"[\r\n\x00]", "", filename)
# Escape quotes
sanitized = sanitized.replace('"', '\\"')
# For non-ASCII, use RFC5987 filename* parameter
# Check if filename has non-ASCII characters
try:
sanitized.encode("ascii")
return f'attachment; filename="{sanitized}"'
except UnicodeEncodeError:
# Use RFC5987 encoding for UTF-8 filenames
encoded = quote(sanitized, safe="")
return f"attachment; filename*=UTF-8''{encoded}"
logger = logging.getLogger(__name__)
router = fastapi.APIRouter(
dependencies=[fastapi.Security(requires_user)],
)
def _create_streaming_response(content: bytes, file) -> Response:
"""Create a streaming response for file content."""
return Response(
content=content,
media_type=file.mimeType,
headers={
"Content-Disposition": _sanitize_filename_for_header(file.name),
"Content-Length": str(len(content)),
},
)
async def _create_file_download_response(file) -> Response:
"""
Create a download response for a workspace file.
Handles both local storage (direct streaming) and GCS (signed URL redirect
with fallback to streaming).
"""
storage = await get_workspace_storage()
# For local storage, stream the file directly
if file.storagePath.startswith("local://"):
content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file)
# For GCS, try to redirect to signed URL, fall back to streaming
try:
url = await storage.get_download_url(file.storagePath, expires_in=300)
# If we got back an API path (fallback), stream directly instead
if url.startswith("/api/"):
content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file)
return fastapi.responses.RedirectResponse(url=url, status_code=302)
except Exception as e:
# Log the signed URL failure with context
logger.error(
f"Failed to get signed URL for file {file.id} "
f"(storagePath={file.storagePath}): {e}",
exc_info=True,
)
# Fall back to streaming directly from GCS
try:
content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file)
except Exception as fallback_error:
logger.error(
f"Fallback streaming also failed for file {file.id} "
f"(storagePath={file.storagePath}): {fallback_error}",
exc_info=True,
)
raise
@router.get(
"/files/{file_id}/download",
summary="Download file by ID",
)
async def download_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
file_id: str,
) -> Response:
"""
Download a file by its ID.
Returns the file content directly or redirects to a signed URL for GCS.
"""
workspace = await get_workspace(user_id)
if workspace is None:
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
file = await get_workspace_file(file_id, workspace.id)
if file is None:
raise fastapi.HTTPException(status_code=404, detail="File not found")
return await _create_file_download_response(file)

View File

@@ -32,7 +32,6 @@ import backend.api.features.postmark.postmark
import backend.api.features.store.model
import backend.api.features.store.routes
import backend.api.features.v1
import backend.api.features.workspace.routes as workspace_routes
import backend.data.block
import backend.data.db
import backend.data.graph
@@ -40,10 +39,6 @@ import backend.data.user
import backend.integrations.webhooks.utils
import backend.util.service
import backend.util.settings
from backend.api.features.chat.completion_consumer import (
start_completion_consumer,
stop_completion_consumer,
)
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
@@ -57,7 +52,6 @@ from backend.util.exceptions import (
)
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
from backend.util.service import UnhealthyServiceError
from backend.util.workspace_storage import shutdown_workspace_storage
from .external.fastapi_app import external_api
from .features.analytics import router as analytics_router
@@ -122,31 +116,14 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
# Start chat completion consumer for Redis Streams notifications
try:
await start_completion_consumer()
except Exception as e:
logger.warning(f"Could not start chat completion consumer: {e}")
with launch_darkly_context():
yield
# Stop chat completion consumer
try:
await stop_completion_consumer()
except Exception as e:
logger.warning(f"Error stopping chat completion consumer: {e}")
try:
await shutdown_cloud_storage_handler()
except Exception as e:
logger.warning(f"Error shutting down cloud storage handler: {e}")
try:
await shutdown_workspace_storage()
except Exception as e:
logger.warning(f"Error shutting down workspace storage: {e}")
await backend.data.db.disconnect()
@@ -338,11 +315,6 @@ app.include_router(
tags=["v2", "chat"],
prefix="/api/chat",
)
app.include_router(
workspace_routes.router,
tags=["workspace"],
prefix="/api/workspace",
)
app.include_router(
backend.api.features.oauth.router,
tags=["oauth"],

View File

@@ -66,24 +66,18 @@ async def event_broadcaster(manager: ConnectionManager):
execution_bus = AsyncRedisExecutionEventBus()
notification_bus = AsyncRedisNotificationEventBus()
try:
async def execution_worker():
async for event in execution_bus.listen("*"):
await manager.send_execution_update(event)
async def execution_worker():
async for event in execution_bus.listen("*"):
await manager.send_execution_update(event)
async def notification_worker():
async for notification in notification_bus.listen("*"):
await manager.send_notification(
user_id=notification.user_id,
payload=notification.payload,
)
async def notification_worker():
async for notification in notification_bus.listen("*"):
await manager.send_notification(
user_id=notification.user_id,
payload=notification.payload,
)
await asyncio.gather(execution_worker(), notification_worker())
finally:
# Ensure PubSub connections are closed on any exit to prevent leaks
await execution_bus.close()
await notification_bus.close()
await asyncio.gather(execution_worker(), notification_worker())
async def authenticate_websocket(websocket: WebSocket) -> str:

View File

@@ -13,7 +13,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -118,13 +117,11 @@ class AIImageCustomizerBlock(Block):
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
# Output will be a workspace ref or data URI depending on context
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
("image_url", "https://replicate.delivery/generated-image.jpg"),
],
test_mock={
# Use data URI to avoid HTTP requests during tests
"run_model": lambda *args, **kwargs: MediaFileType(
"data:image/jpeg;base64,/9j/4AAQSkZJRgABAgAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAABAAEDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD3+iiigD//2Q=="
"https://replicate.delivery/generated-image.jpg"
),
},
test_credentials=TEST_CREDENTIALS,
@@ -135,7 +132,8 @@ class AIImageCustomizerBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
try:
@@ -143,9 +141,10 @@ class AIImageCustomizerBlock(Block):
processed_images = await asyncio.gather(
*(
store_media_file(
graph_exec_id=graph_exec_id,
file=img,
execution_context=execution_context,
return_format="for_external_api", # Get content for Replicate API
user_id=user_id,
return_content=True,
)
for img in input_data.images
)
@@ -159,14 +158,7 @@ class AIImageCustomizerBlock(Block):
aspect_ratio=input_data.aspect_ratio.value,
output_format=input_data.output_format.value,
)
# Store the generated image to the user's workspace for persistence
stored_url = await store_media_file(
file=result,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", stored_url
yield "image_url", result
except Exception as e:
yield "error", str(e)

View File

@@ -6,7 +6,6 @@ from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -14,8 +13,6 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
class ImageSize(str, Enum):
@@ -168,13 +165,11 @@ class AIImageGeneratorBlock(Block):
test_output=[
(
"image_url",
# Test output is a data URI since we now store images
lambda x: x.startswith("data:image/"),
"https://replicate.delivery/generated-image.webp",
),
],
test_mock={
# Return a data URI directly so store_media_file doesn't need to download
"_run_client": lambda *args, **kwargs: "data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
"_run_client": lambda *args, **kwargs: "https://replicate.delivery/generated-image.webp"
},
)
@@ -323,24 +318,11 @@ class AIImageGeneratorBlock(Block):
style_text = style_map.get(style, "")
return f"{style_text} of" if style_text else ""
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
):
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
try:
url = await self.generate_image(input_data, credentials)
if url:
# Store the generated image to the user's workspace/execution folder
stored_url = await store_media_file(
file=MediaFileType(url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", stored_url
yield "image_url", url
else:
yield "error", "Image generation returned an empty result."
except Exception as e:

View File

@@ -13,7 +13,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -22,9 +21,7 @@ from backend.data.model import (
)
from backend.integrations.providers import ProviderName
from backend.util.exceptions import BlockExecutionError
from backend.util.file import store_media_file
from backend.util.request import Requests
from backend.util.type import MediaFileType
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@@ -274,10 +271,7 @@ class AIShortformVideoCreatorBlock(Block):
"voice": Voice.LILY,
"video_style": VisualMediaType.STOCK_VIDEOS,
},
test_output=(
"video_url",
lambda x: x.startswith(("workspace://", "data:")),
),
test_output=("video_url", "https://example.com/video.mp4"),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
@@ -286,21 +280,15 @@ class AIShortformVideoCreatorBlock(Block):
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "data:video/mp4;base64,AAAA",
"videoUrl": "https://example.com/video.mp4",
},
# Use data URI to avoid HTTP requests during tests
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
"wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
# Create a new Webhook.site URL
webhook_token, webhook_url = await self.create_webhook()
@@ -352,13 +340,7 @@ class AIShortformVideoCreatorBlock(Block):
)
video_url = await self.wait_for_video(credentials.api_key, pid)
logger.debug(f"Video ready: {video_url}")
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
yield "video_url", video_url
class AIAdMakerVideoCreatorBlock(Block):
@@ -465,10 +447,7 @@ class AIAdMakerVideoCreatorBlock(Block):
"https://cdn.revid.ai/uploads/1747076315114-image.png",
],
},
test_output=(
"video_url",
lambda x: x.startswith(("workspace://", "data:")),
),
test_output=("video_url", "https://example.com/ad.mp4"),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
@@ -477,21 +456,14 @@ class AIAdMakerVideoCreatorBlock(Block):
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "data:video/mp4;base64,AAAA",
"videoUrl": "https://example.com/ad.mp4",
},
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
"wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
):
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
webhook_token, webhook_url = await self.create_webhook()
payload = {
@@ -559,13 +531,7 @@ class AIAdMakerVideoCreatorBlock(Block):
raise RuntimeError("Failed to create video: No project ID returned")
video_url = await self.wait_for_video(credentials.api_key, pid)
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
yield "video_url", video_url
class AIScreenshotToVideoAdBlock(Block):
@@ -660,10 +626,7 @@ class AIScreenshotToVideoAdBlock(Block):
"script": "Amazing numbers!",
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
},
test_output=(
"video_url",
lambda x: x.startswith(("workspace://", "data:")),
),
test_output=("video_url", "https://example.com/screenshot.mp4"),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
@@ -672,21 +635,14 @@ class AIScreenshotToVideoAdBlock(Block):
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "data:video/mp4;base64,AAAA",
"videoUrl": "https://example.com/screenshot.mp4",
},
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
"wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
):
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
webhook_token, webhook_url = await self.create_webhook()
payload = {
@@ -754,10 +710,4 @@ class AIScreenshotToVideoAdBlock(Block):
raise RuntimeError("Failed to create video: No project ID returned")
video_url = await self.wait_for_video(credentials.api_key, pid)
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
yield "video_url", video_url

View File

@@ -6,7 +6,6 @@ if TYPE_CHECKING:
from pydantic import SecretStr
from backend.data.execution import ExecutionContext
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -18,8 +17,6 @@ from backend.sdk import (
Requests,
SchemaField,
)
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
from ._config import bannerbear
@@ -138,17 +135,15 @@ class BannerbearTextOverlayBlock(Block):
},
test_output=[
("success", True),
# Output will be a workspace ref or data URI depending on context
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
("image_url", "https://cdn.bannerbear.com/test-image.jpg"),
("uid", "test-uid-123"),
("status", "completed"),
],
test_mock={
# Use data URI to avoid HTTP requests during tests
"_make_api_request": lambda *args, **kwargs: {
"uid": "test-uid-123",
"status": "completed",
"image_url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAABAAEBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+v//Z",
"image_url": "https://cdn.bannerbear.com/test-image.jpg",
}
},
test_credentials=TEST_CREDENTIALS,
@@ -182,12 +177,7 @@ class BannerbearTextOverlayBlock(Block):
raise Exception(error_msg)
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
# Build the modifications array
modifications = []
@@ -244,18 +234,6 @@ class BannerbearTextOverlayBlock(Block):
# Synchronous request - image should be ready
yield "success", True
# Store the generated image to workspace for persistence
image_url = data.get("image_url", "")
if image_url:
stored_url = await store_media_file(
file=MediaFileType(image_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", stored_url
else:
yield "image_url", ""
yield "image_url", data.get("image_url", "")
yield "uid", data.get("uid", "")
yield "status", data.get("status", "completed")

View File

@@ -9,7 +9,6 @@ from backend.data.block import (
BlockSchemaOutput,
BlockType,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import store_media_file
from backend.util.type import MediaFileType, convert
@@ -18,10 +17,10 @@ from backend.util.type import MediaFileType, convert
class FileStoreBlock(Block):
class Input(BlockSchemaInput):
file_in: MediaFileType = SchemaField(
description="The file to download and store. Can be a URL (https://...), data URI, or local path."
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
)
base_64: bool = SchemaField(
description="Whether to produce output in base64 format (not recommended, you can pass the file reference across blocks).",
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
default=False,
advanced=True,
title="Produce Base64 Output",
@@ -29,18 +28,13 @@ class FileStoreBlock(Block):
class Output(BlockSchemaOutput):
file_out: MediaFileType = SchemaField(
description="Reference to the stored file. In CoPilot: workspace:// URI (visible in list_workspace_files). In graphs: data URI for passing to other blocks."
description="The relative path to the stored file in the temporary directory."
)
def __init__(self):
super().__init__(
id="cbb50872-625b-42f0-8203-a2ae78242d8a",
description=(
"Downloads and stores a file from a URL, data URI, or local path. "
"Use this to fetch images, documents, or other files for processing. "
"In CoPilot: saves to workspace (use list_workspace_files to see it). "
"In graphs: outputs a data URI to pass to other blocks."
),
description="Stores the input file in the temporary directory.",
categories={BlockCategory.BASIC, BlockCategory.MULTIMEDIA},
input_schema=FileStoreBlock.Input,
output_schema=FileStoreBlock.Output,
@@ -51,18 +45,15 @@ class FileStoreBlock(Block):
self,
input_data: Input,
*,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
# Determine return format based on user preference
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
yield "file_out", await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.file_in,
execution_context=execution_context,
return_format=return_format,
user_id=user_id,
return_content=input_data.base_64,
)

View File

@@ -15,7 +15,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import APIKeyCredentials, SchemaField
from backend.util.file import store_media_file
from backend.util.request import Requests
@@ -667,7 +666,8 @@ class SendDiscordFileBlock(Block):
file: MediaFileType,
filename: str,
message_content: str,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
) -> dict:
intents = discord.Intents.default()
intents.guilds = True
@@ -731,9 +731,10 @@ class SendDiscordFileBlock(Block):
# Local file path - read from stored media file
# This would be a path from a previous block's output
stored_file = await store_media_file(
graph_exec_id=graph_exec_id,
file=file,
execution_context=execution_context,
return_format="for_external_api", # Get content to send to Discord
user_id=user_id,
return_content=True, # Get as data URI
)
# Now process as data URI
header, encoded = stored_file.split(",", 1)
@@ -780,7 +781,8 @@ class SendDiscordFileBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
try:
@@ -791,7 +793,8 @@ class SendDiscordFileBlock(Block):
file=input_data.file,
filename=input_data.filename,
message_content=input_data.message_content,
execution_context=execution_context,
graph_exec_id=graph_exec_id,
user_id=user_id,
)
yield "status", result.get("status", "Unknown error")

View File

@@ -1,28 +0,0 @@
"""ElevenLabs integration blocks - test credentials and shared utilities."""
from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
from backend.integrations.providers import ProviderName
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="elevenlabs",
api_key=SecretStr("mock-elevenlabs-api-key"),
title="Mock ElevenLabs API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
ElevenLabsCredentials = APIKeyCredentials
ElevenLabsCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.ELEVENLABS], Literal["api_key"]
]

View File

@@ -1,77 +0,0 @@
"""Text encoding block for converting special characters to escape sequences."""
import codecs
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import SchemaField
class TextEncoderBlock(Block):
"""
Encodes a string by converting special characters into escape sequences.
This block is the inverse of TextDecoderBlock. It takes text containing
special characters (like newlines, tabs, etc.) and converts them into
their escape sequence representations (e.g., newline becomes \\n).
"""
class Input(BlockSchemaInput):
"""Input schema for TextEncoderBlock."""
text: str = SchemaField(
description="A string containing special characters to be encoded",
placeholder="Your text with newlines and quotes to encode",
)
class Output(BlockSchemaOutput):
"""Output schema for TextEncoderBlock."""
encoded_text: str = SchemaField(
description="The encoded text with special characters converted to escape sequences"
)
error: str = SchemaField(description="Error message if encoding fails")
def __init__(self):
super().__init__(
id="5185f32e-4b65-4ecf-8fbb-873f003f09d6",
description="Encodes a string by converting special characters into escape sequences",
categories={BlockCategory.TEXT},
input_schema=TextEncoderBlock.Input,
output_schema=TextEncoderBlock.Output,
test_input={
"text": """Hello
World!
This is a "quoted" string."""
},
test_output=[
(
"encoded_text",
"""Hello\\nWorld!\\nThis is a "quoted" string.""",
)
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
"""
Encode the input text by converting special characters to escape sequences.
Args:
input_data: The input containing the text to encode.
**kwargs: Additional keyword arguments (unused).
Yields:
The encoded text with escape sequences, or an error message if encoding fails.
"""
try:
encoded_text = codecs.encode(input_data.text, "unicode_escape").decode(
"utf-8"
)
yield "encoded_text", encoded_text
except Exception as e:
yield "error", f"Encoding error: {str(e)}"

View File

@@ -17,11 +17,8 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import store_media_file
from backend.util.request import ClientResponseError, Requests
from backend.util.type import MediaFileType
logger = logging.getLogger(__name__)
@@ -67,13 +64,9 @@ class AIVideoGeneratorBlock(Block):
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
# Output will be a workspace ref or data URI depending on context
("video_url", lambda x: x.startswith(("workspace://", "data:"))),
],
test_output=[("video_url", "https://fal.media/files/example/video.mp4")],
test_mock={
# Use data URI to avoid HTTP requests during tests
"generate_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA"
"generate_video": lambda *args, **kwargs: "https://fal.media/files/example/video.mp4"
},
)
@@ -215,22 +208,11 @@ class AIVideoGeneratorBlock(Block):
raise RuntimeError(f"API request failed: {str(e)}")
async def run(
self,
input_data: Input,
*,
credentials: FalCredentials,
execution_context: ExecutionContext,
**kwargs,
self, input_data: Input, *, credentials: FalCredentials, **kwargs
) -> BlockOutput:
try:
video_url = await self.generate_video(input_data, credentials)
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
yield "video_url", video_url
except Exception as e:
error_message = str(e)
yield "error", error_message

View File

@@ -12,7 +12,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -122,12 +121,10 @@ class AIImageEditorBlock(Block):
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
# Output will be a workspace ref or data URI depending on context
("output_image", lambda x: x.startswith(("workspace://", "data:"))),
("output_image", "https://replicate.com/output/edited-image.png"),
],
test_mock={
# Use data URI to avoid HTTP requests during tests
"run_model": lambda *args, **kwargs: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
"run_model": lambda *args, **kwargs: "https://replicate.com/output/edited-image.png",
},
test_credentials=TEST_CREDENTIALS,
)
@@ -137,7 +134,8 @@ class AIImageEditorBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
result = await self.run_model(
@@ -146,25 +144,20 @@ class AIImageEditorBlock(Block):
prompt=input_data.prompt,
input_image_b64=(
await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.input_image,
execution_context=execution_context,
return_format="for_external_api", # Get content for Replicate API
user_id=user_id,
return_content=True,
)
if input_data.input_image
else None
),
aspect_ratio=input_data.aspect_ratio.value,
seed=input_data.seed,
user_id=execution_context.user_id or "",
graph_exec_id=execution_context.graph_exec_id or "",
user_id=user_id,
graph_exec_id=graph_exec_id,
)
# Store the generated image to the user's workspace for persistence
stored_url = await store_media_file(
file=result,
execution_context=execution_context,
return_format="for_block_output",
)
yield "output_image", stored_url
yield "output_image", result
async def run_model(
self,

View File

@@ -21,7 +21,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
from backend.util.settings import Settings
@@ -96,7 +95,8 @@ def _make_mime_text(
async def create_mime_message(
input_data,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
) -> str:
"""Create a MIME message with attachments and return base64-encoded raw message."""
@@ -117,12 +117,12 @@ async def create_mime_message(
if input_data.attachments:
for attach in input_data.attachments:
local_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=attach,
execution_context=execution_context,
return_format="for_local_processing",
return_content=False,
)
assert execution_context.graph_exec_id # Validated by store_media_file
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
abs_path = get_exec_file_path(graph_exec_id, local_path)
part = MIMEBase("application", "octet-stream")
with open(abs_path, "rb") as f:
part.set_payload(f.read())
@@ -582,25 +582,27 @@ class GmailSendBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
result = await self._send_email(
service,
input_data,
execution_context,
graph_exec_id,
user_id,
)
yield "result", result
async def _send_email(
self, service, input_data: Input, execution_context: ExecutionContext
self, service, input_data: Input, graph_exec_id: str, user_id: str
) -> dict:
if not input_data.to or not input_data.subject or not input_data.body:
raise ValueError(
"At least one recipient, subject, and body are required for sending an email"
)
raw_message = await create_mime_message(input_data, execution_context)
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
sent_message = await asyncio.to_thread(
lambda: service.users()
.messages()
@@ -690,28 +692,30 @@ class GmailCreateDraftBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
result = await self._create_draft(
service,
input_data,
execution_context,
graph_exec_id,
user_id,
)
yield "result", GmailDraftResult(
id=result["id"], message_id=result["message"]["id"], status="draft_created"
)
async def _create_draft(
self, service, input_data: Input, execution_context: ExecutionContext
self, service, input_data: Input, graph_exec_id: str, user_id: str
) -> dict:
if not input_data.to or not input_data.subject:
raise ValueError(
"At least one recipient and subject are required for creating a draft"
)
raw_message = await create_mime_message(input_data, execution_context)
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
draft = await asyncio.to_thread(
lambda: service.users()
.drafts()
@@ -1096,7 +1100,7 @@ class GmailGetThreadBlock(GmailBase):
async def _build_reply_message(
service, input_data, execution_context: ExecutionContext
service, input_data, graph_exec_id: str, user_id: str
) -> tuple[str, str]:
"""
Builds a reply MIME message for Gmail threads.
@@ -1186,12 +1190,12 @@ async def _build_reply_message(
# Handle attachments
for attach in input_data.attachments:
local_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=attach,
execution_context=execution_context,
return_format="for_local_processing",
return_content=False,
)
assert execution_context.graph_exec_id # Validated by store_media_file
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
abs_path = get_exec_file_path(graph_exec_id, local_path)
part = MIMEBase("application", "octet-stream")
with open(abs_path, "rb") as f:
part.set_payload(f.read())
@@ -1307,14 +1311,16 @@ class GmailReplyBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
message = await self._reply(
service,
input_data,
execution_context,
graph_exec_id,
user_id,
)
yield "messageId", message["id"]
yield "threadId", message.get("threadId", input_data.threadId)
@@ -1337,11 +1343,11 @@ class GmailReplyBlock(GmailBase):
yield "email", email
async def _reply(
self, service, input_data: Input, execution_context: ExecutionContext
self, service, input_data: Input, graph_exec_id: str, user_id: str
) -> dict:
# Build the reply message using the shared helper
raw, thread_id = await _build_reply_message(
service, input_data, execution_context
service, input_data, graph_exec_id, user_id
)
# Send the message
@@ -1435,14 +1441,16 @@ class GmailDraftReplyBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
draft = await self._create_draft_reply(
service,
input_data,
execution_context,
graph_exec_id,
user_id,
)
yield "draftId", draft["id"]
yield "messageId", draft["message"]["id"]
@@ -1450,11 +1458,11 @@ class GmailDraftReplyBlock(GmailBase):
yield "status", "draft_created"
async def _create_draft_reply(
self, service, input_data: Input, execution_context: ExecutionContext
self, service, input_data: Input, graph_exec_id: str, user_id: str
) -> dict:
# Build the reply message using the shared helper
raw, thread_id = await _build_reply_message(
service, input_data, execution_context
service, input_data, graph_exec_id, user_id
)
# Create draft with proper thread association
@@ -1621,21 +1629,23 @@ class GmailForwardBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
result = await self._forward_message(
service,
input_data,
execution_context,
graph_exec_id,
user_id,
)
yield "messageId", result["id"]
yield "threadId", result.get("threadId", "")
yield "status", "forwarded"
async def _forward_message(
self, service, input_data: Input, execution_context: ExecutionContext
self, service, input_data: Input, graph_exec_id: str, user_id: str
) -> dict:
if not input_data.to:
raise ValueError("At least one recipient is required for forwarding")
@@ -1717,12 +1727,12 @@ To: {original_to}
# Add any additional attachments
for attach in input_data.additionalAttachments:
local_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=attach,
execution_context=execution_context,
return_format="for_local_processing",
return_content=False,
)
assert execution_context.graph_exec_id # Validated by store_media_file
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
abs_path = get_exec_file_path(graph_exec_id, local_path)
part = MIMEBase("application", "octet-stream")
with open(abs_path, "rb") as f:
part.set_payload(f.read())

View File

@@ -15,7 +15,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
@@ -117,9 +116,10 @@ class SendWebRequestBlock(Block):
@staticmethod
async def _prepare_files(
execution_context: ExecutionContext,
graph_exec_id: str,
files_name: str,
files: list[MediaFileType],
user_id: str,
) -> list[tuple[str, tuple[str, BytesIO, str]]]:
"""
Prepare files for the request by storing them and reading their content.
@@ -127,16 +127,11 @@ class SendWebRequestBlock(Block):
(files_name, (filename, BytesIO, mime_type))
"""
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
graph_exec_id = execution_context.graph_exec_id
if graph_exec_id is None:
raise ValueError("graph_exec_id is required for file operations")
for media in files:
# Normalise to a list so we can repeat the same key
rel_path = await store_media_file(
file=media,
execution_context=execution_context,
return_format="for_local_processing",
graph_exec_id, media, user_id, return_content=False
)
abs_path = get_exec_file_path(graph_exec_id, rel_path)
async with aiofiles.open(abs_path, "rb") as f:
@@ -148,7 +143,7 @@ class SendWebRequestBlock(Block):
return files_payload
async def run(
self, input_data: Input, *, execution_context: ExecutionContext, **kwargs
self, input_data: Input, *, graph_exec_id: str, user_id: str, **kwargs
) -> BlockOutput:
# ─── Parse/normalise body ────────────────────────────────────
body = input_data.body
@@ -179,7 +174,7 @@ class SendWebRequestBlock(Block):
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
if use_files:
files_payload = await self._prepare_files(
execution_context, input_data.files_name, input_data.files
graph_exec_id, input_data.files_name, input_data.files, user_id
)
# Enforce body format rules
@@ -243,8 +238,9 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
self,
input_data: Input,
*,
execution_context: ExecutionContext,
graph_exec_id: str,
credentials: HostScopedCredentials,
user_id: str,
**kwargs,
) -> BlockOutput:
# Create SendWebRequestBlock.Input from our input (removing credentials field)
@@ -275,6 +271,6 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
# Use parent class run method
async for output_name, output_data in super().run(
base_input, execution_context=execution_context, **kwargs
base_input, graph_exec_id=graph_exec_id, user_id=user_id, **kwargs
):
yield output_name, output_data

View File

@@ -12,7 +12,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockType,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import store_media_file
from backend.util.mock import MockObject
@@ -463,21 +462,18 @@ class AgentFileInputBlock(AgentInputBlock):
self,
input_data: Input,
*,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
if not input_data.value:
return
# Determine return format based on user preference
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
yield "result", await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.value,
execution_context=execution_context,
return_format=return_format,
user_id=user_id,
return_content=input_data.base_64,
)

View File

@@ -162,16 +162,8 @@ class LinearClient:
"searchTerm": team_name,
}
result = await self.query(query, variables)
nodes = result["teams"]["nodes"]
if not nodes:
raise LinearAPIException(
f"Team '{team_name}' not found. Check the team name or key and try again.",
status_code=404,
)
return nodes[0]["id"]
team_id = await self.query(query, variables)
return team_id["teams"]["nodes"][0]["id"]
except LinearAPIException as e:
raise e
@@ -248,44 +240,17 @@ class LinearClient:
except LinearAPIException as e:
raise e
async def try_search_issues(
self,
term: str,
max_results: int = 10,
team_id: str | None = None,
) -> list[Issue]:
async def try_search_issues(self, term: str) -> list[Issue]:
try:
query = """
query SearchIssues(
$term: String!,
$first: Int,
$teamId: String
) {
searchIssues(
term: $term,
first: $first,
teamId: $teamId
) {
query SearchIssues($term: String!, $includeComments: Boolean!) {
searchIssues(term: $term, includeComments: $includeComments) {
nodes {
id
identifier
title
description
priority
createdAt
state {
id
name
type
}
project {
id
name
}
assignee {
id
name
}
}
}
}
@@ -293,8 +258,7 @@ class LinearClient:
variables: dict[str, Any] = {
"term": term,
"first": max_results,
"teamId": team_id,
"includeComments": True,
}
issues = await self.query(query, variables)

View File

@@ -17,7 +17,7 @@ from ._config import (
LinearScope,
linear,
)
from .models import CreateIssueResponse, Issue, State
from .models import CreateIssueResponse, Issue
class LinearCreateIssueBlock(Block):
@@ -135,20 +135,9 @@ class LinearSearchIssuesBlock(Block):
description="Linear credentials with read permissions",
required_scopes={LinearScope.READ},
)
max_results: int = SchemaField(
description="Maximum number of results to return",
default=10,
ge=1,
le=100,
)
team_name: str | None = SchemaField(
description="Optional team name to filter results (e.g., 'Internal', 'Open Source')",
default=None,
)
class Output(BlockSchemaOutput):
issues: list[Issue] = SchemaField(description="List of issues")
error: str = SchemaField(description="Error message if the search failed")
def __init__(self):
super().__init__(
@@ -156,11 +145,8 @@ class LinearSearchIssuesBlock(Block):
description="Searches for issues on Linear",
input_schema=self.Input,
output_schema=self.Output,
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
test_input={
"term": "Test issue",
"max_results": 10,
"team_name": None,
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
},
test_credentials=TEST_CREDENTIALS_OAUTH,
@@ -170,14 +156,10 @@ class LinearSearchIssuesBlock(Block):
[
Issue(
id="abc123",
identifier="TST-123",
identifier="abc123",
title="Test issue",
description="Test description",
priority=1,
state=State(
id="state1", name="In Progress", type="started"
),
createdAt="2026-01-15T10:00:00.000Z",
)
],
)
@@ -186,12 +168,10 @@ class LinearSearchIssuesBlock(Block):
"search_issues": lambda *args, **kwargs: [
Issue(
id="abc123",
identifier="TST-123",
identifier="abc123",
title="Test issue",
description="Test description",
priority=1,
state=State(id="state1", name="In Progress", type="started"),
createdAt="2026-01-15T10:00:00.000Z",
)
]
},
@@ -201,22 +181,10 @@ class LinearSearchIssuesBlock(Block):
async def search_issues(
credentials: OAuth2Credentials | APIKeyCredentials,
term: str,
max_results: int = 10,
team_name: str | None = None,
) -> list[Issue]:
client = LinearClient(credentials=credentials)
# Resolve team name to ID if provided
# Raises LinearAPIException with descriptive message if team not found
team_id: str | None = None
if team_name:
team_id = await client.try_get_team_by_name(team_name=team_name)
return await client.try_search_issues(
term=term,
max_results=max_results,
team_id=team_id,
)
response: list[Issue] = await client.try_search_issues(term=term)
return response
async def run(
self,
@@ -228,10 +196,7 @@ class LinearSearchIssuesBlock(Block):
"""Execute the issue search"""
try:
issues = await self.search_issues(
credentials=credentials,
term=input_data.term,
max_results=input_data.max_results,
team_name=input_data.team_name,
credentials=credentials, term=input_data.term
)
yield "issues", issues
except LinearAPIException as e:

View File

@@ -36,21 +36,12 @@ class Project(BaseModel):
content: str | None = None
class State(BaseModel):
id: str
name: str
type: str | None = (
None # Workflow state type (e.g., "triage", "backlog", "started", "completed", "canceled")
)
class Issue(BaseModel):
id: str
identifier: str
title: str
description: str | None
priority: int
state: State | None = None
project: Project | None = None
createdAt: str | None = None
comments: list[Comment] | None = None

View File

@@ -32,7 +32,7 @@ from backend.data.model import (
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.logging import TruncatedLogger
from backend.util.prompt import compress_context, estimate_token_count
from backend.util.prompt import compress_prompt, estimate_token_count
from backend.util.text import TextFormatter
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
@@ -115,7 +115,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
CLAUDE_4_6_OPUS = "claude-opus-4-6"
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
# AI/ML API models
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
@@ -271,9 +271,6 @@ MODEL_METADATA = {
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
), # claude-4-sonnet-20250514
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
), # claude-opus-4-6
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
), # claude-opus-4-5-20251101
@@ -283,6 +280,9 @@ MODEL_METADATA = {
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
), # claude-haiku-4-5-20251001
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
"anthropic", 200000, 64000, "Claude 3.7 Sonnet", "Anthropic", "Anthropic", 2
), # claude-3-7-sonnet-20250219
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
), # claude-3-haiku-20240307
@@ -638,18 +638,11 @@ async def llm_call(
context_window = llm_model.context_window
if compress_prompt_to_fit:
result = await compress_context(
prompt = compress_prompt(
messages=prompt,
target_tokens=llm_model.context_window // 2,
client=None, # Truncation-only, no LLM summarization
reserve=0, # Caller handles response token budget separately
lossy_ok=True,
)
if result.error:
logger.warning(
f"Prompt compression did not meet target: {result.error}. "
f"Proceeding with {result.token_count} tokens."
)
prompt = result.messages
# Calculate available tokens based on context window and input length
estimated_input_tokens = estimate_token_count(prompt)

View File

@@ -0,0 +1,251 @@
import os
import tempfile
from typing import Literal, Optional
from moviepy.audio.io.AudioFileClip import AudioFileClip
from moviepy.video.fx.Loop import Loop
from moviepy.video.io.VideoFileClip import VideoFileClip
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import SchemaField
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
class MediaDurationBlock(Block):
class Input(BlockSchemaInput):
media_in: MediaFileType = SchemaField(
description="Media input (URL, data URI, or local path)."
)
is_video: bool = SchemaField(
description="Whether the media is a video (True) or audio (False).",
default=True,
)
class Output(BlockSchemaOutput):
duration: float = SchemaField(
description="Duration of the media file (in seconds)."
)
def __init__(self):
super().__init__(
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
description="Block to get the duration of a media file.",
categories={BlockCategory.MULTIMEDIA},
input_schema=MediaDurationBlock.Input,
output_schema=MediaDurationBlock.Output,
)
async def run(
self,
input_data: Input,
*,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
# 1) Store the input media locally
local_media_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.media_in,
user_id=user_id,
return_content=False,
)
media_abspath = get_exec_file_path(graph_exec_id, local_media_path)
# 2) Load the clip
if input_data.is_video:
clip = VideoFileClip(media_abspath)
else:
clip = AudioFileClip(media_abspath)
yield "duration", clip.duration
class LoopVideoBlock(Block):
"""
Block for looping (repeating) a video clip until a given duration or number of loops.
"""
class Input(BlockSchemaInput):
video_in: MediaFileType = SchemaField(
description="The input video (can be a URL, data URI, or local path)."
)
# Provide EITHER a `duration` or `n_loops` or both. We'll demonstrate `duration`.
duration: Optional[float] = SchemaField(
description="Target duration (in seconds) to loop the video to. If omitted, defaults to no looping.",
default=None,
ge=0.0,
)
n_loops: Optional[int] = SchemaField(
description="Number of times to repeat the video. If omitted, defaults to 1 (no repeat).",
default=None,
ge=1,
)
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
description="How to return the output video. Either a relative path or base64 data URI.",
default="file_path",
)
class Output(BlockSchemaOutput):
video_out: str = SchemaField(
description="Looped video returned either as a relative path or a data URI."
)
def __init__(self):
super().__init__(
id="8bf9eef6-5451-4213-b265-25306446e94b",
description="Block to loop a video to a given duration or number of repeats.",
categories={BlockCategory.MULTIMEDIA},
input_schema=LoopVideoBlock.Input,
output_schema=LoopVideoBlock.Output,
)
async def run(
self,
input_data: Input,
*,
node_exec_id: str,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
# 1) Store the input video locally
local_video_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.video_in,
user_id=user_id,
return_content=False,
)
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
# 2) Load the clip
clip = VideoFileClip(input_abspath)
# 3) Apply the loop effect
looped_clip = clip
if input_data.duration:
# Loop until we reach the specified duration
looped_clip = looped_clip.with_effects([Loop(duration=input_data.duration)])
elif input_data.n_loops:
looped_clip = looped_clip.with_effects([Loop(n=input_data.n_loops)])
else:
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
assert isinstance(looped_clip, VideoFileClip)
# 4) Save the looped output
output_filename = MediaFileType(
f"{node_exec_id}_looped_{os.path.basename(local_video_path)}"
)
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
looped_clip = looped_clip.with_audio(clip.audio)
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
# Return as data URI
video_out = await store_media_file(
graph_exec_id=graph_exec_id,
file=output_filename,
user_id=user_id,
return_content=input_data.output_return_type == "data_uri",
)
yield "video_out", video_out
class AddAudioToVideoBlock(Block):
"""
Block that adds (attaches) an audio track to an existing video.
Optionally scale the volume of the new track.
"""
class Input(BlockSchemaInput):
video_in: MediaFileType = SchemaField(
description="Video input (URL, data URI, or local path)."
)
audio_in: MediaFileType = SchemaField(
description="Audio input (URL, data URI, or local path)."
)
volume: float = SchemaField(
description="Volume scale for the newly attached audio track (1.0 = original).",
default=1.0,
)
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
description="Return the final output as a relative path or base64 data URI.",
default="file_path",
)
class Output(BlockSchemaOutput):
video_out: MediaFileType = SchemaField(
description="Final video (with attached audio), as a path or data URI."
)
def __init__(self):
super().__init__(
id="3503748d-62b6-4425-91d6-725b064af509",
description="Block to attach an audio file to a video file using moviepy.",
categories={BlockCategory.MULTIMEDIA},
input_schema=AddAudioToVideoBlock.Input,
output_schema=AddAudioToVideoBlock.Output,
)
async def run(
self,
input_data: Input,
*,
node_exec_id: str,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
# 1) Store the inputs locally
local_video_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.video_in,
user_id=user_id,
return_content=False,
)
local_audio_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.audio_in,
user_id=user_id,
return_content=False,
)
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
video_abspath = os.path.join(abs_temp_dir, local_video_path)
audio_abspath = os.path.join(abs_temp_dir, local_audio_path)
# 2) Load video + audio with moviepy
video_clip = VideoFileClip(video_abspath)
audio_clip = AudioFileClip(audio_abspath)
# Optionally scale volume
if input_data.volume != 1.0:
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
# 3) Attach the new audio track
final_clip = video_clip.with_audio(audio_clip)
# 4) Write to output file
output_filename = MediaFileType(
f"{node_exec_id}_audio_attached_{os.path.basename(local_video_path)}"
)
output_abspath = os.path.join(abs_temp_dir, output_filename)
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
# 5) Return either path or data URI
video_out = await store_media_file(
graph_exec_id=graph_exec_id,
file=output_filename,
user_id=user_id,
return_content=input_data.output_return_type == "data_uri",
)
yield "video_out", video_out

View File

@@ -11,7 +11,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -113,7 +112,8 @@ class ScreenshotWebPageBlock(Block):
@staticmethod
async def take_screenshot(
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
url: str,
viewport_width: int,
viewport_height: int,
@@ -155,11 +155,12 @@ class ScreenshotWebPageBlock(Block):
return {
"image": await store_media_file(
graph_exec_id=graph_exec_id,
file=MediaFileType(
f"data:image/{format.value};base64,{b64encode(content).decode('utf-8')}"
),
execution_context=execution_context,
return_format="for_block_output",
user_id=user_id,
return_content=True,
)
}
@@ -168,13 +169,15 @@ class ScreenshotWebPageBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
try:
screenshot_data = await self.take_screenshot(
credentials=credentials,
execution_context=execution_context,
graph_exec_id=graph_exec_id,
user_id=user_id,
url=input_data.url,
viewport_width=input_data.viewport_width,
viewport_height=input_data.viewport_height,

View File

@@ -7,7 +7,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import ContributorDetails, SchemaField
from backend.util.file import get_exec_file_path, store_media_file
from backend.util.type import MediaFileType
@@ -99,7 +98,7 @@ class ReadSpreadsheetBlock(Block):
)
async def run(
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
) -> BlockOutput:
import csv
from io import StringIO
@@ -107,16 +106,14 @@ class ReadSpreadsheetBlock(Block):
# Determine data source - prefer file_input if provided, otherwise use contents
if input_data.file_input:
stored_file_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=input_data.file_input,
execution_context=execution_context,
return_format="for_local_processing",
return_content=False,
)
# Get full file path
assert execution_context.graph_exec_id # Validated by store_media_file
file_path = get_exec_file_path(
execution_context.graph_exec_id, stored_file_path
)
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
if not Path(file_path).exists():
raise ValueError(f"File does not exist: {file_path}")

View File

@@ -83,7 +83,7 @@ class StagehandRecommendedLlmModel(str, Enum):
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
# Anthropic
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
@property
def provider_name(self) -> str:
@@ -137,7 +137,7 @@ class StagehandObserveBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
@@ -182,7 +182,10 @@ class StagehandObserveBlock(Block):
**kwargs,
) -> BlockOutput:
logger.debug(f"OBSERVE: Using model provider {model_credentials.provider}")
logger.info(f"OBSERVE: Stagehand credentials: {stagehand_credentials}")
logger.info(
f"OBSERVE: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
)
with disable_signal_handling():
stagehand = Stagehand(
@@ -227,7 +230,7 @@ class StagehandActBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
@@ -279,7 +282,10 @@ class StagehandActBlock(Block):
**kwargs,
) -> BlockOutput:
logger.debug(f"ACT: Using model provider {model_credentials.provider}")
logger.info(f"ACT: Stagehand credentials: {stagehand_credentials}")
logger.info(
f"ACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
)
with disable_signal_handling():
stagehand = Stagehand(
@@ -324,7 +330,7 @@ class StagehandExtractBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
@@ -364,7 +370,10 @@ class StagehandExtractBlock(Block):
**kwargs,
) -> BlockOutput:
logger.debug(f"EXTRACT: Using model provider {model_credentials.provider}")
logger.info(f"EXTRACT: Stagehand credentials: {stagehand_credentials}")
logger.info(
f"EXTRACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
)
with disable_signal_handling():
stagehand = Stagehand(

View File

@@ -10,7 +10,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -18,9 +17,7 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.file import store_media_file
from backend.util.request import Requests
from backend.util.type import MediaFileType
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@@ -105,7 +102,7 @@ class CreateTalkingAvatarVideoBlock(Block):
test_output=[
(
"video_url",
lambda x: x.startswith(("workspace://", "data:")),
"https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
),
],
test_mock={
@@ -113,10 +110,9 @@ class CreateTalkingAvatarVideoBlock(Block):
"id": "abcd1234-5678-efgh-ijkl-mnopqrstuvwx",
"status": "created",
},
# Use data URI to avoid HTTP requests during tests
"get_clip_status": lambda *args, **kwargs: {
"status": "done",
"result_url": "data:video/mp4;base64,AAAA",
"result_url": "https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
},
},
test_credentials=TEST_CREDENTIALS,
@@ -142,12 +138,7 @@ class CreateTalkingAvatarVideoBlock(Block):
return response.json()
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
# Create the clip
payload = {
@@ -174,14 +165,7 @@ class CreateTalkingAvatarVideoBlock(Block):
for _ in range(input_data.max_polling_attempts):
status_response = await self.get_clip_status(credentials.api_key, clip_id)
if status_response["status"] == "done":
# Store the generated video to the user's workspace for persistence
video_url = status_response["result_url"]
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
yield "video_url", status_response["result_url"]
return
elif status_response["status"] == "error":
raise RuntimeError(

View File

@@ -12,7 +12,6 @@ from backend.blocks.iteration import StepThroughItemsBlock
from backend.blocks.llm import AITextSummarizerBlock
from backend.blocks.text import ExtractTextInformationBlock
from backend.blocks.xml_parser import XMLParserBlock
from backend.data.execution import ExecutionContext
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
@@ -234,12 +233,9 @@ class TestStoreMediaFileSecurity:
with pytest.raises(ValueError, match="File too large"):
await store_media_file(
graph_exec_id="test",
file=MediaFileType(large_data_uri),
execution_context=ExecutionContext(
user_id="test_user",
graph_exec_id="test",
),
return_format="for_local_processing",
user_id="test_user",
)
@patch("backend.util.file.Path")
@@ -274,12 +270,9 @@ class TestStoreMediaFileSecurity:
# Should raise an error when directory size exceeds limit
with pytest.raises(ValueError, match="Disk usage limit exceeded"):
await store_media_file(
graph_exec_id="test",
file=MediaFileType(
"data:text/plain;base64,dGVzdA=="
), # Small test file
execution_context=ExecutionContext(
user_id="test_user",
graph_exec_id="test",
),
return_format="for_local_processing",
user_id="test_user",
)

View File

@@ -11,22 +11,10 @@ from backend.blocks.http import (
HttpMethod,
SendAuthenticatedWebRequestBlock,
)
from backend.data.execution import ExecutionContext
from backend.data.model import HostScopedCredentials
from backend.util.request import Response
def make_test_context(
graph_exec_id: str = "test-exec-id",
user_id: str = "test-user-id",
) -> ExecutionContext:
"""Helper to create test ExecutionContext."""
return ExecutionContext(
user_id=user_id,
graph_exec_id=graph_exec_id,
)
class TestHttpBlockWithHostScopedCredentials:
"""Test suite for HTTP block integration with HostScopedCredentials."""
@@ -117,7 +105,8 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=exact_match_credentials,
execution_context=make_test_context(),
graph_exec_id="test-exec-id",
user_id="test-user-id",
):
result.append((output_name, output_data))
@@ -172,7 +161,8 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=wildcard_credentials,
execution_context=make_test_context(),
graph_exec_id="test-exec-id",
user_id="test-user-id",
):
result.append((output_name, output_data))
@@ -218,7 +208,8 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=non_matching_credentials,
execution_context=make_test_context(),
graph_exec_id="test-exec-id",
user_id="test-user-id",
):
result.append((output_name, output_data))
@@ -267,7 +258,8 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=exact_match_credentials,
execution_context=make_test_context(),
graph_exec_id="test-exec-id",
user_id="test-user-id",
):
result.append((output_name, output_data))
@@ -326,7 +318,8 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=auto_discovered_creds, # Execution manager found these
execution_context=make_test_context(),
graph_exec_id="test-exec-id",
user_id="test-user-id",
):
result.append((output_name, output_data))
@@ -389,7 +382,8 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=multi_header_creds,
execution_context=make_test_context(),
graph_exec_id="test-exec-id",
user_id="test-user-id",
):
result.append((output_name, output_data))
@@ -477,7 +471,8 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=test_creds,
execution_context=make_test_context(),
graph_exec_id="test-exec-id",
user_id="test-user-id",
):
result.append((output_name, output_data))

View File

@@ -1,77 +0,0 @@
import pytest
from backend.blocks.encoder_block import TextEncoderBlock
@pytest.mark.asyncio
async def test_text_encoder_basic():
"""Test basic encoding of newlines and special characters."""
block = TextEncoderBlock()
result = []
async for output in block.run(TextEncoderBlock.Input(text="Hello\nWorld")):
result.append(output)
assert len(result) == 1
assert result[0][0] == "encoded_text"
assert result[0][1] == "Hello\\nWorld"
@pytest.mark.asyncio
async def test_text_encoder_multiple_escapes():
"""Test encoding of multiple escape sequences."""
block = TextEncoderBlock()
result = []
async for output in block.run(
TextEncoderBlock.Input(text="Line1\nLine2\tTabbed\rCarriage")
):
result.append(output)
assert len(result) == 1
assert result[0][0] == "encoded_text"
assert "\\n" in result[0][1]
assert "\\t" in result[0][1]
assert "\\r" in result[0][1]
@pytest.mark.asyncio
async def test_text_encoder_unicode():
"""Test that unicode characters are handled correctly."""
block = TextEncoderBlock()
result = []
async for output in block.run(TextEncoderBlock.Input(text="Hello 世界\n")):
result.append(output)
assert len(result) == 1
assert result[0][0] == "encoded_text"
# Unicode characters should be escaped as \uXXXX sequences
assert "\\n" in result[0][1]
@pytest.mark.asyncio
async def test_text_encoder_empty_string():
"""Test encoding of an empty string."""
block = TextEncoderBlock()
result = []
async for output in block.run(TextEncoderBlock.Input(text="")):
result.append(output)
assert len(result) == 1
assert result[0][0] == "encoded_text"
assert result[0][1] == ""
@pytest.mark.asyncio
async def test_text_encoder_error_handling():
"""Test that encoding errors are handled gracefully."""
from unittest.mock import patch
block = TextEncoderBlock()
result = []
with patch("codecs.encode", side_effect=Exception("Mocked encoding error")):
async for output in block.run(TextEncoderBlock.Input(text="test")):
result.append(output)
assert len(result) == 1
assert result[0][0] == "error"
assert "Mocked encoding error" in result[0][1]

View File

@@ -11,7 +11,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util import json, text
from backend.util.file import get_exec_file_path, store_media_file
@@ -445,21 +444,18 @@ class FileReadBlock(Block):
)
async def run(
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
) -> BlockOutput:
# Store the media file properly (handles URLs, data URIs, etc.)
stored_file_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=input_data.file_input,
execution_context=execution_context,
return_format="for_local_processing",
return_content=False,
)
# Get full file path (graph_exec_id validated by store_media_file above)
if not execution_context.graph_exec_id:
raise ValueError("execution_context.graph_exec_id is required")
file_path = get_exec_file_path(
execution_context.graph_exec_id, stored_file_path
)
# Get full file path
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
if not Path(file_path).exists():
raise ValueError(f"File does not exist: {file_path}")

View File

@@ -1,37 +0,0 @@
"""Video editing blocks for AutoGPT Platform.
This module provides blocks for:
- Downloading videos from URLs (YouTube, Vimeo, news sites, direct links)
- Clipping/trimming video segments
- Concatenating multiple videos
- Adding text overlays
- Adding AI-generated narration
- Getting media duration
- Looping videos
- Adding audio to videos
Dependencies:
- yt-dlp: For video downloading
- moviepy: For video editing operations
- elevenlabs: For AI narration (optional)
"""
from backend.blocks.video.add_audio import AddAudioToVideoBlock
from backend.blocks.video.clip import VideoClipBlock
from backend.blocks.video.concat import VideoConcatBlock
from backend.blocks.video.download import VideoDownloadBlock
from backend.blocks.video.duration import MediaDurationBlock
from backend.blocks.video.loop import LoopVideoBlock
from backend.blocks.video.narration import VideoNarrationBlock
from backend.blocks.video.text_overlay import VideoTextOverlayBlock
__all__ = [
"AddAudioToVideoBlock",
"LoopVideoBlock",
"MediaDurationBlock",
"VideoClipBlock",
"VideoConcatBlock",
"VideoDownloadBlock",
"VideoNarrationBlock",
"VideoTextOverlayBlock",
]

View File

@@ -1,131 +0,0 @@
"""Shared utilities for video blocks."""
from __future__ import annotations
import logging
import os
import re
import subprocess
from pathlib import Path
logger = logging.getLogger(__name__)
# Known operation tags added by video blocks
_VIDEO_OPS = (
r"(?:clip|overlay|narrated|looped|concat|audio_attached|with_audio|narration)"
)
# Matches: {node_exec_id}_{operation}_ where node_exec_id contains a UUID
_BLOCK_PREFIX_RE = re.compile(
r"^[a-zA-Z0-9_-]*"
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
r"[a-zA-Z0-9_-]*"
r"_" + _VIDEO_OPS + r"_"
)
# Matches: a lone {node_exec_id}_ prefix (no operation keyword, e.g. download output)
_UUID_PREFIX_RE = re.compile(
r"^[a-zA-Z0-9_-]*"
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
r"[a-zA-Z0-9_-]*_"
)
def extract_source_name(input_path: str, max_length: int = 50) -> str:
"""Extract the original source filename by stripping block-generated prefixes.
Iteratively removes {node_exec_id}_{operation}_ prefixes that accumulate
when chaining video blocks, recovering the original human-readable name.
Safe for plain filenames (no UUID -> no stripping).
Falls back to "video" if everything is stripped.
"""
stem = Path(input_path).stem
# Pass 1: strip {node_exec_id}_{operation}_ prefixes iteratively
while _BLOCK_PREFIX_RE.match(stem):
stem = _BLOCK_PREFIX_RE.sub("", stem, count=1)
# Pass 2: strip a lone {node_exec_id}_ prefix (e.g. from download block)
if _UUID_PREFIX_RE.match(stem):
stem = _UUID_PREFIX_RE.sub("", stem, count=1)
if not stem:
return "video"
return stem[:max_length]
def get_video_codecs(output_path: str) -> tuple[str, str]:
"""Get appropriate video and audio codecs based on output file extension.
Args:
output_path: Path to the output file (used to determine extension)
Returns:
Tuple of (video_codec, audio_codec)
Codec mappings:
- .mp4: H.264 + AAC (universal compatibility)
- .webm: VP8 + Vorbis (web streaming)
- .mkv: H.264 + AAC (container supports many codecs)
- .mov: H.264 + AAC (Apple QuickTime, widely compatible)
- .m4v: H.264 + AAC (Apple iTunes/devices)
- .avi: MPEG-4 + MP3 (legacy Windows)
"""
ext = os.path.splitext(output_path)[1].lower()
codec_map: dict[str, tuple[str, str]] = {
".mp4": ("libx264", "aac"),
".webm": ("libvpx", "libvorbis"),
".mkv": ("libx264", "aac"),
".mov": ("libx264", "aac"),
".m4v": ("libx264", "aac"),
".avi": ("mpeg4", "libmp3lame"),
}
return codec_map.get(ext, ("libx264", "aac"))
def strip_chapters_inplace(video_path: str) -> None:
"""Strip chapter metadata from a media file in-place using ffmpeg.
MoviePy 2.x crashes with IndexError when parsing files with embedded
chapter metadata (https://github.com/Zulko/moviepy/issues/2419).
This strips chapters without re-encoding.
Args:
video_path: Absolute path to the media file to strip chapters from.
"""
base, ext = os.path.splitext(video_path)
tmp_path = base + ".tmp" + ext
try:
result = subprocess.run(
[
"ffmpeg",
"-y",
"-i",
video_path,
"-map_chapters",
"-1",
"-codec",
"copy",
tmp_path,
],
capture_output=True,
text=True,
timeout=300,
)
if result.returncode != 0:
logger.warning(
"ffmpeg chapter strip failed (rc=%d): %s",
result.returncode,
result.stderr,
)
return
os.replace(tmp_path, video_path)
except FileNotFoundError:
logger.warning("ffmpeg not found; skipping chapter strip")
finally:
if os.path.exists(tmp_path):
os.unlink(tmp_path)

View File

@@ -1,113 +0,0 @@
"""AddAudioToVideoBlock - Attach an audio track to a video file."""
from moviepy.audio.io.AudioFileClip import AudioFileClip
from moviepy.video.io.VideoFileClip import VideoFileClip
from backend.blocks.video._utils import extract_source_name, strip_chapters_inplace
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
class AddAudioToVideoBlock(Block):
"""Add (attach) an audio track to an existing video."""
class Input(BlockSchemaInput):
video_in: MediaFileType = SchemaField(
description="Video input (URL, data URI, or local path)."
)
audio_in: MediaFileType = SchemaField(
description="Audio input (URL, data URI, or local path)."
)
volume: float = SchemaField(
description="Volume scale for the newly attached audio track (1.0 = original).",
default=1.0,
)
class Output(BlockSchemaOutput):
video_out: MediaFileType = SchemaField(
description="Final video (with attached audio), as a path or data URI."
)
def __init__(self):
super().__init__(
id="3503748d-62b6-4425-91d6-725b064af509",
description="Block to attach an audio file to a video file using moviepy.",
categories={BlockCategory.MULTIMEDIA},
input_schema=AddAudioToVideoBlock.Input,
output_schema=AddAudioToVideoBlock.Output,
)
async def run(
self,
input_data: Input,
*,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
assert execution_context.graph_exec_id is not None
assert execution_context.node_exec_id is not None
graph_exec_id = execution_context.graph_exec_id
node_exec_id = execution_context.node_exec_id
# 1) Store the inputs locally
local_video_path = await store_media_file(
file=input_data.video_in,
execution_context=execution_context,
return_format="for_local_processing",
)
local_audio_path = await store_media_file(
file=input_data.audio_in,
execution_context=execution_context,
return_format="for_local_processing",
)
video_abspath = get_exec_file_path(graph_exec_id, local_video_path)
audio_abspath = get_exec_file_path(graph_exec_id, local_audio_path)
# 2) Load video + audio with moviepy
strip_chapters_inplace(video_abspath)
strip_chapters_inplace(audio_abspath)
video_clip = None
audio_clip = None
final_clip = None
try:
video_clip = VideoFileClip(video_abspath)
audio_clip = AudioFileClip(audio_abspath)
# Optionally scale volume
if input_data.volume != 1.0:
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
# 3) Attach the new audio track
final_clip = video_clip.with_audio(audio_clip)
# 4) Write to output file
source = extract_source_name(local_video_path)
output_filename = MediaFileType(f"{node_exec_id}_with_audio_{source}.mp4")
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
final_clip.write_videofile(
output_abspath, codec="libx264", audio_codec="aac"
)
finally:
if final_clip:
final_clip.close()
if audio_clip:
audio_clip.close()
if video_clip:
video_clip.close()
# 5) Return output - for_block_output returns workspace:// if available, else data URI
video_out = await store_media_file(
file=output_filename,
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_out", video_out

View File

@@ -1,167 +0,0 @@
"""VideoClipBlock - Extract a segment from a video file."""
from typing import Literal
from moviepy.video.io.VideoFileClip import VideoFileClip
from backend.blocks.video._utils import (
extract_source_name,
get_video_codecs,
strip_chapters_inplace,
)
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
class VideoClipBlock(Block):
"""Extract a time segment from a video."""
class Input(BlockSchemaInput):
video_in: MediaFileType = SchemaField(
description="Input video (URL, data URI, or local path)"
)
start_time: float = SchemaField(description="Start time in seconds", ge=0.0)
end_time: float = SchemaField(description="End time in seconds", ge=0.0)
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
description="Output format", default="mp4", advanced=True
)
class Output(BlockSchemaOutput):
video_out: MediaFileType = SchemaField(
description="Clipped video file (path or data URI)"
)
duration: float = SchemaField(description="Clip duration in seconds")
def __init__(self):
super().__init__(
id="8f539119-e580-4d86-ad41-86fbcb22abb1",
description="Extract a time segment from a video",
categories={BlockCategory.MULTIMEDIA},
input_schema=self.Input,
output_schema=self.Output,
test_input={
"video_in": "/tmp/test.mp4",
"start_time": 0.0,
"end_time": 10.0,
},
test_output=[("video_out", str), ("duration", float)],
test_mock={
"_clip_video": lambda *args: 10.0,
"_store_input_video": lambda *args, **kwargs: "test.mp4",
"_store_output_video": lambda *args, **kwargs: "clip_test.mp4",
},
)
async def _store_input_video(
self, execution_context: ExecutionContext, file: MediaFileType
) -> MediaFileType:
"""Store input video. Extracted for testability."""
return await store_media_file(
file=file,
execution_context=execution_context,
return_format="for_local_processing",
)
async def _store_output_video(
self, execution_context: ExecutionContext, file: MediaFileType
) -> MediaFileType:
"""Store output video. Extracted for testability."""
return await store_media_file(
file=file,
execution_context=execution_context,
return_format="for_block_output",
)
def _clip_video(
self,
video_abspath: str,
output_abspath: str,
start_time: float,
end_time: float,
) -> float:
"""Extract a clip from a video. Extracted for testability."""
clip = None
subclip = None
try:
strip_chapters_inplace(video_abspath)
clip = VideoFileClip(video_abspath)
subclip = clip.subclipped(start_time, end_time)
video_codec, audio_codec = get_video_codecs(output_abspath)
subclip.write_videofile(
output_abspath, codec=video_codec, audio_codec=audio_codec
)
return subclip.duration
finally:
if subclip:
subclip.close()
if clip:
clip.close()
async def run(
self,
input_data: Input,
*,
execution_context: ExecutionContext,
node_exec_id: str,
**kwargs,
) -> BlockOutput:
# Validate time range
if input_data.end_time <= input_data.start_time:
raise BlockExecutionError(
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
block_name=self.name,
block_id=str(self.id),
)
try:
assert execution_context.graph_exec_id is not None
# Store the input video locally
local_video_path = await self._store_input_video(
execution_context, input_data.video_in
)
video_abspath = get_exec_file_path(
execution_context.graph_exec_id, local_video_path
)
# Build output path
source = extract_source_name(local_video_path)
output_filename = MediaFileType(
f"{node_exec_id}_clip_{source}.{input_data.output_format}"
)
output_abspath = get_exec_file_path(
execution_context.graph_exec_id, output_filename
)
duration = self._clip_video(
video_abspath,
output_abspath,
input_data.start_time,
input_data.end_time,
)
# Return as workspace path or data URI based on context
video_out = await self._store_output_video(
execution_context, output_filename
)
yield "video_out", video_out
yield "duration", duration
except BlockExecutionError:
raise
except Exception as e:
raise BlockExecutionError(
message=f"Failed to clip video: {e}",
block_name=self.name,
block_id=str(self.id),
) from e

View File

@@ -1,227 +0,0 @@
"""VideoConcatBlock - Concatenate multiple video clips into one."""
from typing import Literal
from moviepy import concatenate_videoclips
from moviepy.video.fx import CrossFadeIn, CrossFadeOut, FadeIn, FadeOut
from moviepy.video.io.VideoFileClip import VideoFileClip
from backend.blocks.video._utils import (
extract_source_name,
get_video_codecs,
strip_chapters_inplace,
)
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
class VideoConcatBlock(Block):
"""Merge multiple video clips into one continuous video."""
class Input(BlockSchemaInput):
videos: list[MediaFileType] = SchemaField(
description="List of video files to concatenate (in order)"
)
transition: Literal["none", "crossfade", "fade_black"] = SchemaField(
description="Transition between clips", default="none"
)
transition_duration: int = SchemaField(
description="Transition duration in seconds",
default=1,
ge=0,
advanced=True,
)
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
description="Output format", default="mp4", advanced=True
)
class Output(BlockSchemaOutput):
video_out: MediaFileType = SchemaField(
description="Concatenated video file (path or data URI)"
)
total_duration: float = SchemaField(description="Total duration in seconds")
def __init__(self):
super().__init__(
id="9b0f531a-1118-487f-aeec-3fa63ea8900a",
description="Merge multiple video clips into one continuous video",
categories={BlockCategory.MULTIMEDIA},
input_schema=self.Input,
output_schema=self.Output,
test_input={
"videos": ["/tmp/a.mp4", "/tmp/b.mp4"],
},
test_output=[
("video_out", str),
("total_duration", float),
],
test_mock={
"_concat_videos": lambda *args: 20.0,
"_store_input_video": lambda *args, **kwargs: "test.mp4",
"_store_output_video": lambda *args, **kwargs: "concat_test.mp4",
},
)
async def _store_input_video(
self, execution_context: ExecutionContext, file: MediaFileType
) -> MediaFileType:
"""Store input video. Extracted for testability."""
return await store_media_file(
file=file,
execution_context=execution_context,
return_format="for_local_processing",
)
async def _store_output_video(
self, execution_context: ExecutionContext, file: MediaFileType
) -> MediaFileType:
"""Store output video. Extracted for testability."""
return await store_media_file(
file=file,
execution_context=execution_context,
return_format="for_block_output",
)
def _concat_videos(
self,
video_abspaths: list[str],
output_abspath: str,
transition: str,
transition_duration: int,
) -> float:
"""Concatenate videos. Extracted for testability.
Returns:
Total duration of the concatenated video.
"""
clips = []
faded_clips = []
final = None
try:
# Load clips
for v in video_abspaths:
strip_chapters_inplace(v)
clips.append(VideoFileClip(v))
# Validate transition_duration against shortest clip
if transition in {"crossfade", "fade_black"} and transition_duration > 0:
min_duration = min(c.duration for c in clips)
if transition_duration >= min_duration:
raise BlockExecutionError(
message=(
f"transition_duration ({transition_duration}s) must be "
f"shorter than the shortest clip ({min_duration:.2f}s)"
),
block_name=self.name,
block_id=str(self.id),
)
if transition == "crossfade":
for i, clip in enumerate(clips):
effects = []
if i > 0:
effects.append(CrossFadeIn(transition_duration))
if i < len(clips) - 1:
effects.append(CrossFadeOut(transition_duration))
if effects:
clip = clip.with_effects(effects)
faded_clips.append(clip)
final = concatenate_videoclips(
faded_clips,
method="compose",
padding=-transition_duration,
)
elif transition == "fade_black":
for clip in clips:
faded = clip.with_effects(
[FadeIn(transition_duration), FadeOut(transition_duration)]
)
faded_clips.append(faded)
final = concatenate_videoclips(faded_clips)
else:
final = concatenate_videoclips(clips)
video_codec, audio_codec = get_video_codecs(output_abspath)
final.write_videofile(
output_abspath, codec=video_codec, audio_codec=audio_codec
)
return final.duration
finally:
if final:
final.close()
for clip in faded_clips:
clip.close()
for clip in clips:
clip.close()
async def run(
self,
input_data: Input,
*,
execution_context: ExecutionContext,
node_exec_id: str,
**kwargs,
) -> BlockOutput:
# Validate minimum clips
if len(input_data.videos) < 2:
raise BlockExecutionError(
message="At least 2 videos are required for concatenation",
block_name=self.name,
block_id=str(self.id),
)
try:
assert execution_context.graph_exec_id is not None
# Store all input videos locally
video_abspaths = []
for video in input_data.videos:
local_path = await self._store_input_video(execution_context, video)
video_abspaths.append(
get_exec_file_path(execution_context.graph_exec_id, local_path)
)
# Build output path
source = (
extract_source_name(video_abspaths[0]) if video_abspaths else "video"
)
output_filename = MediaFileType(
f"{node_exec_id}_concat_{source}.{input_data.output_format}"
)
output_abspath = get_exec_file_path(
execution_context.graph_exec_id, output_filename
)
total_duration = self._concat_videos(
video_abspaths,
output_abspath,
input_data.transition,
input_data.transition_duration,
)
# Return as workspace path or data URI based on context
video_out = await self._store_output_video(
execution_context, output_filename
)
yield "video_out", video_out
yield "total_duration", total_duration
except BlockExecutionError:
raise
except Exception as e:
raise BlockExecutionError(
message=f"Failed to concatenate videos: {e}",
block_name=self.name,
block_id=str(self.id),
) from e

View File

@@ -1,172 +0,0 @@
"""VideoDownloadBlock - Download video from URL (YouTube, Vimeo, news sites, direct links)."""
import os
import typing
from typing import Literal
import yt_dlp
if typing.TYPE_CHECKING:
from yt_dlp import _Params
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
class VideoDownloadBlock(Block):
"""Download video from URL using yt-dlp."""
class Input(BlockSchemaInput):
url: str = SchemaField(
description="URL of the video to download (YouTube, Vimeo, direct link, etc.)",
placeholder="https://www.youtube.com/watch?v=...",
)
quality: Literal["best", "1080p", "720p", "480p", "audio_only"] = SchemaField(
description="Video quality preference", default="720p"
)
output_format: Literal["mp4", "webm", "mkv"] = SchemaField(
description="Output video format", default="mp4", advanced=True
)
class Output(BlockSchemaOutput):
video_file: MediaFileType = SchemaField(
description="Downloaded video (path or data URI)"
)
duration: float = SchemaField(description="Video duration in seconds")
title: str = SchemaField(description="Video title from source")
source_url: str = SchemaField(description="Original source URL")
def __init__(self):
super().__init__(
id="c35daabb-cd60-493b-b9ad-51f1fe4b50c4",
description="Download video from URL (YouTube, Vimeo, news sites, direct links)",
categories={BlockCategory.MULTIMEDIA},
input_schema=self.Input,
output_schema=self.Output,
disabled=True, # Disable until we can sandbox yt-dlp and handle security implications
test_input={
"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
"quality": "480p",
},
test_output=[
("video_file", str),
("duration", float),
("title", str),
("source_url", str),
],
test_mock={
"_download_video": lambda *args: (
"video.mp4",
212.0,
"Test Video",
),
"_store_output_video": lambda *args, **kwargs: "video.mp4",
},
)
async def _store_output_video(
self, execution_context: ExecutionContext, file: MediaFileType
) -> MediaFileType:
"""Store output video. Extracted for testability."""
return await store_media_file(
file=file,
execution_context=execution_context,
return_format="for_block_output",
)
def _get_format_string(self, quality: str) -> str:
formats = {
"best": "bestvideo+bestaudio/best",
"1080p": "bestvideo[height<=1080]+bestaudio/best[height<=1080]",
"720p": "bestvideo[height<=720]+bestaudio/best[height<=720]",
"480p": "bestvideo[height<=480]+bestaudio/best[height<=480]",
"audio_only": "bestaudio/best",
}
return formats.get(quality, formats["720p"])
def _download_video(
self,
url: str,
quality: str,
output_format: str,
output_dir: str,
node_exec_id: str,
) -> tuple[str, float, str]:
"""Download video. Extracted for testability."""
output_template = os.path.join(
output_dir, f"{node_exec_id}_%(title).50s.%(ext)s"
)
ydl_opts: "_Params" = {
"format": f"{self._get_format_string(quality)}/best",
"outtmpl": output_template,
"merge_output_format": output_format,
"quiet": True,
"no_warnings": True,
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
info = ydl.extract_info(url, download=True)
video_path = ydl.prepare_filename(info)
# Handle format conversion in filename
if not video_path.endswith(f".{output_format}"):
video_path = video_path.rsplit(".", 1)[0] + f".{output_format}"
# Return just the filename, not the full path
filename = os.path.basename(video_path)
return (
filename,
info.get("duration") or 0.0,
info.get("title") or "Unknown",
)
async def run(
self,
input_data: Input,
*,
execution_context: ExecutionContext,
node_exec_id: str,
**kwargs,
) -> BlockOutput:
try:
assert execution_context.graph_exec_id is not None
# Get the exec file directory
output_dir = get_exec_file_path(execution_context.graph_exec_id, "")
os.makedirs(output_dir, exist_ok=True)
filename, duration, title = self._download_video(
input_data.url,
input_data.quality,
input_data.output_format,
output_dir,
node_exec_id,
)
# Return as workspace path or data URI based on context
video_out = await self._store_output_video(
execution_context, MediaFileType(filename)
)
yield "video_file", video_out
yield "duration", duration
yield "title", title
yield "source_url", input_data.url
except Exception as e:
raise BlockExecutionError(
message=f"Failed to download video: {e}",
block_name=self.name,
block_id=str(self.id),
) from e

View File

@@ -1,77 +0,0 @@
"""MediaDurationBlock - Get the duration of a media file."""
from moviepy.audio.io.AudioFileClip import AudioFileClip
from moviepy.video.io.VideoFileClip import VideoFileClip
from backend.blocks.video._utils import strip_chapters_inplace
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
class MediaDurationBlock(Block):
"""Get the duration of a media file (video or audio)."""
class Input(BlockSchemaInput):
media_in: MediaFileType = SchemaField(
description="Media input (URL, data URI, or local path)."
)
is_video: bool = SchemaField(
description="Whether the media is a video (True) or audio (False).",
default=True,
)
class Output(BlockSchemaOutput):
duration: float = SchemaField(
description="Duration of the media file (in seconds)."
)
def __init__(self):
super().__init__(
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
description="Block to get the duration of a media file.",
categories={BlockCategory.MULTIMEDIA},
input_schema=MediaDurationBlock.Input,
output_schema=MediaDurationBlock.Output,
)
async def run(
self,
input_data: Input,
*,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# 1) Store the input media locally
local_media_path = await store_media_file(
file=input_data.media_in,
execution_context=execution_context,
return_format="for_local_processing",
)
assert execution_context.graph_exec_id is not None
media_abspath = get_exec_file_path(
execution_context.graph_exec_id, local_media_path
)
# 2) Strip chapters to avoid MoviePy crash, then load the clip
strip_chapters_inplace(media_abspath)
clip = None
try:
if input_data.is_video:
clip = VideoFileClip(media_abspath)
else:
clip = AudioFileClip(media_abspath)
duration = clip.duration
finally:
if clip:
clip.close()
yield "duration", duration

View File

@@ -1,115 +0,0 @@
"""LoopVideoBlock - Loop a video to a given duration or number of repeats."""
from typing import Optional
from moviepy.video.fx.Loop import Loop
from moviepy.video.io.VideoFileClip import VideoFileClip
from backend.blocks.video._utils import extract_source_name, strip_chapters_inplace
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
class LoopVideoBlock(Block):
"""Loop (repeat) a video clip until a given duration or number of loops."""
class Input(BlockSchemaInput):
video_in: MediaFileType = SchemaField(
description="The input video (can be a URL, data URI, or local path)."
)
duration: Optional[float] = SchemaField(
description="Target duration (in seconds) to loop the video to. Either duration or n_loops must be provided.",
default=None,
ge=0.0,
le=3600.0, # Max 1 hour to prevent disk exhaustion
)
n_loops: Optional[int] = SchemaField(
description="Number of times to repeat the video. Either n_loops or duration must be provided.",
default=None,
ge=1,
le=10, # Max 10 loops to prevent disk exhaustion
)
class Output(BlockSchemaOutput):
video_out: MediaFileType = SchemaField(
description="Looped video returned either as a relative path or a data URI."
)
def __init__(self):
super().__init__(
id="8bf9eef6-5451-4213-b265-25306446e94b",
description="Block to loop a video to a given duration or number of repeats.",
categories={BlockCategory.MULTIMEDIA},
input_schema=LoopVideoBlock.Input,
output_schema=LoopVideoBlock.Output,
)
async def run(
self,
input_data: Input,
*,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
assert execution_context.graph_exec_id is not None
assert execution_context.node_exec_id is not None
graph_exec_id = execution_context.graph_exec_id
node_exec_id = execution_context.node_exec_id
# 1) Store the input video locally
local_video_path = await store_media_file(
file=input_data.video_in,
execution_context=execution_context,
return_format="for_local_processing",
)
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
# 2) Load the clip
strip_chapters_inplace(input_abspath)
clip = None
looped_clip = None
try:
clip = VideoFileClip(input_abspath)
# 3) Apply the loop effect
if input_data.duration:
# Loop until we reach the specified duration
looped_clip = clip.with_effects([Loop(duration=input_data.duration)])
elif input_data.n_loops:
looped_clip = clip.with_effects([Loop(n=input_data.n_loops)])
else:
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
assert isinstance(looped_clip, VideoFileClip)
# 4) Save the looped output
source = extract_source_name(local_video_path)
output_filename = MediaFileType(f"{node_exec_id}_looped_{source}.mp4")
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
looped_clip = looped_clip.with_audio(clip.audio)
looped_clip.write_videofile(
output_abspath, codec="libx264", audio_codec="aac"
)
finally:
if looped_clip:
looped_clip.close()
if clip:
clip.close()
# Return output - for_block_output returns workspace:// if available, else data URI
video_out = await store_media_file(
file=output_filename,
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_out", video_out

View File

@@ -1,267 +0,0 @@
"""VideoNarrationBlock - Generate AI voice narration and add to video."""
import os
from typing import Literal
from elevenlabs import ElevenLabs
from moviepy import CompositeAudioClip
from moviepy.audio.io.AudioFileClip import AudioFileClip
from moviepy.video.io.VideoFileClip import VideoFileClip
from backend.blocks.elevenlabs._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
ElevenLabsCredentials,
ElevenLabsCredentialsInput,
)
from backend.blocks.video._utils import (
extract_source_name,
get_video_codecs,
strip_chapters_inplace,
)
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsField, SchemaField
from backend.util.exceptions import BlockExecutionError
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
class VideoNarrationBlock(Block):
"""Generate AI narration and add to video."""
class Input(BlockSchemaInput):
credentials: ElevenLabsCredentialsInput = CredentialsField(
description="ElevenLabs API key for voice synthesis"
)
video_in: MediaFileType = SchemaField(
description="Input video (URL, data URI, or local path)"
)
script: str = SchemaField(description="Narration script text")
voice_id: str = SchemaField(
description="ElevenLabs voice ID", default="21m00Tcm4TlvDq8ikWAM" # Rachel
)
model_id: Literal[
"eleven_multilingual_v2",
"eleven_flash_v2_5",
"eleven_turbo_v2_5",
"eleven_turbo_v2",
] = SchemaField(
description="ElevenLabs TTS model",
default="eleven_multilingual_v2",
)
mix_mode: Literal["replace", "mix", "ducking"] = SchemaField(
description="How to combine with original audio. 'ducking' applies stronger attenuation than 'mix'.",
default="ducking",
)
narration_volume: float = SchemaField(
description="Narration volume (0.0 to 2.0)",
default=1.0,
ge=0.0,
le=2.0,
advanced=True,
)
original_volume: float = SchemaField(
description="Original audio volume when mixing (0.0 to 1.0)",
default=0.3,
ge=0.0,
le=1.0,
advanced=True,
)
class Output(BlockSchemaOutput):
video_out: MediaFileType = SchemaField(
description="Video with narration (path or data URI)"
)
audio_file: MediaFileType = SchemaField(
description="Generated audio file (path or data URI)"
)
def __init__(self):
super().__init__(
id="3d036b53-859c-4b17-9826-ca340f736e0e",
description="Generate AI narration and add to video",
categories={BlockCategory.MULTIMEDIA, BlockCategory.AI},
input_schema=self.Input,
output_schema=self.Output,
test_input={
"video_in": "/tmp/test.mp4",
"script": "Hello world",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("video_out", str), ("audio_file", str)],
test_mock={
"_generate_narration_audio": lambda *args: b"mock audio content",
"_add_narration_to_video": lambda *args: None,
"_store_input_video": lambda *args, **kwargs: "test.mp4",
"_store_output_video": lambda *args, **kwargs: "narrated_test.mp4",
},
)
async def _store_input_video(
self, execution_context: ExecutionContext, file: MediaFileType
) -> MediaFileType:
"""Store input video. Extracted for testability."""
return await store_media_file(
file=file,
execution_context=execution_context,
return_format="for_local_processing",
)
async def _store_output_video(
self, execution_context: ExecutionContext, file: MediaFileType
) -> MediaFileType:
"""Store output video. Extracted for testability."""
return await store_media_file(
file=file,
execution_context=execution_context,
return_format="for_block_output",
)
def _generate_narration_audio(
self, api_key: str, script: str, voice_id: str, model_id: str
) -> bytes:
"""Generate narration audio via ElevenLabs API."""
client = ElevenLabs(api_key=api_key)
audio_generator = client.text_to_speech.convert(
voice_id=voice_id,
text=script,
model_id=model_id,
)
# The SDK returns a generator, collect all chunks
return b"".join(audio_generator)
def _add_narration_to_video(
self,
video_abspath: str,
audio_abspath: str,
output_abspath: str,
mix_mode: str,
narration_volume: float,
original_volume: float,
) -> None:
"""Add narration audio to video. Extracted for testability."""
video = None
final = None
narration_original = None
narration_scaled = None
original = None
try:
strip_chapters_inplace(video_abspath)
video = VideoFileClip(video_abspath)
narration_original = AudioFileClip(audio_abspath)
narration_scaled = narration_original.with_volume_scaled(narration_volume)
narration = narration_scaled
if mix_mode == "replace":
final_audio = narration
elif mix_mode == "mix":
if video.audio:
original = video.audio.with_volume_scaled(original_volume)
final_audio = CompositeAudioClip([original, narration])
else:
final_audio = narration
else: # ducking - apply stronger attenuation
if video.audio:
# Ducking uses a much lower volume for original audio
ducking_volume = original_volume * 0.3
original = video.audio.with_volume_scaled(ducking_volume)
final_audio = CompositeAudioClip([original, narration])
else:
final_audio = narration
final = video.with_audio(final_audio)
video_codec, audio_codec = get_video_codecs(output_abspath)
final.write_videofile(
output_abspath, codec=video_codec, audio_codec=audio_codec
)
finally:
if original:
original.close()
if narration_scaled:
narration_scaled.close()
if narration_original:
narration_original.close()
if final:
final.close()
if video:
video.close()
async def run(
self,
input_data: Input,
*,
credentials: ElevenLabsCredentials,
execution_context: ExecutionContext,
node_exec_id: str,
**kwargs,
) -> BlockOutput:
try:
assert execution_context.graph_exec_id is not None
# Store the input video locally
local_video_path = await self._store_input_video(
execution_context, input_data.video_in
)
video_abspath = get_exec_file_path(
execution_context.graph_exec_id, local_video_path
)
# Generate narration audio via ElevenLabs
audio_content = self._generate_narration_audio(
credentials.api_key.get_secret_value(),
input_data.script,
input_data.voice_id,
input_data.model_id,
)
# Save audio to exec file path
audio_filename = MediaFileType(f"{node_exec_id}_narration.mp3")
audio_abspath = get_exec_file_path(
execution_context.graph_exec_id, audio_filename
)
os.makedirs(os.path.dirname(audio_abspath), exist_ok=True)
with open(audio_abspath, "wb") as f:
f.write(audio_content)
# Add narration to video
source = extract_source_name(local_video_path)
output_filename = MediaFileType(f"{node_exec_id}_narrated_{source}.mp4")
output_abspath = get_exec_file_path(
execution_context.graph_exec_id, output_filename
)
self._add_narration_to_video(
video_abspath,
audio_abspath,
output_abspath,
input_data.mix_mode,
input_data.narration_volume,
input_data.original_volume,
)
# Return as workspace path or data URI based on context
video_out = await self._store_output_video(
execution_context, output_filename
)
audio_out = await self._store_output_video(
execution_context, audio_filename
)
yield "video_out", video_out
yield "audio_file", audio_out
except Exception as e:
raise BlockExecutionError(
message=f"Failed to add narration: {e}",
block_name=self.name,
block_id=str(self.id),
) from e

View File

@@ -1,231 +0,0 @@
"""VideoTextOverlayBlock - Add text overlay to video."""
from typing import Literal
from moviepy import CompositeVideoClip, TextClip
from moviepy.video.io.VideoFileClip import VideoFileClip
from backend.blocks.video._utils import (
extract_source_name,
get_video_codecs,
strip_chapters_inplace,
)
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
class VideoTextOverlayBlock(Block):
"""Add text overlay/caption to video."""
class Input(BlockSchemaInput):
video_in: MediaFileType = SchemaField(
description="Input video (URL, data URI, or local path)"
)
text: str = SchemaField(description="Text to overlay on video")
position: Literal[
"top",
"center",
"bottom",
"top-left",
"top-right",
"bottom-left",
"bottom-right",
] = SchemaField(description="Position of text on screen", default="bottom")
start_time: float | None = SchemaField(
description="When to show text (seconds). None = entire video",
default=None,
advanced=True,
)
end_time: float | None = SchemaField(
description="When to hide text (seconds). None = until end",
default=None,
advanced=True,
)
font_size: int = SchemaField(
description="Font size", default=48, ge=12, le=200, advanced=True
)
font_color: str = SchemaField(
description="Font color (hex or name)", default="white", advanced=True
)
bg_color: str | None = SchemaField(
description="Background color behind text (None for transparent)",
default=None,
advanced=True,
)
class Output(BlockSchemaOutput):
video_out: MediaFileType = SchemaField(
description="Video with text overlay (path or data URI)"
)
def __init__(self):
super().__init__(
id="8ef14de6-cc90-430a-8cfa-3a003be92454",
description="Add text overlay/caption to video",
categories={BlockCategory.MULTIMEDIA},
input_schema=self.Input,
output_schema=self.Output,
disabled=True, # Disable until we can lockdown imagemagick security policy
test_input={"video_in": "/tmp/test.mp4", "text": "Hello World"},
test_output=[("video_out", str)],
test_mock={
"_add_text_overlay": lambda *args: None,
"_store_input_video": lambda *args, **kwargs: "test.mp4",
"_store_output_video": lambda *args, **kwargs: "overlay_test.mp4",
},
)
async def _store_input_video(
self, execution_context: ExecutionContext, file: MediaFileType
) -> MediaFileType:
"""Store input video. Extracted for testability."""
return await store_media_file(
file=file,
execution_context=execution_context,
return_format="for_local_processing",
)
async def _store_output_video(
self, execution_context: ExecutionContext, file: MediaFileType
) -> MediaFileType:
"""Store output video. Extracted for testability."""
return await store_media_file(
file=file,
execution_context=execution_context,
return_format="for_block_output",
)
def _add_text_overlay(
self,
video_abspath: str,
output_abspath: str,
text: str,
position: str,
start_time: float | None,
end_time: float | None,
font_size: int,
font_color: str,
bg_color: str | None,
) -> None:
"""Add text overlay to video. Extracted for testability."""
video = None
final = None
txt_clip = None
try:
strip_chapters_inplace(video_abspath)
video = VideoFileClip(video_abspath)
txt_clip = TextClip(
text=text,
font_size=font_size,
color=font_color,
bg_color=bg_color,
)
# Position mapping
pos_map = {
"top": ("center", "top"),
"center": ("center", "center"),
"bottom": ("center", "bottom"),
"top-left": ("left", "top"),
"top-right": ("right", "top"),
"bottom-left": ("left", "bottom"),
"bottom-right": ("right", "bottom"),
}
txt_clip = txt_clip.with_position(pos_map[position])
# Set timing
start = start_time or 0
end = end_time or video.duration
duration = max(0, end - start)
txt_clip = txt_clip.with_start(start).with_end(end).with_duration(duration)
final = CompositeVideoClip([video, txt_clip])
video_codec, audio_codec = get_video_codecs(output_abspath)
final.write_videofile(
output_abspath, codec=video_codec, audio_codec=audio_codec
)
finally:
if txt_clip:
txt_clip.close()
if final:
final.close()
if video:
video.close()
async def run(
self,
input_data: Input,
*,
execution_context: ExecutionContext,
node_exec_id: str,
**kwargs,
) -> BlockOutput:
# Validate time range if both are provided
if (
input_data.start_time is not None
and input_data.end_time is not None
and input_data.end_time <= input_data.start_time
):
raise BlockExecutionError(
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
block_name=self.name,
block_id=str(self.id),
)
try:
assert execution_context.graph_exec_id is not None
# Store the input video locally
local_video_path = await self._store_input_video(
execution_context, input_data.video_in
)
video_abspath = get_exec_file_path(
execution_context.graph_exec_id, local_video_path
)
# Build output path
source = extract_source_name(local_video_path)
output_filename = MediaFileType(f"{node_exec_id}_overlay_{source}.mp4")
output_abspath = get_exec_file_path(
execution_context.graph_exec_id, output_filename
)
self._add_text_overlay(
video_abspath,
output_abspath,
input_data.text,
input_data.position,
input_data.start_time,
input_data.end_time,
input_data.font_size,
input_data.font_color,
input_data.bg_color,
)
# Return as workspace path or data URI based on context
video_out = await self._store_output_video(
execution_context, output_filename
)
yield "video_out", video_out
except BlockExecutionError:
raise
except Exception as e:
raise BlockExecutionError(
message=f"Failed to add text overlay: {e}",
block_name=self.name,
block_id=str(self.id),
) from e

View File

@@ -165,13 +165,10 @@ class TranscribeYoutubeVideoBlock(Block):
credentials: WebshareProxyCredentials,
**kwargs,
) -> BlockOutput:
try:
video_id = self.extract_video_id(input_data.youtube_url)
transcript = self.get_transcript(video_id, credentials)
transcript_text = self.format_transcript(transcript=transcript)
video_id = self.extract_video_id(input_data.youtube_url)
yield "video_id", video_id
# Only yield after all operations succeed
yield "video_id", video_id
yield "transcript", transcript_text
except Exception as e:
yield "error", str(e)
transcript = self.get_transcript(video_id, credentials)
transcript_text = self.format_transcript(transcript=transcript)
yield "transcript", transcript_text

View File

@@ -873,13 +873,14 @@ def is_block_auth_configured(
async def initialize_blocks() -> None:
# First, sync all provider costs to blocks
# Imported here to avoid circular import
from backend.sdk.cost_integration import sync_all_provider_costs
from backend.util.retry import func_retry
sync_all_provider_costs()
@func_retry
async def sync_block_to_db(block: Block) -> None:
for cls in get_blocks().values():
block = cls()
existing_block = await AgentBlock.prisma().find_first(
where={"OR": [{"id": block.id}, {"name": block.name}]}
)
@@ -892,7 +893,7 @@ async def initialize_blocks() -> None:
outputSchema=json.dumps(block.output_schema.jsonschema()),
)
)
return
continue
input_schema = json.dumps(block.input_schema.jsonschema())
output_schema = json.dumps(block.output_schema.jsonschema())
@@ -912,25 +913,6 @@ async def initialize_blocks() -> None:
},
)
failed_blocks: list[str] = []
for cls in get_blocks().values():
block = cls()
try:
await sync_block_to_db(block)
except Exception as e:
logger.warning(
f"Failed to sync block {block.name} to database: {e}. "
"Block is still available in memory.",
exc_info=True,
)
failed_blocks.append(block.name)
if failed_blocks:
logger.error(
f"Failed to sync {len(failed_blocks)} block(s) to database: "
f"{', '.join(failed_blocks)}. These blocks are still available in memory."
)
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
def get_block(block_id: str) -> AnyBlockSchema | None:

View File

@@ -36,14 +36,12 @@ from backend.blocks.replicate.replicate_block import ReplicateModelBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
from backend.blocks.video.narration import VideoNarrationBlock
from backend.data.block import Block, BlockCost, BlockCostType
from backend.integrations.credentials_store import (
aiml_api_credentials,
anthropic_credentials,
apollo_credentials,
did_credentials,
elevenlabs_credentials,
enrichlayer_credentials,
groq_credentials,
ideogram_credentials,
@@ -80,10 +78,10 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.CLAUDE_4_1_OPUS: 21,
LlmModel.CLAUDE_4_OPUS: 21,
LlmModel.CLAUDE_4_SONNET: 5,
LlmModel.CLAUDE_4_6_OPUS: 14,
LlmModel.CLAUDE_4_5_HAIKU: 4,
LlmModel.CLAUDE_4_5_OPUS: 14,
LlmModel.CLAUDE_4_5_SONNET: 9,
LlmModel.CLAUDE_3_7_SONNET: 5,
LlmModel.CLAUDE_3_HAIKU: 1,
LlmModel.AIML_API_QWEN2_5_72B: 1,
LlmModel.AIML_API_LLAMA3_1_70B: 1,
@@ -642,16 +640,4 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
},
),
],
VideoNarrationBlock: [
BlockCost(
cost_amount=5, # ElevenLabs TTS cost
cost_filter={
"credentials": {
"id": elevenlabs_credentials.id,
"provider": elevenlabs_credentials.provider,
"type": elevenlabs_credentials.type,
}
},
)
],
}

View File

@@ -134,16 +134,6 @@ async def test_block_credit_reset(server: SpinTestServer):
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
user_credit.time_now = lambda: month1
# IMPORTANT: Set updatedAt to December of previous year to ensure it's
# in a different month than month1 (January). This fixes a timing bug
# where if the test runs in early February, 35 days ago would be January,
# matching the mocked month1 and preventing the refill from triggering.
dec_previous_year = month1.replace(year=month1.year - 1, month=12, day=15)
await UserBalance.prisma().update(
where={"userId": DEFAULT_USER_ID},
data={"updatedAt": dec_previous_year},
)
# First call in month 1 should trigger refill
balance = await user_credit.get_credits(DEFAULT_USER_ID)
assert balance == REFILL_VALUE # Should get 1000 credits

View File

@@ -133,23 +133,10 @@ class RedisEventBus(BaseRedisEventBus[M], ABC):
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
def __init__(self):
self._pubsub: AsyncPubSub | None = None
@property
async def connection(self) -> redis.AsyncRedis:
return await redis.get_redis_async()
async def close(self) -> None:
"""Close the PubSub connection if it exists."""
if self._pubsub is not None:
try:
await self._pubsub.close()
except Exception:
logger.warning("Failed to close PubSub connection", exc_info=True)
finally:
self._pubsub = None
async def publish_event(self, event: M, channel_key: str):
"""
Publish an event to Redis. Gracefully handles connection failures
@@ -170,7 +157,6 @@ class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
await self.connection, channel_key
)
assert isinstance(pubsub, AsyncPubSub)
self._pubsub = pubsub
if "*" in channel_key:
await pubsub.psubscribe(full_channel_name)

View File

@@ -83,29 +83,12 @@ class ExecutionContext(BaseModel):
model_config = {"extra": "ignore"}
# Execution identity
user_id: Optional[str] = None
graph_id: Optional[str] = None
graph_exec_id: Optional[str] = None
graph_version: Optional[int] = None
node_id: Optional[str] = None
node_exec_id: Optional[str] = None
# Safety settings
human_in_the_loop_safe_mode: bool = True
sensitive_action_safe_mode: bool = False
# User settings
user_timezone: str = "UTC"
# Execution hierarchy
root_execution_id: Optional[str] = None
parent_execution_id: Optional[str] = None
# Workspace
workspace_id: Optional[str] = None
session_id: Optional[str] = None
# -------------------------- Models -------------------------- #

View File

@@ -1028,39 +1028,6 @@ async def get_graph(
return GraphModel.from_db(graph, for_export)
async def get_store_listed_graphs(*graph_ids: str) -> dict[str, GraphModel]:
"""Batch-fetch multiple store-listed graphs by their IDs.
Only returns graphs that have approved store listings (publicly available).
Does not require permission checks since store-listed graphs are public.
Args:
*graph_ids: Variable number of graph IDs to fetch
Returns:
Dict mapping graph_id to GraphModel for graphs with approved store listings
"""
if not graph_ids:
return {}
store_listings = await StoreListingVersion.prisma().find_many(
where={
"agentGraphId": {"in": list(graph_ids)},
"submissionStatus": SubmissionStatus.APPROVED,
"isDeleted": False,
},
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
distinct=["agentGraphId"],
order={"agentGraphVersion": "desc"},
)
return {
listing.agentGraphId: GraphModel.from_db(listing.AgentGraph)
for listing in store_listings
if listing.AgentGraph
}
async def get_graph_as_admin(
graph_id: str,
version: int | None = None,

View File

@@ -19,6 +19,7 @@ from typing import (
cast,
get_args,
)
from urllib.parse import urlparse
from uuid import uuid4
from prisma.enums import CreditTransactionType, OnboardingStep
@@ -41,7 +42,6 @@ from typing_extensions import TypedDict
from backend.integrations.providers import ProviderName
from backend.util.json import loads as json_loads
from backend.util.request import parse_url
from backend.util.settings import Secrets
# Type alias for any provider name (including custom ones)
@@ -397,25 +397,19 @@ class HostScopedCredentials(_BaseCredentials):
def matches_url(self, url: str) -> bool:
"""Check if this credential should be applied to the given URL."""
request_host, request_port = _extract_host_from_url(url)
cred_scope_host, cred_scope_port = _extract_host_from_url(self.host)
parsed_url = urlparse(url)
# Extract hostname without port
request_host = parsed_url.hostname
if not request_host:
return False
# If a port is specified in credential host, the request host port must match
if cred_scope_port is not None and request_port != cred_scope_port:
return False
# Non-standard ports are only allowed if explicitly specified in credential host
elif cred_scope_port is None and request_port not in (80, 443, None):
return False
# Simple host matching
if cred_scope_host == request_host:
# Simple host matching - exact match or wildcard subdomain match
if self.host == request_host:
return True
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
if cred_scope_host.startswith("*."):
domain = cred_scope_host[2:] # Remove "*."
if self.host.startswith("*."):
domain = self.host[2:] # Remove "*."
return request_host.endswith(f".{domain}") or request_host == domain
return False
@@ -557,13 +551,13 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
)
def _extract_host_from_url(url: str) -> tuple[str, int | None]:
"""Extract host and port from URL for grouping host-scoped credentials."""
def _extract_host_from_url(url: str) -> str:
"""Extract host from URL for grouping host-scoped credentials."""
try:
parsed = parse_url(url)
return parsed.hostname or url, parsed.port
parsed = urlparse(url)
return parsed.hostname or url
except Exception:
return "", None
return ""
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
@@ -612,7 +606,7 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
providers = frozenset(
[cast(CP, "http")]
+ [
cast(CP, parse_url(str(value)).netloc)
cast(CP, _extract_host_from_url(str(value)))
for value in field.discriminator_values
]
)
@@ -672,16 +666,10 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
if not (self.discriminator and self.discriminator_mapping):
return self
try:
provider = self.discriminator_mapping[discriminator_value]
except KeyError:
raise ValueError(
f"Model '{discriminator_value}' is not supported. "
"It may have been deprecated. Please update your agent configuration."
)
return CredentialsFieldInfo(
credentials_provider=frozenset([provider]),
credentials_provider=frozenset(
[self.discriminator_mapping[discriminator_value]]
),
credentials_types=self.supported_types,
credentials_scopes=self.required_scopes,
discriminator=self.discriminator,

View File

@@ -79,23 +79,10 @@ class TestHostScopedCredentials:
headers={"Authorization": SecretStr("Bearer token")},
)
# Non-standard ports require explicit port in credential host
assert not creds.matches_url("http://localhost:8080/api/v1")
assert creds.matches_url("http://localhost:8080/api/v1")
assert creds.matches_url("https://localhost:443/secure/endpoint")
assert creds.matches_url("http://localhost/simple")
def test_matches_url_with_explicit_port(self):
"""Test URL matching with explicit port in credential host."""
creds = HostScopedCredentials(
provider="custom",
host="localhost:8080",
headers={"Authorization": SecretStr("Bearer token")},
)
assert creds.matches_url("http://localhost:8080/api/v1")
assert not creds.matches_url("http://localhost:3000/api/v1")
assert not creds.matches_url("http://localhost/simple")
def test_empty_headers_dict(self):
"""Test HostScopedCredentials with empty headers."""
creds = HostScopedCredentials(
@@ -141,20 +128,8 @@ class TestHostScopedCredentials:
("*.example.com", "https://sub.api.example.com/test", True),
("*.example.com", "https://example.com/test", True),
("*.example.com", "https://example.org/test", False),
# Non-standard ports require explicit port in credential host
("localhost", "http://localhost:3000/test", False),
("localhost:3000", "http://localhost:3000/test", True),
("localhost", "http://localhost:3000/test", True),
("localhost", "http://127.0.0.1:3000/test", False),
# IPv6 addresses (frontend stores with brackets via URL.hostname)
("[::1]", "http://[::1]/test", True),
("[::1]", "http://[::1]:80/test", True),
("[::1]", "https://[::1]:443/test", True),
("[::1]", "http://[::1]:8080/test", False), # Non-standard port
("[::1]:8080", "http://[::1]:8080/test", True),
("[::1]:8080", "http://[::1]:9090/test", False),
("[2001:db8::1]", "http://[2001:db8::1]/path", True),
("[2001:db8::1]", "https://[2001:db8::1]:443/path", True),
("[2001:db8::1]", "http://[2001:db8::ff]/path", False),
],
)
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):

View File

@@ -1,276 +0,0 @@
"""
Database CRUD operations for User Workspace.
This module provides functions for managing user workspaces and workspace files.
"""
import logging
from datetime import datetime, timezone
from typing import Optional
from prisma.models import UserWorkspace, UserWorkspaceFile
from prisma.types import UserWorkspaceFileWhereInput
from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
async def get_or_create_workspace(user_id: str) -> UserWorkspace:
"""
Get user's workspace, creating one if it doesn't exist.
Uses upsert to handle race conditions when multiple concurrent requests
attempt to create a workspace for the same user.
Args:
user_id: The user's ID
Returns:
UserWorkspace instance
"""
workspace = await UserWorkspace.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id},
"update": {}, # No updates needed if exists
},
)
return workspace
async def get_workspace(user_id: str) -> Optional[UserWorkspace]:
"""
Get user's workspace if it exists.
Args:
user_id: The user's ID
Returns:
UserWorkspace instance or None
"""
return await UserWorkspace.prisma().find_unique(where={"userId": user_id})
async def create_workspace_file(
workspace_id: str,
file_id: str,
name: str,
path: str,
storage_path: str,
mime_type: str,
size_bytes: int,
checksum: Optional[str] = None,
metadata: Optional[dict] = None,
) -> UserWorkspaceFile:
"""
Create a new workspace file record.
Args:
workspace_id: The workspace ID
file_id: The file ID (same as used in storage path for consistency)
name: User-visible filename
path: Virtual path (e.g., "/documents/report.pdf")
storage_path: Actual storage path (GCS or local)
mime_type: MIME type of the file
size_bytes: File size in bytes
checksum: Optional SHA256 checksum
metadata: Optional additional metadata
Returns:
Created UserWorkspaceFile instance
"""
# Normalize path to start with /
if not path.startswith("/"):
path = f"/{path}"
file = await UserWorkspaceFile.prisma().create(
data={
"id": file_id,
"workspaceId": workspace_id,
"name": name,
"path": path,
"storagePath": storage_path,
"mimeType": mime_type,
"sizeBytes": size_bytes,
"checksum": checksum,
"metadata": SafeJson(metadata or {}),
}
)
logger.info(
f"Created workspace file {file.id} at path {path} "
f"in workspace {workspace_id}"
)
return file
async def get_workspace_file(
file_id: str,
workspace_id: Optional[str] = None,
) -> Optional[UserWorkspaceFile]:
"""
Get a workspace file by ID.
Args:
file_id: The file ID
workspace_id: Optional workspace ID for validation
Returns:
UserWorkspaceFile instance or None
"""
where_clause: dict = {"id": file_id, "isDeleted": False}
if workspace_id:
where_clause["workspaceId"] = workspace_id
return await UserWorkspaceFile.prisma().find_first(where=where_clause)
async def get_workspace_file_by_path(
workspace_id: str,
path: str,
) -> Optional[UserWorkspaceFile]:
"""
Get a workspace file by its virtual path.
Args:
workspace_id: The workspace ID
path: Virtual path
Returns:
UserWorkspaceFile instance or None
"""
# Normalize path
if not path.startswith("/"):
path = f"/{path}"
return await UserWorkspaceFile.prisma().find_first(
where={
"workspaceId": workspace_id,
"path": path,
"isDeleted": False,
}
)
async def list_workspace_files(
workspace_id: str,
path_prefix: Optional[str] = None,
include_deleted: bool = False,
limit: Optional[int] = None,
offset: int = 0,
) -> list[UserWorkspaceFile]:
"""
List files in a workspace.
Args:
workspace_id: The workspace ID
path_prefix: Optional path prefix to filter (e.g., "/documents/")
include_deleted: Whether to include soft-deleted files
limit: Maximum number of files to return
offset: Number of files to skip
Returns:
List of UserWorkspaceFile instances
"""
where_clause: UserWorkspaceFileWhereInput = {"workspaceId": workspace_id}
if not include_deleted:
where_clause["isDeleted"] = False
if path_prefix:
# Normalize prefix
if not path_prefix.startswith("/"):
path_prefix = f"/{path_prefix}"
where_clause["path"] = {"startswith": path_prefix}
return await UserWorkspaceFile.prisma().find_many(
where=where_clause,
order={"createdAt": "desc"},
take=limit,
skip=offset,
)
async def count_workspace_files(
workspace_id: str,
path_prefix: Optional[str] = None,
include_deleted: bool = False,
) -> int:
"""
Count files in a workspace.
Args:
workspace_id: The workspace ID
path_prefix: Optional path prefix to filter (e.g., "/sessions/abc123/")
include_deleted: Whether to include soft-deleted files
Returns:
Number of files
"""
where_clause: dict = {"workspaceId": workspace_id}
if not include_deleted:
where_clause["isDeleted"] = False
if path_prefix:
# Normalize prefix
if not path_prefix.startswith("/"):
path_prefix = f"/{path_prefix}"
where_clause["path"] = {"startswith": path_prefix}
return await UserWorkspaceFile.prisma().count(where=where_clause)
async def soft_delete_workspace_file(
file_id: str,
workspace_id: Optional[str] = None,
) -> Optional[UserWorkspaceFile]:
"""
Soft-delete a workspace file.
The path is modified to include a deletion timestamp to free up the original
path for new files while preserving the record for potential recovery.
Args:
file_id: The file ID
workspace_id: Optional workspace ID for validation
Returns:
Updated UserWorkspaceFile instance or None if not found
"""
# First verify the file exists and belongs to workspace
file = await get_workspace_file(file_id, workspace_id)
if file is None:
return None
deleted_at = datetime.now(timezone.utc)
# Modify path to free up the unique constraint for new files at original path
# Format: {original_path}__deleted__{timestamp}
deleted_path = f"{file.path}__deleted__{int(deleted_at.timestamp())}"
updated = await UserWorkspaceFile.prisma().update(
where={"id": file_id},
data={
"isDeleted": True,
"deletedAt": deleted_at,
"path": deleted_path,
},
)
logger.info(f"Soft-deleted workspace file {file_id}")
return updated
async def get_workspace_total_size(workspace_id: str) -> int:
"""
Get the total size of all files in a workspace.
Args:
workspace_id: The workspace ID
Returns:
Total size in bytes
"""
files = await list_workspace_files(workspace_id)
return sum(file.sizeBytes for file in files)

View File

@@ -17,7 +17,6 @@ from backend.data.analytics import (
get_accuracy_trends_and_alerts,
get_marketplace_graphs_for_monitoring,
)
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.execution import (
create_graph_execution,
@@ -220,9 +219,6 @@ class DatabaseManager(AppService):
# Onboarding
increment_onboarding_runs = _(increment_onboarding_runs)
# OAuth
cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens)
# Store
get_store_agents = _(get_store_agents)
get_store_agent_details = _(get_store_agent_details)
@@ -353,9 +349,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
# Onboarding
increment_onboarding_runs = d.increment_onboarding_runs
# OAuth
cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens
# Store
get_store_agents = d.get_store_agents
get_store_agent_details = d.get_store_agent_details

View File

@@ -236,14 +236,7 @@ async def execute_node(
input_size = len(input_data_str)
log_metadata.debug("Executed node with input", input=input_data_str)
# Create node-specific execution context to avoid race conditions
# (multiple nodes can execute concurrently and would otherwise mutate shared state)
execution_context = execution_context.model_copy(
update={"node_id": node_id, "node_exec_id": node_exec_id}
)
# Inject extra execution arguments for the blocks via kwargs
# Keep individual kwargs for backwards compatibility with existing blocks
extra_exec_kwargs: dict = {
"graph_id": graph_id,
"graph_version": graph_version,

View File

@@ -24,9 +24,11 @@ from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import MetaData, create_engine
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
from backend.data.block import BlockInput
from backend.data.execution import GraphExecutionWithNodes
from backend.data.model import CredentialsMetaInput
from backend.data.onboarding import increment_onboarding_runs
from backend.executor import utils as execution_utils
from backend.monitoring import (
NotificationJobArgs,
@@ -36,11 +38,7 @@ from backend.monitoring import (
report_execution_accuracy_alerts,
report_late_executions,
)
from backend.util.clients import (
get_database_manager_async_client,
get_database_manager_client,
get_scheduler_client,
)
from backend.util.clients import get_database_manager_client, get_scheduler_client
from backend.util.cloud_storage import cleanup_expired_files_async
from backend.util.exceptions import (
GraphNotFoundError,
@@ -150,7 +148,6 @@ def execute_graph(**kwargs):
async def _execute_graph(**kwargs):
args = GraphExecutionJobArgs(**kwargs)
start_time = asyncio.get_event_loop().time()
db = get_database_manager_async_client()
try:
logger.info(f"Executing recurring job for graph #{args.graph_id}")
graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution(
@@ -160,7 +157,7 @@ async def _execute_graph(**kwargs):
inputs=args.input_data,
graph_credentials_inputs=args.input_credentials,
)
await db.increment_onboarding_runs(args.user_id)
await increment_onboarding_runs(args.user_id)
elapsed = asyncio.get_event_loop().time() - start_time
logger.info(
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
@@ -249,13 +246,8 @@ def cleanup_expired_files():
def cleanup_oauth_tokens():
"""Clean up expired OAuth tokens from the database."""
# Wait for completion
async def _cleanup():
db = get_database_manager_async_client()
return await db.cleanup_expired_oauth_tokens()
run_async(_cleanup())
run_async(cleanup_expired_oauth_tokens())
def execution_accuracy_alerts():

Some files were not shown because too many files have changed in this diff Show More