Compare commits

..

1 Commits

Author SHA1 Message Date
Zamil Majdy
2fdc311293 Test check 2026-01-23 10:37:37 -05:00
280 changed files with 6078 additions and 21153 deletions

View File

@@ -29,7 +29,8 @@
"postCreateCmd": [
"cd autogpt_platform/autogpt_libs && poetry install",
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
"cd autogpt_platform/frontend && pnpm install"
"cd autogpt_platform/frontend && pnpm install",
"cd docs && pip install -r requirements.txt"
],
"terminalCommand": "code .",
"deleteBranchWithWorktree": false

View File

@@ -160,7 +160,7 @@ pnpm storybook # Start component development server
**Backend Entry Points:**
- `backend/backend/api/rest_api.py` - FastAPI application setup
- `backend/backend/server/server.py` - FastAPI application setup
- `backend/backend/data/` - Database models and user management
- `backend/blocks/` - Agent execution blocks and logic
@@ -219,7 +219,7 @@ Agents are built using a visual block-based system where each block performs a s
### API Development
1. Update routes in `/backend/backend/api/features/`
1. Update routes in `/backend/backend/server/routers/`
2. Add/update Pydantic models in same directory
3. Write tests alongside route files
4. For `data/*.py` changes, validate user ID checks
@@ -285,7 +285,7 @@ Agents are built using a visual block-based system where each block performs a s
### Security Guidelines
**Cache Protection Middleware** (`/backend/backend/api/middleware/security.py`):
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)

2
.gitignore vendored
View File

@@ -178,6 +178,4 @@ autogpt_platform/backend/settings.py
*.ign.*
.test-contents
.claude/settings.local.json
CLAUDE.local.md
/autogpt_platform/backend/logs
.next

View File

@@ -16,6 +16,7 @@ See `docs/content/platform/getting-started.md` for setup instructions.
- Format Python code with `poetry run format`.
- Format frontend code using `pnpm format`.
## Frontend guidelines:
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
@@ -32,17 +33,14 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
- Use function declarations for components, arrow functions only for callbacks
- No barrel files or `index.ts` re-exports
- Do not use `useCallback` or `useMemo` unless strictly needed
- Avoid comments at all times unless the code is very complex
- Do not use `useCallback` or `useMemo` unless asked to optimise a given function
- Do not type hook returns, let Typescript infer as much as possible
- Never type with `any`, if not types available use `unknown`
## Testing
@@ -51,8 +49,22 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
Always run the relevant linters and tests before committing.
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).
Types: - feat - fix - refactor - ci - dx (developer experience)
Scopes: - platform - platform/library - platform/marketplace - backend - backend/executor - frontend - frontend/library - frontend/marketplace - blocks
Types:
- feat
- fix
- refactor
- ci
- dx (developer experience)
Scopes:
- platform
- platform/library
- platform/marketplace
- backend
- backend/executor
- frontend
- frontend/library
- frontend/marketplace
- blocks
## Pull requests

View File

@@ -54,7 +54,7 @@ Before proceeding with the installation, ensure your system meets the following
### Updated Setup Instructions:
We've moved to a fully maintained and regularly updated documentation site.
👉 [Follow the official self-hosting guide here](https://agpt.co/docs/platform/getting-started/getting-started)
👉 [Follow the official self-hosting guide here](https://docs.agpt.co/platform/getting-started/)
This tutorial assumes you have Docker, VSCode, git and npm installed.

View File

@@ -6,30 +6,152 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
AutoGPT Platform is a monorepo containing:
- **Backend** (`backend`): Python FastAPI server with async support
- **Frontend** (`frontend`): Next.js React application
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
- **Backend** (`/backend`): Python FastAPI server with async support
- **Frontend** (`/frontend`): Next.js React application
- **Shared Libraries** (`/autogpt_libs`): Common Python utilities
## Component Documentation
## Essential Commands
- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks
- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns
### Backend Development
## Key Concepts
```bash
# Install dependencies
cd backend && poetry install
# Run database migrations
poetry run prisma migrate dev
# Start all services (database, redis, rabbitmq, clamav)
docker compose up -d
# Run the backend server
poetry run serve
# Run tests
poetry run test
# Run specific test
poetry run pytest path/to/test_file.py::test_function_name
# Run block tests (tests that validate all blocks work correctly)
poetry run pytest backend/blocks/test/test_block.py -xvs
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
# Lint and format
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
poetry run format # Black + isort
poetry run lint # ruff
```
More details can be found in TESTING.md
#### Creating/Updating Snapshots
When you first write a test or when the expected output changes:
```bash
poetry run pytest path/to/test.py --snapshot-update
```
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
### Frontend Development
```bash
# Install dependencies
cd frontend && pnpm i
# Generate API client from OpenAPI spec
pnpm generate:api
# Start development server
pnpm dev
# Run E2E tests
pnpm test
# Run Storybook for component development
pnpm storybook
# Build production
pnpm build
# Format and lint
pnpm format
# Type checking
pnpm types
```
**📖 Complete Guide**: See `/frontend/CONTRIBUTING.md` and `/frontend/.cursorrules` for comprehensive frontend patterns.
**Key Frontend Conventions:**
- Separate render logic from data/behavior in components
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
- Use function declarations (not arrow functions) for components/handlers
- Use design system components from `src/components/` (atoms, molecules, organisms)
- Only use Phosphor Icons
- Never use `src/components/__legacy__/*` or deprecated `BackendAPI`
## Architecture Overview
### Backend Architecture
- **API Layer**: FastAPI with REST and WebSocket endpoints
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
- **Queue System**: RabbitMQ for async task processing
- **Execution Engine**: Separate executor service processes agent workflows
- **Authentication**: JWT-based with Supabase integration
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
### Frontend Architecture
- **Framework**: Next.js 15 App Router (client-first approach)
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
- **State Management**: React Query for server state, co-located UI state in components/hooks
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
- **Workflow Builder**: Visual graph editor using @xyflow/react
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
- **Icons**: Phosphor Icons only
- **Feature Flags**: LaunchDarkly integration
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
- **Testing**: Playwright for E2E, Storybook for component development
### Key Concepts
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
3. **Integrations**: OAuth and API connections stored per user
4. **Store**: Marketplace for sharing agent templates
5. **Virus Scanning**: ClamAV integration for file upload security
### Testing Approach
- Backend uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
- Frontend uses Playwright for E2E tests
- Component testing via Storybook
### Database Schema
Key models (defined in `/backend/schema.prisma`):
- `User`: Authentication and profile data
- `AgentGraph`: Workflow definitions with version control
- `AgentGraphExecution`: Execution history and results
- `AgentNode`: Individual nodes in a workflow
- `StoreListing`: Marketplace listings for sharing agents
### Environment Configuration
#### Configuration Files
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
- **Backend**: `/backend/.env.default` (defaults) → `/backend/.env` (user overrides)
- **Frontend**: `/frontend/.env.default` (defaults) → `/frontend/.env` (user overrides)
- **Platform**: `/.env.default` (Supabase/shared defaults) → `/.env` (user overrides)
#### Docker Environment Loading Order
@@ -45,12 +167,83 @@ AutoGPT Platform is a monorepo containing:
- Backend/Frontend services use YAML anchors for consistent configuration
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
### Common Development Tasks
**Adding a new block:**
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
- Provider configuration with `ProviderBuilder`
- Block schema definition
- Authentication (API keys, OAuth, webhooks)
- Testing and validation
- File organization
Quick steps:
1. Create new file in `/backend/backend/blocks/`
2. Configure provider using `ProviderBuilder` in `_config.py`
3. Inherit from `Block` base class
4. Define input/output schemas using `BlockSchema`
5. Implement async `run` method
6. Generate unique block ID using `uuid.uuid4()`
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
ex: do the inputs and outputs tie well together?
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
**Modifying the API:**
1. Update route in `/backend/backend/server/routers/`
2. Add/update Pydantic models in same directory
3. Write tests alongside the route file
4. Run `poetry run test` to verify
### Frontend guidelines:
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
- Add `usePageName.ts` hook for logic
- Put sub-components in local `components/` folder
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
- Use design system components from `src/components/` (atoms, molecules, organisms)
- Never use `src/components/__legacy__/*`
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
- Regenerate with `pnpm generate:api`
- Pattern: `use{Method}{Version}{OperationName}`
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
- Use function declarations for components, arrow functions only for callbacks
- No barrel files or `index.ts` re-exports
- Do not use `useCallback` or `useMemo` unless strictly needed
- Avoid comments at all times unless the code is very complex
### Security Implementation
**Cache Protection Middleware:**
- Located in `/backend/backend/server/middleware/security.py`
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses an allow list approach - only explicitly permitted paths can be cached
- Cacheable paths include: static assets (`/static/*`, `/_next/static/*`), health checks, public store pages, documentation
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
- Applied to both main API server and external API applications
### Creating Pull Requests
- Create the PR against the `dev` branch of the repository.
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
- Use conventional commit messages (see below)
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
- Create the PR aginst the `dev` branch of the repository.
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)/
- Use conventional commit messages (see below)/
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description/
- Run the github pre-commit hooks to ensure code quality.
### Reviewing/Revising Pull Requests

View File

@@ -178,10 +178,5 @@ AYRSHARE_JWT_KEY=
SMARTLEAD_API_KEY=
ZEROBOUNCE_API_KEY=
# PostHog Analytics
# Get API key from https://posthog.com - Project Settings > Project API Key
POSTHOG_API_KEY=
POSTHOG_HOST=https://eu.i.posthog.com
# Other Services
AUTOMOD_API_KEY=

View File

@@ -1,170 +0,0 @@
# CLAUDE.md - Backend
This file provides guidance to Claude Code when working with the backend.
## Essential Commands
To run something with Python package dependencies you MUST use `poetry run ...`.
```bash
# Install dependencies
poetry install
# Run database migrations
poetry run prisma migrate dev
# Start all services (database, redis, rabbitmq, clamav)
docker compose up -d
# Run the backend as a whole
poetry run app
# Run tests
poetry run test
# Run specific test
poetry run pytest path/to/test_file.py::test_function_name
# Run block tests (tests that validate all blocks work correctly)
poetry run pytest backend/blocks/test/test_block.py -xvs
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
# Lint and format
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
poetry run format # Black + isort
poetry run lint # ruff
```
More details can be found in @TESTING.md
### Creating/Updating Snapshots
When you first write a test or when the expected output changes:
```bash
poetry run pytest path/to/test.py --snapshot-update
```
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
## Architecture
- **API Layer**: FastAPI with REST and WebSocket endpoints
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
- **Queue System**: RabbitMQ for async task processing
- **Execution Engine**: Separate executor service processes agent workflows
- **Authentication**: JWT-based with Supabase integration
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
## Testing Approach
- Uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
## Database Schema
Key models (defined in `schema.prisma`):
- `User`: Authentication and profile data
- `AgentGraph`: Workflow definitions with version control
- `AgentGraphExecution`: Execution history and results
- `AgentNode`: Individual nodes in a workflow
- `StoreListing`: Marketplace listings for sharing agents
## Environment Configuration
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
## Common Development Tasks
### Adding a new block
Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers:
- Provider configuration with `ProviderBuilder`
- Block schema definition
- Authentication (API keys, OAuth, webhooks)
- Testing and validation
- File organization
Quick steps:
1. Create new file in `backend/blocks/`
2. Configure provider using `ProviderBuilder` in `_config.py`
3. Inherit from `Block` base class
4. Define input/output schemas using `BlockSchema`
5. Implement async `run` method
6. Generate unique block ID using `uuid.uuid4()`
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
ex: do the inputs and outputs tie well together?
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
#### Handling files in blocks with `store_media_file()`
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
| Format | Use When | Returns |
|--------|----------|---------|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
**Examples:**
```python
# INPUT: Need to process file locally with ffmpeg
local_path = await store_media_file(
file=input_data.video,
execution_context=execution_context,
return_format="for_local_processing",
)
# local_path = "video.mp4" - use with Path/ffmpeg/etc
# INPUT: Need to send to external API like Replicate
image_b64 = await store_media_file(
file=input_data.image,
execution_context=execution_context,
return_format="for_external_api",
)
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
# OUTPUT: Returning result from block
result_url = await store_media_file(
file=generated_image_url,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", result_url
# In CoPilot: result_url = "workspace://abc123"
# In graphs: result_url = "data:image/png;base64,..."
```
**Key points:**
- `for_block_output` is the ONLY format that auto-adapts to execution context
- Always use `for_block_output` for block outputs unless you have a specific reason not to
- Never hardcode workspace checks - let `for_block_output` handle it
### Modifying the API
1. Update route in `backend/api/features/`
2. Add/update Pydantic models in same directory
3. Write tests alongside the route file
4. Run `poetry run test` to verify
## Security Implementation
### Cache Protection Middleware
- Located in `backend/api/middleware/security.py`
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses an allow list approach - only explicitly permitted paths can be cached
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
- Applied to both main API server and external API applications

View File

@@ -138,7 +138,7 @@ If the test doesn't need the `user_id` specifically, mocking is not necessary as
#### Using Global Auth Fixtures
Two global auth fixtures are provided by `backend/api/conftest.py`:
Two global auth fixtures are provided by `backend/server/conftest.py`:
- `mock_jwt_user` - Regular user with `test_user_id` ("test-user-id")
- `mock_jwt_admin` - Admin user with `admin_user_id` ("admin-user-id")

View File

@@ -86,8 +86,6 @@ async def execute_graph_block(
obj = backend.data.block.get_block(block_id)
if not obj:
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
if obj.disabled:
raise HTTPException(status_code=403, detail=f"Block #{block_id} is disabled.")
output = defaultdict(list)
async for name, data in obj.execute(data):

View File

@@ -17,7 +17,7 @@ router = fastapi.APIRouter(
)
# Taken from backend/api/features/store/db.py
# Taken from backend/server/v2/store/db.py
def sanitize_query(query: str | None) -> str | None:
if query is None:
return query

View File

@@ -1,368 +0,0 @@
"""Redis Streams consumer for operation completion messages.
This module provides a consumer (ChatCompletionConsumer) that listens for
completion notifications (OperationCompleteMessage) from external services
(like Agent Generator) and triggers the appropriate stream registry and
chat service updates via process_operation_success/process_operation_failure.
Why Redis Streams instead of RabbitMQ?
--------------------------------------
While the project typically uses RabbitMQ for async task queues (e.g., execution
queue), Redis Streams was chosen for chat completion notifications because:
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
Streams (via stream_registry) for message persistence and replay. Using Redis
Streams for completion notifications keeps all chat streaming infrastructure
in one system, simplifying operations and reducing cross-system coordination.
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
allowing consumers to replay missed messages after reconnection. This aligns
with the SSE reconnection pattern where clients can resume from last_message_id.
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
recovering from dead consumers - ideal for the completion callback pattern.
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
stream_registry) provides lower latency than an additional RabbitMQ hop.
5. **Atomicity with Task State**: Completion processing often needs to update
task metadata stored in Redis. Keeping both in Redis enables simpler
transactional semantics without distributed coordination.
The consumer uses Redis Streams with consumer groups for reliable message
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
stale pending messages from dead consumers.
"""
import asyncio
import logging
import os
import uuid
from typing import Any
import orjson
from prisma import Prisma
from pydantic import BaseModel
from redis.exceptions import ResponseError
from backend.data.redis_client import get_redis_async
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
logger = logging.getLogger(__name__)
config = ChatConfig()
class OperationCompleteMessage(BaseModel):
"""Message format for operation completion notifications."""
operation_id: str
task_id: str
success: bool
result: dict | str | None = None
error: str | None = None
class ChatCompletionConsumer:
"""Consumer for chat operation completion messages from Redis Streams.
This consumer initializes its own Prisma client in start() to ensure
database operations work correctly within this async context.
Uses Redis consumer groups to allow multiple platform pods to consume
messages reliably with automatic redelivery on failure.
"""
def __init__(self):
self._consumer_task: asyncio.Task | None = None
self._running = False
self._prisma: Prisma | None = None
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
async def start(self) -> None:
"""Start the completion consumer."""
if self._running:
logger.warning("Completion consumer already running")
return
# Create consumer group if it doesn't exist
try:
redis = await get_redis_async()
await redis.xgroup_create(
config.stream_completion_name,
config.stream_consumer_group,
id="0",
mkstream=True,
)
logger.info(
f"Created consumer group '{config.stream_consumer_group}' "
f"on stream '{config.stream_completion_name}'"
)
except ResponseError as e:
if "BUSYGROUP" in str(e):
logger.debug(
f"Consumer group '{config.stream_consumer_group}' already exists"
)
else:
raise
self._running = True
self._consumer_task = asyncio.create_task(self._consume_messages())
logger.info(
f"Chat completion consumer started (consumer: {self._consumer_name})"
)
async def _ensure_prisma(self) -> Prisma:
"""Lazily initialize Prisma client on first use."""
if self._prisma is None:
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
self._prisma = Prisma(datasource={"url": database_url})
await self._prisma.connect()
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
return self._prisma
async def stop(self) -> None:
"""Stop the completion consumer."""
self._running = False
if self._consumer_task:
self._consumer_task.cancel()
try:
await self._consumer_task
except asyncio.CancelledError:
pass
self._consumer_task = None
if self._prisma:
await self._prisma.disconnect()
self._prisma = None
logger.info("[COMPLETION] Consumer Prisma client disconnected")
logger.info("Chat completion consumer stopped")
async def _consume_messages(self) -> None:
"""Main message consumption loop with retry logic."""
max_retries = 10
retry_delay = 5 # seconds
retry_count = 0
block_timeout = 5000 # milliseconds
while self._running and retry_count < max_retries:
try:
redis = await get_redis_async()
# Reset retry count on successful connection
retry_count = 0
while self._running:
# First, claim any stale pending messages from dead consumers
# Redis does NOT auto-redeliver pending messages; we must explicitly
# claim them using XAUTOCLAIM
try:
claimed_result = await redis.xautoclaim(
name=config.stream_completion_name,
groupname=config.stream_consumer_group,
consumername=self._consumer_name,
min_idle_time=config.stream_claim_min_idle_ms,
start_id="0-0",
count=10,
)
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
if claimed_result and len(claimed_result) >= 2:
claimed_entries = claimed_result[1]
if claimed_entries:
logger.info(
f"Claimed {len(claimed_entries)} stale pending messages"
)
for entry_id, data in claimed_entries:
if not self._running:
return
await self._process_entry(redis, entry_id, data)
except Exception as e:
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
# Read new messages from the stream
messages = await redis.xreadgroup(
groupname=config.stream_consumer_group,
consumername=self._consumer_name,
streams={config.stream_completion_name: ">"},
block=block_timeout,
count=10,
)
if not messages:
continue
for stream_name, entries in messages:
for entry_id, data in entries:
if not self._running:
return
await self._process_entry(redis, entry_id, data)
except asyncio.CancelledError:
logger.info("Consumer cancelled")
return
except Exception as e:
retry_count += 1
logger.error(
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
exc_info=True,
)
if self._running and retry_count < max_retries:
await asyncio.sleep(retry_delay)
else:
logger.error("Max retries reached, stopping consumer")
return
async def _process_entry(
self, redis: Any, entry_id: str, data: dict[str, Any]
) -> None:
"""Process a single stream entry and acknowledge it on success.
Args:
redis: Redis client connection
entry_id: The stream entry ID
data: The entry data dict
"""
try:
# Handle the message
message_data = data.get("data")
if message_data:
await self._handle_message(
message_data.encode()
if isinstance(message_data, str)
else message_data
)
# Acknowledge the message after successful processing
await redis.xack(
config.stream_completion_name,
config.stream_consumer_group,
entry_id,
)
except Exception as e:
logger.error(
f"Error processing completion message {entry_id}: {e}",
exc_info=True,
)
# Message remains in pending state and will be claimed by
# XAUTOCLAIM after min_idle_time expires
async def _handle_message(self, body: bytes) -> None:
"""Handle a completion message using our own Prisma client."""
try:
data = orjson.loads(body)
message = OperationCompleteMessage(**data)
except Exception as e:
logger.error(f"Failed to parse completion message: {e}")
return
logger.info(
f"[COMPLETION] Received completion for operation {message.operation_id} "
f"(task_id={message.task_id}, success={message.success})"
)
# Find task in registry
task = await stream_registry.find_task_by_operation_id(message.operation_id)
if task is None:
task = await stream_registry.get_task(message.task_id)
if task is None:
logger.warning(
f"[COMPLETION] Task not found for operation {message.operation_id} "
f"(task_id={message.task_id})"
)
return
logger.info(
f"[COMPLETION] Found task: task_id={task.task_id}, "
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
)
# Guard against empty task fields
if not task.task_id or not task.session_id or not task.tool_call_id:
logger.error(
f"[COMPLETION] Task has empty critical fields! "
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
f"tool_call_id={task.tool_call_id!r}"
)
return
if message.success:
await self._handle_success(task, message)
else:
await self._handle_failure(task, message)
async def _handle_success(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle successful operation completion."""
prisma = await self._ensure_prisma()
await process_operation_success(task, message.result, prisma)
async def _handle_failure(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle failed operation completion."""
prisma = await self._ensure_prisma()
await process_operation_failure(task, message.error, prisma)
# Module-level consumer instance
_consumer: ChatCompletionConsumer | None = None
async def start_completion_consumer() -> None:
"""Start the global completion consumer."""
global _consumer
if _consumer is None:
_consumer = ChatCompletionConsumer()
await _consumer.start()
async def stop_completion_consumer() -> None:
"""Stop the global completion consumer."""
global _consumer
if _consumer:
await _consumer.stop()
_consumer = None
async def publish_operation_complete(
operation_id: str,
task_id: str,
success: bool,
result: dict | str | None = None,
error: str | None = None,
) -> None:
"""Publish an operation completion message to Redis Streams.
Args:
operation_id: The operation ID that completed.
task_id: The task ID associated with the operation.
success: Whether the operation succeeded.
result: The result data (for success).
error: The error message (for failure).
"""
message = OperationCompleteMessage(
operation_id=operation_id,
task_id=task_id,
success=success,
result=result,
error=error,
)
redis = await get_redis_async()
await redis.xadd(
config.stream_completion_name,
{"data": message.model_dump_json()},
maxlen=config.stream_max_length,
)
logger.info(f"Published completion for operation {operation_id}")

View File

@@ -1,344 +0,0 @@
"""Shared completion handling for operation success and failure.
This module provides common logic for handling operation completion from both:
- The Redis Streams consumer (completion_consumer.py)
- The HTTP webhook endpoint (routes.py)
"""
import logging
from typing import Any
import orjson
from prisma import Prisma
from . import service as chat_service
from . import stream_registry
from .response_model import StreamError, StreamToolOutputAvailable
from .tools.models import ErrorResponse
logger = logging.getLogger(__name__)
# Tools that produce agent_json that needs to be saved to library
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
# Keys that should be stripped from agent_json when returning in error responses
SENSITIVE_KEYS = frozenset(
{
"api_key",
"apikey",
"api_secret",
"password",
"secret",
"credentials",
"credential",
"token",
"access_token",
"refresh_token",
"private_key",
"privatekey",
"auth",
"authorization",
}
)
def _sanitize_agent_json(obj: Any) -> Any:
"""Recursively sanitize agent_json by removing sensitive keys.
Args:
obj: The object to sanitize (dict, list, or primitive)
Returns:
Sanitized copy with sensitive keys removed/redacted
"""
if isinstance(obj, dict):
return {
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
for k, v in obj.items()
}
elif isinstance(obj, list):
return [_sanitize_agent_json(item) for item in obj]
else:
return obj
class ToolMessageUpdateError(Exception):
"""Raised when updating a tool message in the database fails."""
pass
async def _update_tool_message(
session_id: str,
tool_call_id: str,
content: str,
prisma_client: Prisma | None,
) -> None:
"""Update tool message in database.
Args:
session_id: The session ID
tool_call_id: The tool call ID to update
content: The new content for the message
prisma_client: Optional Prisma client. If None, uses chat_service.
Raises:
ToolMessageUpdateError: If the database update fails. The caller should
handle this to avoid marking the task as completed with inconsistent state.
"""
try:
if prisma_client:
# Use provided Prisma client (for consumer with its own connection)
updated_count = await prisma_client.chatmessage.update_many(
where={
"sessionId": session_id,
"toolCallId": tool_call_id,
},
data={"content": content},
)
# Check if any rows were updated - 0 means message not found
if updated_count == 0:
raise ToolMessageUpdateError(
f"No message found with tool_call_id={tool_call_id} in session {session_id}"
)
else:
# Use service function (for webhook endpoint)
await chat_service._update_pending_operation(
session_id=session_id,
tool_call_id=tool_call_id,
result=content,
)
except ToolMessageUpdateError:
raise
except Exception as e:
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
raise ToolMessageUpdateError(
f"Failed to update tool message for tool_call_id={tool_call_id}: {e}"
) from e
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
"""Serialize result to JSON string with sensible defaults.
Args:
result: The result to serialize. Can be a dict, list, string,
number, boolean, or None.
Returns:
JSON string representation of the result. Returns '{"status": "completed"}'
only when result is explicitly None.
"""
if isinstance(result, str):
return result
if result is None:
return '{"status": "completed"}'
return orjson.dumps(result).decode("utf-8")
async def _save_agent_from_result(
result: dict[str, Any],
user_id: str | None,
tool_name: str,
) -> dict[str, Any]:
"""Save agent to library if result contains agent_json.
Args:
result: The result dict that may contain agent_json
user_id: The user ID to save the agent for
tool_name: The tool name (create_agent or edit_agent)
Returns:
Updated result dict with saved agent details, or original result if no agent_json
"""
if not user_id:
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
return result
agent_json = result.get("agent_json")
if not agent_json:
logger.warning(
f"[COMPLETION] {tool_name} completed but no agent_json in result"
)
return result
try:
from .tools.agent_generator import save_agent_to_library
is_update = tool_name == "edit_agent"
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id, is_update=is_update
)
logger.info(
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
)
# Return a response similar to AgentSavedResponse
return {
"type": "agent_saved",
"message": f"Agent '{created_graph.name}' has been saved to your library!",
"agent_id": created_graph.id,
"agent_name": created_graph.name,
"library_agent_id": library_agent.id,
"library_agent_link": f"/library/agents/{library_agent.id}",
"agent_page_link": f"/build?flowID={created_graph.id}",
}
except Exception as e:
logger.error(
f"[COMPLETION] Failed to save agent to library: {e}",
exc_info=True,
)
# Return error but don't fail the whole operation
# Sanitize agent_json to remove sensitive keys before returning
return {
"type": "error",
"message": f"Agent was generated but failed to save: {str(e)}",
"error": str(e),
"agent_json": _sanitize_agent_json(agent_json),
}
async def process_operation_success(
task: stream_registry.ActiveTask,
result: dict | str | None,
prisma_client: Prisma | None = None,
) -> None:
"""Handle successful operation completion.
Publishes the result to the stream registry, updates the database,
generates LLM continuation, and marks the task as completed.
Args:
task: The active task that completed
result: The result data from the operation
prisma_client: Optional Prisma client for database operations.
If None, uses chat_service._update_pending_operation instead.
Raises:
ToolMessageUpdateError: If the database update fails. The task will be
marked as failed instead of completed to avoid inconsistent state.
"""
# For agent generation tools, save the agent to library
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
# Serialize result for output (only substitute default when result is exactly None)
result_output = result if result is not None else {"status": "completed"}
output_str = (
result_output
if isinstance(result_output, str)
else orjson.dumps(result_output).decode("utf-8")
)
# Publish result to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamToolOutputAvailable(
toolCallId=task.tool_call_id,
toolName=task.tool_name,
output=output_str,
success=True,
),
)
# Update pending operation in database
# If this fails, we must not continue to mark the task as completed
result_str = serialize_result(result)
try:
await _update_tool_message(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
content=result_str,
prisma_client=prisma_client,
)
except ToolMessageUpdateError:
# DB update failed - mark task as failed to avoid inconsistent state
logger.error(
f"[COMPLETION] DB update failed for task {task.task_id}, "
"marking as failed instead of completed"
)
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText="Failed to save operation result to database"),
)
await stream_registry.mark_task_completed(task.task_id, status="failed")
raise
# Generate LLM continuation with streaming
try:
await chat_service._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
except Exception as e:
logger.error(
f"[COMPLETION] Failed to generate LLM continuation: {e}",
exc_info=True,
)
# Mark task as completed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="completed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
)
async def process_operation_failure(
task: stream_registry.ActiveTask,
error: str | None,
prisma_client: Prisma | None = None,
) -> None:
"""Handle failed operation completion.
Publishes the error to the stream registry, updates the database with
the error response, and marks the task as failed.
Args:
task: The active task that failed
error: The error message from the operation
prisma_client: Optional Prisma client for database operations.
If None, uses chat_service._update_pending_operation instead.
"""
error_msg = error or "Operation failed"
# Publish error to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText=error_msg),
)
# Update pending operation with error
# If this fails, we still continue to mark the task as failed
error_response = ErrorResponse(
message=error_msg,
error=error,
)
try:
await _update_tool_message(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
content=error_response.model_dump_json(),
prisma_client=prisma_client,
)
except ToolMessageUpdateError:
# DB update failed - log but continue with cleanup
logger.error(
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
"continuing with cleanup"
)
# Mark task as failed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="failed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")

View File

@@ -33,57 +33,9 @@ class ChatConfig(BaseSettings):
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(default=3, description="Maximum number of retries")
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
max_agent_runs: int = Field(default=3, description="Maximum number of agent runs")
max_agent_schedules: int = Field(
default=30, description="Maximum number of agent schedules"
)
# Long-running operation configuration
long_running_operation_ttl: int = Field(
default=600,
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
)
# Stream registry configuration for SSE reconnection
stream_ttl: int = Field(
default=3600,
description="TTL in seconds for stream data in Redis (1 hour)",
)
stream_max_length: int = Field(
default=10000,
description="Maximum number of messages to store per stream",
)
# Redis Streams configuration for completion consumer
stream_completion_name: str = Field(
default="chat:completions",
description="Redis Stream name for operation completions",
)
stream_consumer_group: str = Field(
default="chat_consumers",
description="Consumer group name for completion stream",
)
stream_claim_min_idle_ms: int = Field(
default=60000,
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
)
# Redis key prefixes for stream registry
task_meta_prefix: str = Field(
default="chat:task:meta:",
description="Prefix for task metadata hash keys",
)
task_stream_prefix: str = Field(
default="chat:stream:",
description="Prefix for task message stream keys",
)
task_op_prefix: str = Field(
default="chat:task:op:",
description="Prefix for operation ID to task ID mapping keys",
)
internal_api_key: str | None = Field(
default=None,
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
default=3, description="Maximum number of agent schedules"
)
# Langfuse Prompt Management Configuration
@@ -124,14 +76,6 @@ class ChatConfig(BaseSettings):
v = "https://openrouter.ai/api/v1"
return v
@field_validator("internal_api_key", mode="before")
@classmethod
def get_internal_api_key(cls, v):
"""Get internal API key from environment if not provided."""
if v is None:
v = os.getenv("CHAT_INTERNAL_API_KEY")
return v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",

View File

@@ -247,45 +247,3 @@ async def get_chat_session_message_count(session_id: str) -> int:
"""Get the number of messages in a chat session."""
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
return count
async def update_tool_message_content(
session_id: str,
tool_call_id: str,
new_content: str,
) -> bool:
"""Update the content of a tool message in chat history.
Used by background tasks to update pending operation messages with final results.
Args:
session_id: The chat session ID.
tool_call_id: The tool call ID to find the message.
new_content: The new content to set.
Returns:
True if a message was updated, False otherwise.
"""
try:
result = await PrismaChatMessage.prisma().update_many(
where={
"sessionId": session_id,
"toolCallId": tool_call_id,
},
data={
"content": new_content,
},
)
if result == 0:
logger.warning(
f"No message found to update for session {session_id}, "
f"tool_call_id {tool_call_id}"
)
return False
return True
except Exception as e:
logger.error(
f"Failed to update tool message for session {session_id}, "
f"tool_call_id {tool_call_id}: {e}"
)
return False

View File

@@ -295,21 +295,6 @@ async def cache_chat_session(session: ChatSession) -> None:
await _cache_session(session)
async def invalidate_session_cache(session_id: str) -> None:
"""Invalidate a chat session from Redis cache.
Used by background tasks to ensure fresh data is loaded on next access.
This is best-effort - Redis failures are logged but don't fail the operation.
"""
try:
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
except Exception as e:
# Best-effort: log but don't fail - cache will expire naturally
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
async def _get_session_from_db(session_id: str) -> ChatSession | None:
"""Get a chat session from the database."""
prisma_session = await chat_db.get_chat_session(session_id)

View File

@@ -31,7 +31,6 @@ class ResponseType(str, Enum):
# Other
ERROR = "error"
USAGE = "usage"
HEARTBEAT = "heartbeat"
class StreamBaseResponse(BaseModel):
@@ -52,10 +51,6 @@ class StreamStart(StreamBaseResponse):
type: ResponseType = ResponseType.START
messageId: str = Field(..., description="Unique message ID")
taskId: str | None = Field(
default=None,
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
)
class StreamFinish(StreamBaseResponse):
@@ -147,20 +142,3 @@ class StreamError(StreamBaseResponse):
details: dict[str, Any] | None = Field(
default=None, description="Additional error details"
)
class StreamHeartbeat(StreamBaseResponse):
"""Heartbeat to keep SSE connection alive during long-running operations.
Uses SSE comment format (: comment) which is ignored by clients but keeps
the connection alive through proxies and load balancers.
"""
type: ResponseType = ResponseType.HEARTBEAT
toolCallId: str | None = Field(
default=None, description="Tool call ID if heartbeat is for a specific tool"
)
def to_sse(self) -> str:
"""Convert to SSE comment format to keep connection alive."""
return ": heartbeat\n\n"

View File

@@ -1,23 +1,19 @@
"""Chat API routes for chat session management and streaming via SSE."""
import logging
import uuid as uuid_module
from collections.abc import AsyncGenerator
from typing import Annotated
from autogpt_libs import auth
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
from fastapi import APIRouter, Depends, Query, Security
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from backend.util.exceptions import NotFoundError
from . import service as chat_service
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
config = ChatConfig()
@@ -59,15 +55,6 @@ class CreateSessionResponse(BaseModel):
user_id: str | None
class ActiveStreamInfo(BaseModel):
"""Information about an active stream for reconnection."""
task_id: str
last_message_id: str # Redis Stream message ID for resumption
operation_id: str # Operation ID for completion tracking
tool_name: str # Name of the tool being executed
class SessionDetailResponse(BaseModel):
"""Response model providing complete details for a chat session, including messages."""
@@ -76,7 +63,6 @@ class SessionDetailResponse(BaseModel):
updated_at: str
user_id: str | None
messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
class SessionSummaryResponse(BaseModel):
@@ -95,14 +81,6 @@ class ListSessionsResponse(BaseModel):
total: int
class OperationCompleteRequest(BaseModel):
"""Request model for external completion webhook."""
success: bool
result: dict | str | None = None
error: str | None = None
# ========== Routes ==========
@@ -188,14 +166,13 @@ async def get_session(
Retrieve the details of a specific chat session.
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
If there's an active stream for this session, returns the task_id for reconnection.
Args:
session_id: The unique identifier for the desired chat session.
user_id: The optional authenticated user ID, or None for anonymous access.
Returns:
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
SessionDetailResponse: Details for the requested session, or None if not found.
"""
session = await get_chat_session(session_id, user_id)
@@ -203,28 +180,11 @@ async def get_session(
raise NotFoundError(f"Session {session_id} not found.")
messages = [message.model_dump() for message in session.messages]
# Check if there's an active stream for this session
active_stream_info = None
active_task, last_message_id = await stream_registry.get_active_task_for_session(
session_id, user_id
logger.info(
f"Returning session {session_id}: "
f"message_count={len(messages)}, "
f"roles={[m.get('role') for m in messages]}"
)
if active_task:
# Filter out the in-progress assistant message from the session response.
# The client will receive the complete assistant response through the SSE
# stream replay instead, preventing duplicate content.
if messages and messages[-1].get("role") == "assistant":
messages = messages[:-1]
# Use "0-0" as last_message_id to replay the stream from the beginning.
# Since we filtered out the cached assistant message, the client needs
# the full stream to reconstruct the response.
active_stream_info = ActiveStreamInfo(
task_id=active_task.task_id,
last_message_id="0-0",
operation_id=active_task.operation_id,
tool_name=active_task.tool_name,
)
return SessionDetailResponse(
id=session.session_id,
@@ -232,7 +192,6 @@ async def get_session(
updated_at=session.updated_at.isoformat(),
user_id=session.user_id or None,
messages=messages,
active_stream=active_stream_info,
)
@@ -252,112 +211,49 @@ async def stream_chat_post(
- Tool call UI elements (if invoked)
- Tool execution results
The AI generation runs in a background task that continues even if the client disconnects.
All chunks are written to Redis for reconnection support. If the client disconnects,
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
Args:
session_id: The chat session identifier to associate with the streamed messages.
request: Request body containing message, is_user_message, and optional context.
user_id: Optional authenticated user ID.
Returns:
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
containing the task_id for reconnection.
StreamingResponse: SSE-formatted response chunks.
"""
import asyncio
session = await _validate_and_get_session(session_id, user_id)
# Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4())
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id="chat_stream", # Not a tool call, but needed for the model
tool_name="chat",
operation_id=operation_id,
)
# Background task that runs the AI generation independently of SSE connection
async def run_ai_generation():
try:
# Emit a start event with task_id for reconnection
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
await stream_registry.publish_chunk(task_id, start_chunk)
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
):
# Write to Redis (subscribers will receive via XREAD)
await stream_registry.publish_chunk(task_id, chunk)
# Mark task as completed
await stream_registry.mark_task_completed(task_id, "completed")
except Exception as e:
logger.error(
f"Error in background AI generation for session {session_id}: {e}"
)
await stream_registry.mark_task_completed(task_id, "failed")
# Start the AI generation in a background task
bg_task = asyncio.create_task(run_ai_generation())
await stream_registry.set_task_asyncio_task(task_id, bg_task)
# SSE endpoint that subscribes to the task's stream
async def event_generator() -> AsyncGenerator[str, None]:
subscriber_queue = None
try:
# Subscribe to the task stream (this replays existing messages + live updates)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id="0-0", # Get all messages from the beginning
)
if subscriber_queue is None:
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return
# Read from the subscriber queue and yield to SSE
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse()
except GeneratorExit:
pass # Client disconnected - background task continues
except Exception as e:
logger.error(f"Error in SSE stream for task {task_id}: {e}")
finally:
# Unsubscribe when client disconnects or stream ends to prevent resource leak
if subscriber_queue is not None:
try:
await stream_registry.unsubscribe_from_task(
task_id, subscriber_queue
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {task_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
yield "data: [DONE]\n\n"
chunk_count = 0
first_chunk_type: str | None = None
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
logger.info(
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
@@ -470,251 +366,6 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Task Streaming (SSE Reconnection) ==========
@router.get(
"/tasks/{task_id}/stream",
)
async def stream_task(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
last_message_id: str = Query(
default="0-0",
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
),
):
"""
Reconnect to a long-running task's SSE stream.
When a long-running operation (like agent generation) starts, the client
receives a task_id. If the connection drops, the client can reconnect
using this endpoint to resume receiving updates.
Args:
task_id: The task ID from the operation_started response.
user_id: Authenticated user ID for ownership validation.
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
Returns:
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
Raises:
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
"""
# Check task existence and expiry before subscribing
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
if error_code == "TASK_EXPIRED":
raise HTTPException(
status_code=410,
detail={
"code": "TASK_EXPIRED",
"message": "This operation has expired. Please try again.",
},
)
if error_code == "TASK_NOT_FOUND":
raise HTTPException(
status_code=404,
detail={
"code": "TASK_NOT_FOUND",
"message": f"Task {task_id} not found.",
},
)
# Validate ownership if task has an owner
if task and task.user_id and user_id != task.user_id:
raise HTTPException(
status_code=403,
detail={
"code": "ACCESS_DENIED",
"message": "You do not have access to this task.",
},
)
# Get subscriber queue from stream registry
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id=last_message_id,
)
if subscriber_queue is None:
raise HTTPException(
status_code=404,
detail={
"code": "TASK_NOT_FOUND",
"message": f"Task {task_id} not found or access denied.",
},
)
async def event_generator() -> AsyncGenerator[str, None]:
import asyncio
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
try:
while True:
try:
# Wait for next chunk with timeout for heartbeats
chunk = await asyncio.wait_for(
subscriber_queue.get(), timeout=heartbeat_interval
)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse()
except Exception as e:
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
finally:
# Unsubscribe when client disconnects or stream ends
try:
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {task_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
},
)
@router.get(
"/tasks/{task_id}",
)
async def get_task_status(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
) -> dict:
"""
Get the status of a long-running task.
Args:
task_id: The task ID to check.
user_id: Authenticated user ID for ownership validation.
Returns:
dict: Task status including task_id, status, tool_name, and operation_id.
Raises:
NotFoundError: If task_id is not found or user doesn't have access.
"""
task = await stream_registry.get_task(task_id)
if task is None:
raise NotFoundError(f"Task {task_id} not found.")
# Validate ownership - if task has an owner, requester must match
if task.user_id and user_id != task.user_id:
raise NotFoundError(f"Task {task_id} not found.")
return {
"task_id": task.task_id,
"session_id": task.session_id,
"status": task.status,
"tool_name": task.tool_name,
"operation_id": task.operation_id,
"created_at": task.created_at.isoformat(),
}
# ========== External Completion Webhook ==========
@router.post(
"/operations/{operation_id}/complete",
status_code=200,
)
async def complete_operation(
operation_id: str,
request: OperationCompleteRequest,
x_api_key: str | None = Header(default=None),
) -> dict:
"""
External completion webhook for long-running operations.
Called by Agent Generator (or other services) when an operation completes.
This triggers the stream registry to publish completion and continue LLM generation.
Args:
operation_id: The operation ID to complete.
request: Completion payload with success status and result/error.
x_api_key: Internal API key for authentication.
Returns:
dict: Status of the completion.
Raises:
HTTPException: If API key is invalid or operation not found.
"""
# Validate internal API key - reject if not configured or invalid
if not config.internal_api_key:
logger.error(
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
)
raise HTTPException(
status_code=503,
detail="Webhook not available: internal API key not configured",
)
if x_api_key != config.internal_api_key:
raise HTTPException(status_code=401, detail="Invalid API key")
# Find task by operation_id
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None:
raise HTTPException(
status_code=404,
detail=f"Operation {operation_id} not found",
)
logger.info(
f"Received completion webhook for operation {operation_id} "
f"(task_id={task.task_id}, success={request.success})"
)
if request.success:
await process_operation_success(task, request.result)
else:
await process_operation_failure(task, request.error)
return {"status": "ok", "task_id": task.task_id}
# ========== Configuration ==========
@router.get("/config/ttl", status_code=200)
async def get_ttl_config() -> dict:
"""
Get the stream TTL configuration.
Returns the Time-To-Live settings for chat streams, which determines
how long clients can reconnect to an active stream.
Returns:
dict: TTL configuration with seconds and milliseconds values.
"""
return {
"stream_ttl_seconds": config.stream_ttl,
"stream_ttl_ms": config.stream_ttl * 1000,
}
# ========== Health Check ==========

File diff suppressed because it is too large Load Diff

View File

@@ -1,704 +0,0 @@
"""Stream registry for managing reconnectable SSE streams.
This module provides a registry for tracking active streaming tasks and their
messages. It uses Redis for all state management (no in-memory state), making
pods stateless and horizontally scalable.
Architecture:
- Redis Stream: Persists all messages for replay and real-time delivery
- Redis Hash: Task metadata (status, session_id, etc.)
Subscribers:
1. Replay missed messages from Redis Stream (XREAD)
2. Listen for live updates via blocking XREAD
3. No in-memory state required on the subscribing pod
"""
import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Literal
import orjson
from backend.data.redis_client import get_redis_async
from .config import ChatConfig
from .response_model import StreamBaseResponse, StreamError, StreamFinish
logger = logging.getLogger(__name__)
config = ChatConfig()
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
_local_tasks: dict[str, asyncio.Task] = {}
# Track listener tasks per subscriber queue for cleanup
# Maps queue id() to (task_id, asyncio.Task) for proper cleanup on unsubscribe
_listener_tasks: dict[int, tuple[str, asyncio.Task]] = {}
# Timeout for putting chunks into subscriber queues (seconds)
# If the queue is full and doesn't drain within this time, send an overflow error
QUEUE_PUT_TIMEOUT = 5.0
# Lua script for atomic compare-and-swap status update (idempotent completion)
# Returns 1 if status was updated, 0 if already completed/failed
COMPLETE_TASK_SCRIPT = """
local current = redis.call("HGET", KEYS[1], "status")
if current == "running" then
redis.call("HSET", KEYS[1], "status", ARGV[1])
return 1
end
return 0
"""
@dataclass
class ActiveTask:
"""Represents an active streaming task (metadata only, no in-memory queues)."""
task_id: str
session_id: str
user_id: str | None
tool_call_id: str
tool_name: str
operation_id: str
status: Literal["running", "completed", "failed"] = "running"
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
asyncio_task: asyncio.Task | None = None
def _get_task_meta_key(task_id: str) -> str:
"""Get Redis key for task metadata."""
return f"{config.task_meta_prefix}{task_id}"
def _get_task_stream_key(task_id: str) -> str:
"""Get Redis key for task message stream."""
return f"{config.task_stream_prefix}{task_id}"
def _get_operation_mapping_key(operation_id: str) -> str:
"""Get Redis key for operation_id to task_id mapping."""
return f"{config.task_op_prefix}{operation_id}"
async def create_task(
task_id: str,
session_id: str,
user_id: str | None,
tool_call_id: str,
tool_name: str,
operation_id: str,
) -> ActiveTask:
"""Create a new streaming task in Redis.
Args:
task_id: Unique identifier for the task
session_id: Chat session ID
user_id: User ID (may be None for anonymous)
tool_call_id: Tool call ID from the LLM
tool_name: Name of the tool being executed
operation_id: Operation ID for webhook callbacks
Returns:
The created ActiveTask instance (metadata only)
"""
task = ActiveTask(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
operation_id=operation_id,
)
# Store metadata in Redis
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
op_key = _get_operation_mapping_key(operation_id)
await redis.hset( # type: ignore[misc]
meta_key,
mapping={
"task_id": task_id,
"session_id": session_id,
"user_id": user_id or "",
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"operation_id": operation_id,
"status": task.status,
"created_at": task.created_at.isoformat(),
},
)
await redis.expire(meta_key, config.stream_ttl)
# Create operation_id -> task_id mapping for webhook lookups
await redis.set(op_key, task_id, ex=config.stream_ttl)
logger.debug(f"Created task {task_id} for session {session_id}")
return task
async def publish_chunk(
task_id: str,
chunk: StreamBaseResponse,
) -> str:
"""Publish a chunk to Redis Stream.
All delivery is via Redis Streams - no in-memory state.
Args:
task_id: Task ID to publish to
chunk: The stream response chunk to publish
Returns:
The Redis Stream message ID
"""
chunk_json = chunk.model_dump_json()
message_id = "0-0"
try:
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
# Write to Redis Stream for persistence and real-time delivery
raw_id = await redis.xadd(
stream_key,
{"data": chunk_json},
maxlen=config.stream_max_length,
)
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
# Set TTL on stream to match task metadata TTL
await redis.expire(stream_key, config.stream_ttl)
except Exception as e:
logger.error(
f"Failed to publish chunk for task {task_id}: {e}",
exc_info=True,
)
return message_id
async def subscribe_to_task(
task_id: str,
user_id: str | None,
last_message_id: str = "0-0",
) -> asyncio.Queue[StreamBaseResponse] | None:
"""Subscribe to a task's stream with replay of missed messages.
This is fully stateless - uses Redis Stream for replay and pub/sub for live updates.
Args:
task_id: Task ID to subscribe to
user_id: User ID for ownership validation
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
Returns:
An asyncio Queue that will receive stream chunks, or None if task not found
or user doesn't have access
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
logger.debug(f"Task {task_id} not found in Redis")
return None
# Note: Redis client uses decode_responses=True, so keys are strings
task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None
# Validate ownership - if task has an owner, requester must match
if task_user_id:
if user_id != task_user_id:
logger.warning(
f"User {user_id} denied access to task {task_id} "
f"owned by {task_user_id}"
)
return None
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
stream_key = _get_task_stream_key(task_id)
# Step 1: Replay messages from Redis Stream
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
replayed_count = 0
replay_last_id = last_message_id
if messages:
for _stream_name, stream_messages in messages:
for msg_id, msg_data in stream_messages:
replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
# Note: Redis client uses decode_responses=True, so keys are strings
if "data" in msg_data:
try:
chunk_data = orjson.loads(msg_data["data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
await subscriber_queue.put(chunk)
replayed_count += 1
except Exception as e:
logger.warning(f"Failed to replay message: {e}")
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
# Step 2: If task is still running, start stream listener for live updates
if task_status == "running":
listener_task = asyncio.create_task(
_stream_listener(task_id, subscriber_queue, replay_last_id)
)
# Track listener task for cleanup on unsubscribe
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
else:
# Task is completed/failed - add finish marker
await subscriber_queue.put(StreamFinish())
return subscriber_queue
async def _stream_listener(
task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
last_replayed_id: str,
) -> None:
"""Listen to Redis Stream for new messages using blocking XREAD.
This approach avoids the duplicate message issue that can occur with pub/sub
when messages are published during the gap between replay and subscription.
Args:
task_id: Task ID to listen for
subscriber_queue: Queue to deliver messages to
last_replayed_id: Last message ID from replay (continue from here)
"""
queue_id = id(subscriber_queue)
# Track the last successfully delivered message ID for recovery hints
last_delivered_id = last_replayed_id
try:
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
current_id = last_replayed_id
while True:
# Block for up to 30 seconds waiting for new messages
# This allows periodic checking if task is still running
messages = await redis.xread(
{stream_key: current_id}, block=30000, count=100
)
if not messages:
# Timeout - check if task is still running
meta_key = _get_task_meta_key(task_id)
status = await redis.hget(meta_key, "status") # type: ignore[misc]
if status and status != "running":
try:
await asyncio.wait_for(
subscriber_queue.put(StreamFinish()),
timeout=QUEUE_PUT_TIMEOUT,
)
except asyncio.TimeoutError:
logger.warning(
f"Timeout delivering finish event for task {task_id}"
)
break
continue
for _stream_name, stream_messages in messages:
for msg_id, msg_data in stream_messages:
current_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
if "data" not in msg_data:
continue
try:
chunk_data = orjson.loads(msg_data["data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
try:
await asyncio.wait_for(
subscriber_queue.put(chunk),
timeout=QUEUE_PUT_TIMEOUT,
)
# Update last delivered ID on successful delivery
last_delivered_id = current_id
except asyncio.TimeoutError:
logger.warning(
f"Subscriber queue full for task {task_id}, "
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
)
# Send overflow error with recovery info
try:
overflow_error = StreamError(
errorText="Message delivery timeout - some messages may have been missed",
code="QUEUE_OVERFLOW",
details={
"last_delivered_id": last_delivered_id,
"recovery_hint": f"Reconnect with last_message_id={last_delivered_id}",
},
)
subscriber_queue.put_nowait(overflow_error)
except asyncio.QueueFull:
# Queue is completely stuck, nothing more we can do
logger.error(
f"Cannot deliver overflow error for task {task_id}, "
"queue completely blocked"
)
# Stop listening on finish
if isinstance(chunk, StreamFinish):
return
except Exception as e:
logger.warning(f"Error processing stream message: {e}")
except asyncio.CancelledError:
logger.debug(f"Stream listener cancelled for task {task_id}")
raise # Re-raise to propagate cancellation
except Exception as e:
logger.error(f"Stream listener error for task {task_id}: {e}")
# On error, send finish to unblock subscriber
try:
await asyncio.wait_for(
subscriber_queue.put(StreamFinish()),
timeout=QUEUE_PUT_TIMEOUT,
)
except (asyncio.TimeoutError, asyncio.QueueFull):
logger.warning(
f"Could not deliver finish event for task {task_id} after error"
)
finally:
# Clean up listener task mapping on exit
_listener_tasks.pop(queue_id, None)
async def mark_task_completed(
task_id: str,
status: Literal["completed", "failed"] = "completed",
) -> bool:
"""Mark a task as completed and publish finish event.
This is idempotent - calling multiple times with the same task_id is safe.
Uses atomic compare-and-swap via Lua script to prevent race conditions.
Status is updated first (source of truth), then finish event is published (best-effort).
Args:
task_id: Task ID to mark as completed
status: Final status ("completed" or "failed")
Returns:
True if task was newly marked completed, False if already completed/failed
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
# Atomic compare-and-swap: only update if status is "running"
# This prevents race conditions when multiple callers try to complete simultaneously
result = await redis.eval(COMPLETE_TASK_SCRIPT, 1, meta_key, status) # type: ignore[misc]
if result == 0:
logger.debug(f"Task {task_id} already completed/failed, skipping")
return False
# THEN publish finish event (best-effort - listeners can detect via status polling)
try:
await publish_chunk(task_id, StreamFinish())
except Exception as e:
logger.error(
f"Failed to publish finish event for task {task_id}: {e}. "
"Listeners will detect completion via status polling."
)
# Clean up local task reference if exists
_local_tasks.pop(task_id, None)
return True
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
"""Find a task by its operation ID.
Used by webhook callbacks to locate the task to update.
Args:
operation_id: Operation ID to search for
Returns:
ActiveTask if found, None otherwise
"""
redis = await get_redis_async()
op_key = _get_operation_mapping_key(operation_id)
task_id = await redis.get(op_key)
if not task_id:
return None
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
return await get_task(task_id_str)
async def get_task(task_id: str) -> ActiveTask | None:
"""Get a task by its ID from Redis.
Args:
task_id: Task ID to look up
Returns:
ActiveTask if found, None otherwise
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
return None
# Note: Redis client uses decode_responses=True, so keys/values are strings
return ActiveTask(
task_id=meta.get("task_id", ""),
session_id=meta.get("session_id", ""),
user_id=meta.get("user_id", "") or None,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
operation_id=meta.get("operation_id", ""),
status=meta.get("status", "running"), # type: ignore[arg-type]
)
async def get_task_with_expiry_info(
task_id: str,
) -> tuple[ActiveTask | None, str | None]:
"""Get a task by its ID with expiration detection.
Returns (task, error_code) where error_code is:
- None if task found
- "TASK_EXPIRED" if stream exists but metadata is gone (TTL expired)
- "TASK_NOT_FOUND" if neither exists
Args:
task_id: Task ID to look up
Returns:
Tuple of (ActiveTask or None, error_code or None)
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
stream_key = _get_task_stream_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
# Check if stream still has data (metadata expired but stream hasn't)
stream_len = await redis.xlen(stream_key)
if stream_len > 0:
return None, "TASK_EXPIRED"
return None, "TASK_NOT_FOUND"
# Note: Redis client uses decode_responses=True, so keys/values are strings
return (
ActiveTask(
task_id=meta.get("task_id", ""),
session_id=meta.get("session_id", ""),
user_id=meta.get("user_id", "") or None,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
operation_id=meta.get("operation_id", ""),
status=meta.get("status", "running"), # type: ignore[arg-type]
),
None,
)
async def get_active_task_for_session(
session_id: str,
user_id: str | None = None,
) -> tuple[ActiveTask | None, str]:
"""Get the active (running) task for a session, if any.
Scans Redis for tasks matching the session_id with status="running".
Args:
session_id: Session ID to look up
user_id: User ID for ownership validation (optional)
Returns:
Tuple of (ActiveTask if found and running, last_message_id from Redis Stream)
"""
redis = await get_redis_async()
# Scan Redis for task metadata keys
cursor = 0
tasks_checked = 0
while True:
cursor, keys = await redis.scan(
cursor, match=f"{config.task_meta_prefix}*", count=100
)
for key in keys:
tasks_checked += 1
meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc]
if not meta:
continue
# Note: Redis client uses decode_responses=True, so keys/values are strings
task_session_id = meta.get("session_id", "")
task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None
task_id = meta.get("task_id", "")
if task_session_id == session_id and task_status == "running":
# Validate ownership - if task has an owner, requester must match
if task_user_id and user_id != task_user_id:
continue
# Get the last message ID from Redis Stream
stream_key = _get_task_stream_key(task_id)
last_id = "0-0"
try:
messages = await redis.xrevrange(stream_key, count=1)
if messages:
msg_id = messages[0][0]
last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
except Exception as e:
logger.warning(f"Failed to get last message ID: {e}")
return (
ActiveTask(
task_id=task_id,
session_id=task_session_id,
user_id=task_user_id,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
operation_id=meta.get("operation_id", ""),
status="running",
),
last_id,
)
if cursor == 0:
break
return None, "0-0"
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
"""Reconstruct a StreamBaseResponse from JSON data.
Args:
chunk_data: Parsed JSON data from Redis
Returns:
Reconstructed response object, or None if unknown type
"""
from .response_model import (
ResponseType,
StreamError,
StreamFinish,
StreamHeartbeat,
StreamStart,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
# Map response types to their corresponding classes
type_to_class: dict[str, type[StreamBaseResponse]] = {
ResponseType.START.value: StreamStart,
ResponseType.FINISH.value: StreamFinish,
ResponseType.TEXT_START.value: StreamTextStart,
ResponseType.TEXT_DELTA.value: StreamTextDelta,
ResponseType.TEXT_END.value: StreamTextEnd,
ResponseType.TOOL_INPUT_START.value: StreamToolInputStart,
ResponseType.TOOL_INPUT_AVAILABLE.value: StreamToolInputAvailable,
ResponseType.TOOL_OUTPUT_AVAILABLE.value: StreamToolOutputAvailable,
ResponseType.ERROR.value: StreamError,
ResponseType.USAGE.value: StreamUsage,
ResponseType.HEARTBEAT.value: StreamHeartbeat,
}
chunk_type = chunk_data.get("type")
chunk_class = type_to_class.get(chunk_type) # type: ignore[arg-type]
if chunk_class is None:
logger.warning(f"Unknown chunk type: {chunk_type}")
return None
try:
return chunk_class(**chunk_data)
except Exception as e:
logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}")
return None
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
"""Track the asyncio.Task for a task (local reference only).
This is just for cleanup purposes - the task state is in Redis.
Args:
task_id: Task ID
asyncio_task: The asyncio Task to track
"""
_local_tasks[task_id] = asyncio_task
async def unsubscribe_from_task(
task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
) -> None:
"""Clean up when a subscriber disconnects.
Cancels the XREAD-based listener task associated with this subscriber queue
to prevent resource leaks.
Args:
task_id: Task ID
subscriber_queue: The subscriber's queue used to look up the listener task
"""
queue_id = id(subscriber_queue)
listener_entry = _listener_tasks.pop(queue_id, None)
if listener_entry is None:
logger.debug(
f"No listener task found for task {task_id} queue {queue_id} "
"(may have already completed)"
)
return
stored_task_id, listener_task = listener_entry
if stored_task_id != task_id:
logger.warning(
f"Task ID mismatch in unsubscribe: expected {task_id}, "
f"found {stored_task_id}"
)
if listener_task.done():
logger.debug(f"Listener task for task {task_id} already completed")
return
# Cancel the listener task
listener_task.cancel()
try:
# Wait for the task to be cancelled with a timeout
await asyncio.wait_for(listener_task, timeout=5.0)
except asyncio.CancelledError:
# Expected - the task was successfully cancelled
pass
except asyncio.TimeoutError:
logger.warning(
f"Timeout waiting for listener task cancellation for task {task_id}"
)
except Exception as e:
logger.error(f"Error during listener task cancellation for task {task_id}: {e}")
logger.debug(f"Successfully unsubscribed from task {task_id}")

View File

@@ -1,79 +0,0 @@
# CoPilot Tools - Future Ideas
## Multimodal Image Support for CoPilot
**Problem:** CoPilot uses a vision-capable model but can't "see" workspace images. When a block generates an image and returns `workspace://abc123`, CoPilot can't evaluate it (e.g., checking blog thumbnail quality).
**Backend Solution:**
When preparing messages for the LLM, detect `workspace://` image references and convert them to proper image content blocks:
```python
# Before sending to LLM, scan for workspace image references
# and inject them as image content parts
# Example message transformation:
# FROM: {"role": "assistant", "content": "Generated image: workspace://abc123"}
# TO: {"role": "assistant", "content": [
# {"type": "text", "text": "Generated image: workspace://abc123"},
# {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
# ]}
```
**Where to implement:**
- In the chat stream handler before calling the LLM
- Or in a message preprocessing step
- Need to fetch image from workspace, convert to base64, add as image content
**Considerations:**
- Only do this for image MIME types (image/png, image/jpeg, etc.)
- May want a size limit (don't pass 10MB images)
- Track which images were "shown" to the AI for frontend indicator
- Cost implications - vision API calls are more expensive
**Frontend Solution:**
Show visual indicator on workspace files in chat:
- If AI saw the image: normal display
- If AI didn't see it: overlay icon saying "AI can't see this image"
Requires response metadata indicating which `workspace://` refs were passed to the model.
---
## Output Post-Processing Layer for run_block
**Problem:** Many blocks produce large outputs that:
- Consume massive context (100KB base64 image = ~133KB tokens)
- Can't fit in conversation
- Break things and cause high LLM costs
**Proposed Solution:** Instead of modifying individual blocks or `store_media_file()`, implement a centralized output processor in `run_block.py` that handles outputs before they're returned to CoPilot.
**Benefits:**
1. **Centralized** - one place to handle all output processing
2. **Future-proof** - new blocks automatically get output processing
3. **Keeps blocks pure** - they don't need to know about context constraints
4. **Handles all large outputs** - not just images
**Processing Rules:**
- Detect base64 data URIs → save to workspace, return `workspace://` reference
- Truncate very long strings (>N chars) with truncation note
- Summarize large arrays/lists (e.g., "Array with 1000 items, first 5: [...]")
- Handle nested large outputs in dicts recursively
- Cap total output size
**Implementation Location:** `run_block.py` after block execution, before returning `BlockOutputResponse`
**Example:**
```python
def _process_outputs_for_context(
outputs: dict[str, list[Any]],
workspace_manager: WorkspaceManager,
max_string_length: int = 10000,
max_array_preview: int = 5,
) -> dict[str, list[Any]]:
"""Process block outputs to prevent context bloat."""
processed = {}
for name, values in outputs.items():
processed[name] = [_process_value(v, workspace_manager) for v in values]
return processed
```

View File

@@ -1,16 +1,13 @@
import logging
from typing import TYPE_CHECKING, Any
from openai.types.chat import ChatCompletionToolParam
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tracking import track_tool_called
from .add_understanding import AddUnderstandingTool
from .agent_output import AgentOutputTool
from .base import BaseTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
from .edit_agent import EditAgentTool
from .find_agent import FindAgentTool
from .find_block import FindBlockTool
@@ -19,23 +16,14 @@ from .get_doc_page import GetDocPageTool
from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .search_docs import SearchDocsTool
from .workspace_files import (
DeleteWorkspaceFileTool,
ListWorkspaceFilesTool,
ReadWorkspaceFileTool,
WriteWorkspaceFileTool,
)
if TYPE_CHECKING:
from backend.api.features.chat.response_model import StreamToolOutputAvailable
logger = logging.getLogger(__name__)
# Single source of truth for all tools
TOOL_REGISTRY: dict[str, BaseTool] = {
"add_understanding": AddUnderstandingTool(),
"create_agent": CreateAgentTool(),
"customize_agent": CustomizeAgentTool(),
"edit_agent": EditAgentTool(),
"find_agent": FindAgentTool(),
"find_block": FindBlockTool(),
@@ -45,11 +33,6 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"view_agent_output": AgentOutputTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
# Workspace tools for CoPilot file operations
"list_workspace_files": ListWorkspaceFilesTool(),
"read_workspace_file": ReadWorkspaceFileTool(),
"write_workspace_file": WriteWorkspaceFileTool(),
"delete_workspace_file": DeleteWorkspaceFileTool(),
}
# Export individual tool instances for backwards compatibility
@@ -62,11 +45,6 @@ tools: list[ChatCompletionToolParam] = [
]
def get_tool(tool_name: str) -> BaseTool | None:
"""Get a tool instance by name."""
return TOOL_REGISTRY.get(tool_name)
async def execute_tool(
tool_name: str,
parameters: dict[str, Any],
@@ -75,20 +53,7 @@ async def execute_tool(
tool_call_id: str,
) -> "StreamToolOutputAvailable":
"""Execute a tool by name."""
tool = get_tool(tool_name)
tool = TOOL_REGISTRY.get(tool_name)
if not tool:
raise ValueError(f"Tool {tool_name} not found")
# Track tool call in PostHog
logger.info(
f"Tracking tool call: tool={tool_name}, user={user_id}, "
f"session={session.session_id}, call_id={tool_call_id}"
)
track_tool_called(
user_id=user_id,
session_id=session.session_id,
tool_name=tool_name,
tool_call_id=tool_call_id,
)
return await tool.execute(user_id, session, tool_call_id, **parameters)

View File

@@ -3,6 +3,8 @@
import logging
from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession
from backend.data.understanding import (
BusinessUnderstandingInput,
@@ -59,6 +61,7 @@ and automations for the user's specific needs."""
"""Requires authentication to store user-specific data."""
return True
@observe(as_type="tool", name="add_understanding")
async def _execute(
self,
user_id: str | None,

View File

@@ -1,59 +1,29 @@
"""Agent generator package - Creates agents from natural language."""
from .core import (
AgentGeneratorNotConfiguredError,
AgentJsonValidationError,
AgentSummary,
DecompositionResult,
DecompositionStep,
LibraryAgentSummary,
MarketplaceAgentSummary,
customize_template,
apply_agent_patch,
decompose_goal,
enrich_library_agents_from_steps,
extract_search_terms_from_steps,
extract_uuids_from_text,
generate_agent,
generate_agent_patch,
get_agent_as_json,
get_all_relevant_agents_for_generation,
get_library_agent_by_graph_id,
get_library_agent_by_id,
get_library_agents_for_generation,
graph_to_json,
json_to_graph,
save_agent_to_library,
search_marketplace_agents_for_generation,
)
from .errors import get_user_message_for_error
from .service import health_check as check_external_service_health
from .service import is_external_service_configured
from .fixer import apply_all_fixes
from .utils import get_blocks_info
from .validator import validate_agent
__all__ = [
"AgentGeneratorNotConfiguredError",
"AgentJsonValidationError",
"AgentSummary",
"DecompositionResult",
"DecompositionStep",
"LibraryAgentSummary",
"MarketplaceAgentSummary",
"check_external_service_health",
"customize_template",
# Core functions
"decompose_goal",
"enrich_library_agents_from_steps",
"extract_search_terms_from_steps",
"extract_uuids_from_text",
"generate_agent",
"generate_agent_patch",
"get_agent_as_json",
"get_all_relevant_agents_for_generation",
"get_library_agent_by_graph_id",
"get_library_agent_by_id",
"get_library_agents_for_generation",
"get_user_message_for_error",
"graph_to_json",
"is_external_service_configured",
"json_to_graph",
"apply_agent_patch",
"save_agent_to_library",
"search_marketplace_agents_for_generation",
"get_agent_as_json",
# Fixer
"apply_all_fixes",
# Validator
"validate_agent",
# Utils
"get_blocks_info",
]

View File

@@ -0,0 +1,25 @@
"""OpenRouter client configuration for agent generation."""
import os
from openai import AsyncOpenAI
# Configuration - use OPEN_ROUTER_API_KEY for consistency with chat/config.py
OPENROUTER_API_KEY = os.getenv("OPEN_ROUTER_API_KEY")
AGENT_GENERATOR_MODEL = os.getenv("AGENT_GENERATOR_MODEL", "anthropic/claude-opus-4.5")
# OpenRouter client (OpenAI-compatible API)
_client: AsyncOpenAI | None = None
def get_client() -> AsyncOpenAI:
"""Get or create the OpenRouter client."""
global _client
if _client is None:
if not OPENROUTER_API_KEY:
raise ValueError("OPENROUTER_API_KEY environment variable is required")
_client = AsyncOpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=OPENROUTER_API_KEY,
)
return _client

View File

@@ -1,95 +0,0 @@
"""Error handling utilities for agent generator."""
import re
def _sanitize_error_details(details: str) -> str:
"""Sanitize error details to remove sensitive information.
Strips common patterns that could expose internal system info:
- File paths (Unix and Windows)
- Database connection strings
- URLs with credentials
- Stack trace internals
Args:
details: Raw error details string
Returns:
Sanitized error details safe for user display
"""
sanitized = re.sub(
r"/[a-zA-Z0-9_./\-]+\.(py|js|ts|json|yaml|yml)", "[path]", details
)
sanitized = re.sub(r"[A-Z]:\\[a-zA-Z0-9_\\.\\-]+", "[path]", sanitized)
sanitized = re.sub(
r"(postgres|mysql|mongodb|redis)://[^\s]+", "[database_url]", sanitized
)
sanitized = re.sub(r"https?://[^:]+:[^@]+@[^\s]+", "[url]", sanitized)
sanitized = re.sub(r", line \d+", "", sanitized)
sanitized = re.sub(r'File "[^"]+",?', "", sanitized)
return sanitized.strip()
def get_user_message_for_error(
error_type: str,
operation: str = "process the request",
llm_parse_message: str | None = None,
validation_message: str | None = None,
error_details: str | None = None,
) -> str:
"""Get a user-friendly error message based on error type.
This function maps internal error types to user-friendly messages,
providing a consistent experience across different agent operations.
Args:
error_type: The error type from the external service
(e.g., "llm_parse_error", "timeout", "rate_limit")
operation: Description of what operation failed, used in the default
message (e.g., "analyze the goal", "generate the agent")
llm_parse_message: Custom message for llm_parse_error type
validation_message: Custom message for validation_error type
error_details: Optional additional details about the error
Returns:
User-friendly error message suitable for display to the user
"""
base_message = ""
if error_type == "llm_parse_error":
base_message = (
llm_parse_message
or "The AI had trouble processing this request. Please try again."
)
elif error_type == "validation_error":
base_message = (
validation_message
or "The generated agent failed validation. "
"This usually happens when the agent structure doesn't match "
"what the platform expects. Please try simplifying your goal "
"or breaking it into smaller parts."
)
elif error_type == "patch_error":
base_message = (
"Failed to apply the changes. The modification couldn't be "
"validated. Please try a different approach or simplify the change."
)
elif error_type in ("timeout", "llm_timeout"):
base_message = (
"The request took too long to process. This can happen with "
"complex agents. Please try again or simplify your goal."
)
elif error_type in ("rate_limit", "llm_rate_limit"):
base_message = "The service is currently busy. Please try again in a moment."
else:
base_message = f"Failed to {operation}. Please try again."
if error_details:
details = _sanitize_error_details(error_details)
if len(details) > 200:
details = details[:200] + "..."
base_message += f"\n\nTechnical details: {details}"
return base_message

View File

@@ -0,0 +1,606 @@
"""Agent fixer - Fixes common LLM generation errors."""
import logging
import re
import uuid
from typing import Any
from .utils import (
ADDTODICTIONARY_BLOCK_ID,
ADDTOLIST_BLOCK_ID,
CODE_EXECUTION_BLOCK_ID,
CONDITION_BLOCK_ID,
CREATEDICT_BLOCK_ID,
CREATELIST_BLOCK_ID,
DATA_SAMPLING_BLOCK_ID,
DOUBLE_CURLY_BRACES_BLOCK_IDS,
GET_CURRENT_DATE_BLOCK_ID,
STORE_VALUE_BLOCK_ID,
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
get_blocks_info,
is_valid_uuid,
)
logger = logging.getLogger(__name__)
def fix_agent_ids(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix invalid UUIDs in agent and link IDs."""
# Fix agent ID
if not is_valid_uuid(agent.get("id", "")):
agent["id"] = str(uuid.uuid4())
logger.debug(f"Fixed agent ID: {agent['id']}")
# Fix node IDs
id_mapping = {} # Old ID -> New ID
for node in agent.get("nodes", []):
if not is_valid_uuid(node.get("id", "")):
old_id = node.get("id", "")
new_id = str(uuid.uuid4())
id_mapping[old_id] = new_id
node["id"] = new_id
logger.debug(f"Fixed node ID: {old_id} -> {new_id}")
# Fix link IDs and update references
for link in agent.get("links", []):
if not is_valid_uuid(link.get("id", "")):
link["id"] = str(uuid.uuid4())
logger.debug(f"Fixed link ID: {link['id']}")
# Update source/sink IDs if they were remapped
if link.get("source_id") in id_mapping:
link["source_id"] = id_mapping[link["source_id"]]
if link.get("sink_id") in id_mapping:
link["sink_id"] = id_mapping[link["sink_id"]]
return agent
def fix_double_curly_braces(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix single curly braces to double in template blocks."""
for node in agent.get("nodes", []):
if node.get("block_id") not in DOUBLE_CURLY_BRACES_BLOCK_IDS:
continue
input_data = node.get("input_default", {})
for key in ("prompt", "format"):
if key in input_data and isinstance(input_data[key], str):
original = input_data[key]
# Fix simple variable references: {var} -> {{var}}
fixed = re.sub(
r"(?<!\{)\{([a-zA-Z_][a-zA-Z0-9_]*)\}(?!\})",
r"{{\1}}",
original,
)
if fixed != original:
input_data[key] = fixed
logger.debug(f"Fixed curly braces in {key}")
return agent
def fix_storevalue_before_condition(agent: dict[str, Any]) -> dict[str, Any]:
"""Add StoreValueBlock before ConditionBlock if needed for value2."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
# Find all ConditionBlock nodes
condition_node_ids = {
node["id"] for node in nodes if node.get("block_id") == CONDITION_BLOCK_ID
}
if not condition_node_ids:
return agent
new_nodes = []
new_links = []
processed_conditions = set()
for link in links:
sink_id = link.get("sink_id")
sink_name = link.get("sink_name")
# Check if this link goes to a ConditionBlock's value2
if sink_id in condition_node_ids and sink_name == "value2":
source_node = next(
(n for n in nodes if n["id"] == link.get("source_id")), None
)
# Skip if source is already a StoreValueBlock
if source_node and source_node.get("block_id") == STORE_VALUE_BLOCK_ID:
continue
# Skip if we already processed this condition
if sink_id in processed_conditions:
continue
processed_conditions.add(sink_id)
# Create StoreValueBlock
store_node_id = str(uuid.uuid4())
store_node = {
"id": store_node_id,
"block_id": STORE_VALUE_BLOCK_ID,
"input_default": {"data": None},
"metadata": {"position": {"x": 0, "y": -100}},
}
new_nodes.append(store_node)
# Create link: original source -> StoreValueBlock
new_links.append(
{
"id": str(uuid.uuid4()),
"source_id": link["source_id"],
"source_name": link["source_name"],
"sink_id": store_node_id,
"sink_name": "input",
"is_static": False,
}
)
# Update original link: StoreValueBlock -> ConditionBlock
link["source_id"] = store_node_id
link["source_name"] = "output"
logger.debug(f"Added StoreValueBlock before ConditionBlock {sink_id}")
if new_nodes:
agent["nodes"] = nodes + new_nodes
return agent
def fix_addtolist_blocks(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix AddToList blocks by adding prerequisite empty AddToList block.
When an AddToList block is found:
1. Checks if there's a CreateListBlock before it
2. Removes CreateListBlock if linked directly to AddToList
3. Adds an empty AddToList block before the original
4. Ensures the original has a self-referencing link
"""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
new_nodes = []
original_addtolist_ids = set()
nodes_to_remove = set()
links_to_remove = []
# First pass: identify CreateListBlock nodes to remove
for link in links:
source_node = next(
(n for n in nodes if n.get("id") == link.get("source_id")), None
)
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
if (
source_node
and sink_node
and source_node.get("block_id") == CREATELIST_BLOCK_ID
and sink_node.get("block_id") == ADDTOLIST_BLOCK_ID
):
nodes_to_remove.add(source_node.get("id"))
links_to_remove.append(link)
logger.debug(f"Removing CreateListBlock {source_node.get('id')}")
# Second pass: process AddToList blocks
filtered_nodes = []
for node in nodes:
if node.get("id") in nodes_to_remove:
continue
if node.get("block_id") == ADDTOLIST_BLOCK_ID:
original_addtolist_ids.add(node.get("id"))
node_id = node.get("id")
pos = node.get("metadata", {}).get("position", {"x": 0, "y": 0})
# Check if already has prerequisite
has_prereq = any(
link.get("sink_id") == node_id
and link.get("sink_name") == "list"
and link.get("source_name") == "updated_list"
for link in links
)
if not has_prereq:
# Remove links to "list" input (except self-reference)
for link in links:
if (
link.get("sink_id") == node_id
and link.get("sink_name") == "list"
and link.get("source_id") != node_id
and link not in links_to_remove
):
links_to_remove.append(link)
# Create prerequisite AddToList block
prereq_id = str(uuid.uuid4())
prereq_node = {
"id": prereq_id,
"block_id": ADDTOLIST_BLOCK_ID,
"input_default": {"list": [], "entry": None, "entries": []},
"metadata": {
"position": {"x": pos.get("x", 0) - 800, "y": pos.get("y", 0)}
},
}
new_nodes.append(prereq_node)
# Link prerequisite to original
links.append(
{
"id": str(uuid.uuid4()),
"source_id": prereq_id,
"source_name": "updated_list",
"sink_id": node_id,
"sink_name": "list",
"is_static": False,
}
)
logger.debug(f"Added prerequisite AddToList block for {node_id}")
filtered_nodes.append(node)
# Remove marked links
filtered_links = [link for link in links if link not in links_to_remove]
# Add self-referencing links for original AddToList blocks
for node in filtered_nodes + new_nodes:
if (
node.get("block_id") == ADDTOLIST_BLOCK_ID
and node.get("id") in original_addtolist_ids
):
node_id = node.get("id")
has_self_ref = any(
link["source_id"] == node_id
and link["sink_id"] == node_id
and link["source_name"] == "updated_list"
and link["sink_name"] == "list"
for link in filtered_links
)
if not has_self_ref:
filtered_links.append(
{
"id": str(uuid.uuid4()),
"source_id": node_id,
"source_name": "updated_list",
"sink_id": node_id,
"sink_name": "list",
"is_static": False,
}
)
logger.debug(f"Added self-reference for AddToList {node_id}")
agent["nodes"] = filtered_nodes + new_nodes
agent["links"] = filtered_links
return agent
def fix_addtodictionary_blocks(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix AddToDictionary blocks by removing empty CreateDictionary nodes."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
nodes_to_remove = set()
links_to_remove = []
for link in links:
source_node = next(
(n for n in nodes if n.get("id") == link.get("source_id")), None
)
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
if (
source_node
and sink_node
and source_node.get("block_id") == CREATEDICT_BLOCK_ID
and sink_node.get("block_id") == ADDTODICTIONARY_BLOCK_ID
):
nodes_to_remove.add(source_node.get("id"))
links_to_remove.append(link)
logger.debug(f"Removing CreateDictionary {source_node.get('id')}")
agent["nodes"] = [n for n in nodes if n.get("id") not in nodes_to_remove]
agent["links"] = [link for link in links if link not in links_to_remove]
return agent
def fix_code_execution_output(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix CodeExecutionBlock output: change 'response' to 'stdout_logs'."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
for link in links:
source_node = next(
(n for n in nodes if n.get("id") == link.get("source_id")), None
)
if (
source_node
and source_node.get("block_id") == CODE_EXECUTION_BLOCK_ID
and link.get("source_name") == "response"
):
link["source_name"] = "stdout_logs"
logger.debug("Fixed CodeExecutionBlock output: response -> stdout_logs")
return agent
def fix_data_sampling_sample_size(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix DataSamplingBlock by setting sample_size to 1 as default."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
links_to_remove = []
for node in nodes:
if node.get("block_id") == DATA_SAMPLING_BLOCK_ID:
node_id = node.get("id")
input_default = node.get("input_default", {})
# Remove links to sample_size
for link in links:
if (
link.get("sink_id") == node_id
and link.get("sink_name") == "sample_size"
):
links_to_remove.append(link)
# Set default
input_default["sample_size"] = 1
node["input_default"] = input_default
logger.debug(f"Fixed DataSamplingBlock {node_id} sample_size to 1")
if links_to_remove:
agent["links"] = [link for link in links if link not in links_to_remove]
return agent
def fix_node_x_coordinates(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix node x-coordinates to ensure 800+ unit spacing between linked nodes."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
node_lookup = {n.get("id"): n for n in nodes}
for link in links:
source_id = link.get("source_id")
sink_id = link.get("sink_id")
source_node = node_lookup.get(source_id)
sink_node = node_lookup.get(sink_id)
if not source_node or not sink_node:
continue
source_pos = source_node.get("metadata", {}).get("position", {})
sink_pos = sink_node.get("metadata", {}).get("position", {})
source_x = source_pos.get("x", 0)
sink_x = sink_pos.get("x", 0)
if abs(sink_x - source_x) < 800:
new_x = source_x + 800
if "metadata" not in sink_node:
sink_node["metadata"] = {}
if "position" not in sink_node["metadata"]:
sink_node["metadata"]["position"] = {}
sink_node["metadata"]["position"]["x"] = new_x
logger.debug(f"Fixed node {sink_id} x: {sink_x} -> {new_x}")
return agent
def fix_getcurrentdate_offset(agent: dict[str, Any]) -> dict[str, Any]:
"""Fix GetCurrentDateBlock offset to ensure it's positive."""
for node in agent.get("nodes", []):
if node.get("block_id") == GET_CURRENT_DATE_BLOCK_ID:
input_default = node.get("input_default", {})
if "offset" in input_default:
offset = input_default["offset"]
if isinstance(offset, (int, float)) and offset < 0:
input_default["offset"] = abs(offset)
logger.debug(f"Fixed offset: {offset} -> {abs(offset)}")
return agent
def fix_ai_model_parameter(
agent: dict[str, Any],
blocks_info: list[dict[str, Any]],
default_model: str = "gpt-4o",
) -> dict[str, Any]:
"""Add default model parameter to AI blocks if missing."""
block_map = {b.get("id"): b for b in blocks_info}
for node in agent.get("nodes", []):
block_id = node.get("block_id")
block = block_map.get(block_id)
if not block:
continue
# Check if block has AI category
categories = block.get("categories", [])
is_ai_block = any(
cat.get("category") == "AI" for cat in categories if isinstance(cat, dict)
)
if is_ai_block:
input_default = node.get("input_default", {})
if "model" not in input_default:
input_default["model"] = default_model
node["input_default"] = input_default
logger.debug(
f"Added model '{default_model}' to AI block {node.get('id')}"
)
return agent
def fix_link_static_properties(
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> dict[str, Any]:
"""Fix is_static property based on source block's staticOutput."""
block_map = {b.get("id"): b for b in blocks_info}
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
for link in agent.get("links", []):
source_node = node_lookup.get(link.get("source_id"))
if not source_node:
continue
source_block = block_map.get(source_node.get("block_id"))
if not source_block:
continue
static_output = source_block.get("staticOutput", False)
if link.get("is_static") != static_output:
link["is_static"] = static_output
logger.debug(f"Fixed link {link.get('id')} is_static to {static_output}")
return agent
def fix_data_type_mismatch(
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> dict[str, Any]:
"""Fix data type mismatches by inserting UniversalTypeConverterBlock."""
nodes = agent.get("nodes", [])
links = agent.get("links", [])
block_map = {b.get("id"): b for b in blocks_info}
node_lookup = {n.get("id"): n for n in nodes}
def get_property_type(schema: dict, name: str) -> str | None:
if "_#_" in name:
parent, child = name.split("_#_", 1)
parent_schema = schema.get(parent, {})
if "properties" in parent_schema:
return parent_schema["properties"].get(child, {}).get("type")
return None
return schema.get(name, {}).get("type")
def are_types_compatible(src: str, sink: str) -> bool:
if {src, sink} <= {"integer", "number"}:
return True
return src == sink
type_mapping = {
"string": "string",
"text": "string",
"integer": "number",
"number": "number",
"float": "number",
"boolean": "boolean",
"bool": "boolean",
"array": "list",
"list": "list",
"object": "dictionary",
"dict": "dictionary",
"dictionary": "dictionary",
}
new_links = []
nodes_to_add = []
for link in links:
source_node = node_lookup.get(link.get("source_id"))
sink_node = node_lookup.get(link.get("sink_id"))
if not source_node or not sink_node:
new_links.append(link)
continue
source_block = block_map.get(source_node.get("block_id"))
sink_block = block_map.get(sink_node.get("block_id"))
if not source_block or not sink_block:
new_links.append(link)
continue
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
source_type = get_property_type(source_outputs, link.get("source_name", ""))
sink_type = get_property_type(sink_inputs, link.get("sink_name", ""))
if (
source_type
and sink_type
and not are_types_compatible(source_type, sink_type)
):
# Insert type converter
converter_id = str(uuid.uuid4())
target_type = type_mapping.get(sink_type, sink_type)
converter_node = {
"id": converter_id,
"block_id": UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
"input_default": {"type": target_type},
"metadata": {"position": {"x": 0, "y": 100}},
}
nodes_to_add.append(converter_node)
# source -> converter
new_links.append(
{
"id": str(uuid.uuid4()),
"source_id": link["source_id"],
"source_name": link["source_name"],
"sink_id": converter_id,
"sink_name": "value",
"is_static": False,
}
)
# converter -> sink
new_links.append(
{
"id": str(uuid.uuid4()),
"source_id": converter_id,
"source_name": "value",
"sink_id": link["sink_id"],
"sink_name": link["sink_name"],
"is_static": False,
}
)
logger.debug(f"Inserted type converter: {source_type} -> {target_type}")
else:
new_links.append(link)
if nodes_to_add:
agent["nodes"] = nodes + nodes_to_add
agent["links"] = new_links
return agent
def apply_all_fixes(
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
) -> dict[str, Any]:
"""Apply all fixes to an agent JSON.
Args:
agent: Agent JSON dict
blocks_info: Optional list of block info dicts for advanced fixes
Returns:
Fixed agent JSON
"""
# Basic fixes (no block info needed)
agent = fix_agent_ids(agent)
agent = fix_double_curly_braces(agent)
agent = fix_storevalue_before_condition(agent)
agent = fix_addtolist_blocks(agent)
agent = fix_addtodictionary_blocks(agent)
agent = fix_code_execution_output(agent)
agent = fix_data_sampling_sample_size(agent)
agent = fix_node_x_coordinates(agent)
agent = fix_getcurrentdate_offset(agent)
# Advanced fixes (require block info)
if blocks_info is None:
blocks_info = get_blocks_info()
agent = fix_ai_model_parameter(agent, blocks_info)
agent = fix_link_static_properties(agent, blocks_info)
agent = fix_data_type_mismatch(agent, blocks_info)
return agent

View File

@@ -0,0 +1,225 @@
"""Prompt templates for agent generation."""
DECOMPOSITION_PROMPT = """
You are an expert AutoGPT Workflow Decomposer. Your task is to analyze a user's high-level goal and break it down into a clear, step-by-step plan using the available blocks.
Each step should represent a distinct, automatable action suitable for execution by an AI automation system.
---
FIRST: Analyze the user's goal and determine:
1) Design-time configuration (fixed settings that won't change per run)
2) Runtime inputs (values the agent's end-user will provide each time it runs)
For anything that can vary per run (email addresses, names, dates, search terms, etc.):
- DO NOT ask for the actual value
- Instead, define it as an Agent Input with a clear name, type, and description
Only ask clarifying questions about design-time config that affects how you build the workflow:
- Which external service to use (e.g., "Gmail vs Outlook", "Notion vs Google Docs")
- Required formats or structures (e.g., "CSV, JSON, or PDF output?")
- Business rules that must be hard-coded
IMPORTANT CLARIFICATIONS POLICY:
- Ask no more than five essential questions
- Do not ask for concrete values that can be provided at runtime as Agent Inputs
- Do not ask for API keys or credentials; the platform handles those directly
- If there is enough information to infer reasonable defaults, prefer to propose defaults
---
GUIDELINES:
1. List each step as a numbered item
2. Describe the action clearly and specify inputs/outputs
3. Ensure steps are in logical, sequential order
4. Mention block names naturally (e.g., "Use GetWeatherByLocationBlock to...")
5. Help the user reach their goal efficiently
---
RULES:
1. OUTPUT FORMAT: Only output either clarifying questions OR step-by-step instructions, not both
2. USE ONLY THE BLOCKS PROVIDED
3. ALL required_input fields must be provided
4. Data types of linked properties must match
5. Write expert-level prompts for AI-related blocks
---
CRITICAL BLOCK RESTRICTIONS:
1. AddToListBlock: Outputs updated list EVERY addition, not after all additions
2. SendEmailBlock: Draft the email for user review; set SMTP config based on email type
3. ConditionBlock: value2 is reference, value1 is contrast
4. CodeExecutionBlock: DO NOT USE - use AI blocks instead
5. ReadCsvBlock: Only use the 'rows' output, not 'row'
---
OUTPUT FORMAT:
If more information is needed:
```json
{{
"type": "clarifying_questions",
"questions": [
{{
"question": "Which email provider should be used? (Gmail, Outlook, custom SMTP)",
"keyword": "email_provider",
"example": "Gmail"
}}
]
}}
```
If ready to proceed:
```json
{{
"type": "instructions",
"steps": [
{{
"step_number": 1,
"block_name": "AgentShortTextInputBlock",
"description": "Get the URL of the content to analyze.",
"inputs": [{{"name": "name", "value": "URL"}}],
"outputs": [{{"name": "result", "description": "The URL entered by user"}}]
}}
]
}}
```
---
AVAILABLE BLOCKS:
{block_summaries}
"""
GENERATION_PROMPT = """
You are an expert AI workflow builder. Generate a valid agent JSON from the given instructions.
---
NODES:
Each node must include:
- `id`: Unique UUID v4 (e.g. `a8f5b1e2-c3d4-4e5f-8a9b-0c1d2e3f4a5b`)
- `block_id`: The block identifier (must match an Allowed Block)
- `input_default`: Dict of inputs (can be empty if no static inputs needed)
- `metadata`: Must contain:
- `position`: {{"x": number, "y": number}} - adjacent nodes should differ by 800+ in X
- `customized_name`: Clear name describing this block's purpose in the workflow
---
LINKS:
Each link connects a source node's output to a sink node's input:
- `id`: MUST be UUID v4 (NOT "link-1", "link-2", etc.)
- `source_id`: ID of the source node
- `source_name`: Output field name from the source block
- `sink_id`: ID of the sink node
- `sink_name`: Input field name on the sink block
- `is_static`: true only if source block has static_output: true
CRITICAL: All IDs must be valid UUID v4 format!
---
AGENT (GRAPH):
Wrap nodes and links in:
- `id`: UUID of the agent
- `name`: Short, generic name (avoid specific company names, URLs)
- `description`: Short, generic description
- `nodes`: List of all nodes
- `links`: List of all links
- `version`: 1
- `is_active`: true
---
TIPS:
- All required_input fields must be provided via input_default or a valid link
- Ensure consistent source_id and sink_id references
- Avoid dangling links
- Input/output pins must match block schemas
- Do not invent unknown block_ids
---
ALLOWED BLOCKS:
{block_summaries}
---
Generate the complete agent JSON. Output ONLY valid JSON, no explanation.
"""
PATCH_PROMPT = """
You are an expert at modifying AutoGPT agent workflows. Given the current agent and a modification request, generate a JSON patch to update the agent.
CURRENT AGENT:
{current_agent}
AVAILABLE BLOCKS:
{block_summaries}
---
PATCH FORMAT:
Return a JSON object with the following structure:
```json
{{
"type": "patch",
"intent": "Brief description of what the patch does",
"patches": [
{{
"type": "modify",
"node_id": "uuid-of-node-to-modify",
"changes": {{
"input_default": {{"field": "new_value"}},
"metadata": {{"customized_name": "New Name"}}
}}
}},
{{
"type": "add",
"new_nodes": [
{{
"id": "new-uuid",
"block_id": "block-uuid",
"input_default": {{}},
"metadata": {{"position": {{"x": 0, "y": 0}}, "customized_name": "Name"}}
}}
],
"new_links": [
{{
"id": "link-uuid",
"source_id": "source-node-id",
"source_name": "output_field",
"sink_id": "sink-node-id",
"sink_name": "input_field"
}}
]
}},
{{
"type": "remove",
"node_ids": ["uuid-of-node-to-remove"],
"link_ids": ["uuid-of-link-to-remove"]
}}
]
}}
```
If you need more information, return:
```json
{{
"type": "clarifying_questions",
"questions": [
{{
"question": "What specific change do you want?",
"keyword": "change_type",
"example": "Add error handling"
}}
]
}}
```
Generate the minimal patch needed. Output ONLY valid JSON.
"""

View File

@@ -1,498 +0,0 @@
"""External Agent Generator service client.
This module provides a client for communicating with the external Agent Generator
microservice. When AGENTGENERATOR_HOST is configured, the agent generation functions
will delegate to the external service instead of using the built-in LLM-based implementation.
"""
import logging
from typing import Any
import httpx
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
def _create_error_response(
error_message: str,
error_type: str = "unknown",
details: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Create a standardized error response dict.
Args:
error_message: Human-readable error message
error_type: Machine-readable error type
details: Optional additional error details
Returns:
Error dict with type="error" and error details
"""
response: dict[str, Any] = {
"type": "error",
"error": error_message,
"error_type": error_type,
}
if details:
response["details"] = details
return response
def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
"""Classify an HTTP error into error_type and message.
Args:
e: The HTTP status error
Returns:
Tuple of (error_type, error_message)
"""
status = e.response.status_code
if status == 429:
return "rate_limit", f"Agent Generator rate limited: {e}"
elif status == 503:
return "service_unavailable", f"Agent Generator unavailable: {e}"
elif status == 504 or status == 408:
return "timeout", f"Agent Generator timed out: {e}"
else:
return "http_error", f"HTTP error calling Agent Generator: {e}"
def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
"""Classify a request error into error_type and message.
Args:
e: The request error
Returns:
Tuple of (error_type, error_message)
"""
error_str = str(e).lower()
if "timeout" in error_str or "timed out" in error_str:
return "timeout", f"Agent Generator request timed out: {e}"
elif "connect" in error_str:
return "connection_error", f"Could not connect to Agent Generator: {e}"
else:
return "request_error", f"Request error calling Agent Generator: {e}"
_client: httpx.AsyncClient | None = None
_settings: Settings | None = None
def _get_settings() -> Settings:
"""Get or create settings singleton."""
global _settings
if _settings is None:
_settings = Settings()
return _settings
def is_external_service_configured() -> bool:
"""Check if external Agent Generator service is configured."""
settings = _get_settings()
return bool(settings.config.agentgenerator_host)
def _get_base_url() -> str:
"""Get the base URL for the external service."""
settings = _get_settings()
host = settings.config.agentgenerator_host
port = settings.config.agentgenerator_port
return f"http://{host}:{port}"
def _get_client() -> httpx.AsyncClient:
"""Get or create the HTTP client for the external service."""
global _client
if _client is None:
settings = _get_settings()
_client = httpx.AsyncClient(
base_url=_get_base_url(),
timeout=httpx.Timeout(settings.config.agentgenerator_timeout),
)
return _client
async def decompose_goal_external(
description: str,
context: str = "",
library_agents: list[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Call the external service to decompose a goal.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
library_agents: User's library agents available for sub-agent composition
Returns:
Dict with either:
- {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]}
- {"type": "unachievable_goal", ...}
- {"type": "vague_goal", ...}
- {"type": "error", "error": "...", "error_type": "..."} on error
Or None on unexpected error
"""
client = _get_client()
if context:
description = f"{description}\n\nAdditional context from user:\n{context}"
payload: dict[str, Any] = {"description": description}
if library_agents:
payload["library_agents"] = library_agents
try:
response = await client.post("/api/decompose-description", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator decomposition failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Map the response to the expected format
response_type = data.get("type")
if response_type == "instructions":
return {"type": "instructions", "steps": data.get("steps", [])}
elif response_type == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
elif response_type == "unachievable_goal":
return {
"type": "unachievable_goal",
"reason": data.get("reason"),
"suggested_goal": data.get("suggested_goal"),
}
elif response_type == "vague_goal":
return {
"type": "vague_goal",
"suggested_goal": data.get("suggested_goal"),
}
elif response_type == "error":
# Pass through error from the service
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
else:
logger.error(
f"Unknown response type from external service: {response_type}"
)
return _create_error_response(
f"Unknown response type from Agent Generator: {response_type}",
"invalid_response",
)
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def generate_agent_external(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate an agent from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
"""
client = _get_client()
# Build request payload
payload: dict[str, Any] = {"instructions": instructions}
if library_agents:
payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post("/api/generate-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
)
return _create_error_response(error_msg, error_type)
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def generate_agent_patch_external(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate a patch for an existing agent.
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
"""
client = _get_client()
# Build request payload
payload: dict[str, Any] = {
"update_request": update_request,
"current_agent_json": current_agent,
}
if library_agents:
payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post("/api/update-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async update request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator patch generation failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Check if it's clarifying questions
if data.get("type") == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
# Check if it's an error passed through
if data.get("type") == "error":
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
# Otherwise return the updated agent JSON
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def customize_template_external(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any] | None:
"""Call the external service to customize a template/marketplace agent.
Args:
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
Returns:
Customized agent JSON, clarifying questions dict, or error dict on error
"""
client = _get_client()
request = modification_request
if context:
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
payload: dict[str, Any] = {
"template_agent_json": template_agent,
"modification_request": request,
}
try:
response = await client.post("/api/template-modification", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator template customization failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Check if it's clarifying questions
if data.get("type") == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
# Check if it's an error passed through
if data.get("type") == "error":
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
# Otherwise return the customized agent JSON
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def get_blocks_external() -> list[dict[str, Any]] | None:
"""Get available blocks from the external service.
Returns:
List of block info dicts or None on error
"""
client = _get_client()
try:
response = await client.get("/api/blocks")
response.raise_for_status()
data = response.json()
if not data.get("success"):
logger.error("External service returned error getting blocks")
return None
return data.get("blocks", [])
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error getting blocks from external service: {e}")
return None
except httpx.RequestError as e:
logger.error(f"Request error getting blocks from external service: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error getting blocks from external service: {e}")
return None
async def health_check() -> bool:
"""Check if the external service is healthy.
Returns:
True if healthy, False otherwise
"""
if not is_external_service_configured():
return False
client = _get_client()
try:
response = await client.get("/health")
response.raise_for_status()
data = response.json()
return data.get("status") == "healthy" and data.get("blocks_loaded", False)
except Exception as e:
logger.warning(f"External agent generator health check failed: {e}")
return False
async def close_client() -> None:
"""Close the HTTP client."""
global _client
if _client is not None:
await _client.aclose()
_client = None

View File

@@ -0,0 +1,213 @@
"""Utilities for agent generation."""
import json
import re
from typing import Any
from backend.data.block import get_blocks
# UUID validation regex
UUID_REGEX = re.compile(
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$"
)
# Block IDs for various fixes
STORE_VALUE_BLOCK_ID = "1ff065e9-88e8-4358-9d82-8dc91f622ba9"
CONDITION_BLOCK_ID = "715696a0-e1da-45c8-b209-c2fa9c3b0be6"
ADDTOLIST_BLOCK_ID = "aeb08fc1-2fc1-4141-bc8e-f758f183a822"
ADDTODICTIONARY_BLOCK_ID = "31d1064e-7446-4693-a7d4-65e5ca1180d1"
CREATELIST_BLOCK_ID = "a912d5c7-6e00-4542-b2a9-8034136930e4"
CREATEDICT_BLOCK_ID = "b924ddf4-de4f-4b56-9a85-358930dcbc91"
CODE_EXECUTION_BLOCK_ID = "0b02b072-abe7-11ef-8372-fb5d162dd712"
DATA_SAMPLING_BLOCK_ID = "4a448883-71fa-49cf-91cf-70d793bd7d87"
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID = "95d1b990-ce13-4d88-9737-ba5c2070c97b"
GET_CURRENT_DATE_BLOCK_ID = "b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1"
DOUBLE_CURLY_BRACES_BLOCK_IDS = [
"44f6c8ad-d75c-4ae1-8209-aad1c0326928", # FillTextTemplateBlock
"6ab085e2-20b3-4055-bc3e-08036e01eca6",
"90f8c45e-e983-4644-aa0b-b4ebe2f531bc",
"363ae599-353e-4804-937e-b2ee3cef3da4", # AgentOutputBlock
"3b191d9f-356f-482d-8238-ba04b6d18381",
"db7d8f02-2f44-4c55-ab7a-eae0941f0c30",
"3a7c4b8d-6e2f-4a5d-b9c1-f8d23c5a9b0e",
"ed1ae7a0-b770-4089-b520-1f0005fad19a",
"a892b8d9-3e4e-4e9c-9c1e-75f8efcf1bfa",
"b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1",
"716a67b3-6760-42e7-86dc-18645c6e00fc",
"530cf046-2ce0-4854-ae2c-659db17c7a46",
"ed55ac19-356e-4243-a6cb-bc599e9b716f",
"1f292d4a-41a4-4977-9684-7c8d560b9f91", # LLM blocks
"32a87eab-381e-4dd4-bdb8-4c47151be35a",
]
def is_valid_uuid(value: str) -> bool:
"""Check if a string is a valid UUID v4."""
return isinstance(value, str) and UUID_REGEX.match(value) is not None
def _compact_schema(schema: dict) -> dict[str, str]:
"""Extract compact type info from a JSON schema properties dict.
Returns a dict of {field_name: type_string} for essential info only.
"""
props = schema.get("properties", {})
result = {}
for name, prop in props.items():
# Skip internal/complex fields
if name.startswith("_"):
continue
# Get type string
type_str = prop.get("type", "any")
# Handle anyOf/oneOf (optional types)
if "anyOf" in prop:
types = [t.get("type", "?") for t in prop["anyOf"] if t.get("type")]
type_str = "|".join(types) if types else "any"
elif "allOf" in prop:
type_str = "object"
# Add array item type if present
if type_str == "array" and "items" in prop:
items = prop["items"]
if isinstance(items, dict):
item_type = items.get("type", "any")
type_str = f"array[{item_type}]"
result[name] = type_str
return result
def get_block_summaries(include_schemas: bool = True) -> str:
"""Generate compact block summaries for prompts.
Args:
include_schemas: Whether to include input/output type info
Returns:
Formatted string of block summaries (compact format)
"""
blocks = get_blocks()
summaries = []
for block_id, block_cls in blocks.items():
block = block_cls()
name = block.name
desc = getattr(block, "description", "") or ""
# Truncate description
if len(desc) > 150:
desc = desc[:147] + "..."
if not include_schemas:
summaries.append(f"- {name} (id: {block_id}): {desc}")
else:
# Compact format with type info only
inputs = {}
outputs = {}
required = []
if hasattr(block, "input_schema"):
try:
schema = block.input_schema.jsonschema()
inputs = _compact_schema(schema)
required = schema.get("required", [])
except Exception:
pass
if hasattr(block, "output_schema"):
try:
schema = block.output_schema.jsonschema()
outputs = _compact_schema(schema)
except Exception:
pass
# Build compact line format
# Format: NAME (id): desc | in: {field:type, ...} [required] | out: {field:type}
in_str = ", ".join(f"{k}:{v}" for k, v in inputs.items())
out_str = ", ".join(f"{k}:{v}" for k, v in outputs.items())
req_str = f" req=[{','.join(required)}]" if required else ""
static = " [static]" if getattr(block, "static_output", False) else ""
line = f"- {name} (id: {block_id}): {desc}"
if in_str:
line += f"\n in: {{{in_str}}}{req_str}"
if out_str:
line += f"\n out: {{{out_str}}}{static}"
summaries.append(line)
return "\n".join(summaries)
def get_blocks_info() -> list[dict[str, Any]]:
"""Get block information with schemas for validation and fixing."""
blocks = get_blocks()
blocks_info = []
for block_id, block_cls in blocks.items():
block = block_cls()
blocks_info.append(
{
"id": block_id,
"name": block.name,
"description": getattr(block, "description", ""),
"categories": getattr(block, "categories", []),
"staticOutput": getattr(block, "static_output", False),
"inputSchema": (
block.input_schema.jsonschema()
if hasattr(block, "input_schema")
else {}
),
"outputSchema": (
block.output_schema.jsonschema()
if hasattr(block, "output_schema")
else {}
),
}
)
return blocks_info
def parse_json_from_llm(text: str) -> dict[str, Any] | None:
"""Extract JSON from LLM response (handles markdown code blocks)."""
if not text:
return None
# Try fenced code block
match = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, re.IGNORECASE)
if match:
try:
return json.loads(match.group(1).strip())
except json.JSONDecodeError:
pass
# Try raw text
try:
return json.loads(text.strip())
except json.JSONDecodeError:
pass
# Try finding {...} span
start = text.find("{")
end = text.rfind("}")
if start != -1 and end > start:
try:
return json.loads(text[start : end + 1])
except json.JSONDecodeError:
pass
# Try finding [...] span
start = text.find("[")
end = text.rfind("]")
if start != -1 and end > start:
try:
return json.loads(text[start : end + 1])
except json.JSONDecodeError:
pass
return None

View File

@@ -0,0 +1,279 @@
"""Agent validator - Validates agent structure and connections."""
import logging
import re
from typing import Any
from .utils import get_blocks_info
logger = logging.getLogger(__name__)
class AgentValidator:
"""Validator for AutoGPT agents with detailed error reporting."""
def __init__(self):
self.errors: list[str] = []
def add_error(self, error: str) -> None:
"""Add an error message."""
self.errors.append(error)
def validate_block_existence(
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> bool:
"""Validate all block IDs exist in the blocks library."""
valid = True
valid_block_ids = {b.get("id") for b in blocks_info if b.get("id")}
for node in agent.get("nodes", []):
block_id = node.get("block_id")
node_id = node.get("id")
if not block_id:
self.add_error(f"Node '{node_id}' is missing 'block_id' field.")
valid = False
continue
if block_id not in valid_block_ids:
self.add_error(
f"Node '{node_id}' references block_id '{block_id}' which does not exist."
)
valid = False
return valid
def validate_link_node_references(self, agent: dict[str, Any]) -> bool:
"""Validate all node IDs referenced in links exist."""
valid = True
valid_node_ids = {n.get("id") for n in agent.get("nodes", []) if n.get("id")}
for link in agent.get("links", []):
link_id = link.get("id", "Unknown")
source_id = link.get("source_id")
sink_id = link.get("sink_id")
if not source_id:
self.add_error(f"Link '{link_id}' is missing 'source_id'.")
valid = False
elif source_id not in valid_node_ids:
self.add_error(
f"Link '{link_id}' references non-existent source_id '{source_id}'."
)
valid = False
if not sink_id:
self.add_error(f"Link '{link_id}' is missing 'sink_id'.")
valid = False
elif sink_id not in valid_node_ids:
self.add_error(
f"Link '{link_id}' references non-existent sink_id '{sink_id}'."
)
valid = False
return valid
def validate_required_inputs(
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> bool:
"""Validate required inputs are provided."""
valid = True
block_map = {b.get("id"): b for b in blocks_info}
for node in agent.get("nodes", []):
block_id = node.get("block_id")
block = block_map.get(block_id)
if not block:
continue
required_inputs = block.get("inputSchema", {}).get("required", [])
input_defaults = node.get("input_default", {})
node_id = node.get("id")
# Get linked inputs
linked_inputs = {
link["sink_name"]
for link in agent.get("links", [])
if link.get("sink_id") == node_id
}
for req_input in required_inputs:
if (
req_input not in input_defaults
and req_input not in linked_inputs
and req_input != "credentials"
):
block_name = block.get("name", "Unknown Block")
self.add_error(
f"Node '{node_id}' ({block_name}) is missing required input '{req_input}'."
)
valid = False
return valid
def validate_data_type_compatibility(
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> bool:
"""Validate linked data types are compatible."""
valid = True
block_map = {b.get("id"): b for b in blocks_info}
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
def get_type(schema: dict, name: str) -> str | None:
if "_#_" in name:
parent, child = name.split("_#_", 1)
parent_schema = schema.get(parent, {})
if "properties" in parent_schema:
return parent_schema["properties"].get(child, {}).get("type")
return None
return schema.get(name, {}).get("type")
def are_compatible(src: str, sink: str) -> bool:
if {src, sink} <= {"integer", "number"}:
return True
return src == sink
for link in agent.get("links", []):
source_node = node_lookup.get(link.get("source_id"))
sink_node = node_lookup.get(link.get("sink_id"))
if not source_node or not sink_node:
continue
source_block = block_map.get(source_node.get("block_id"))
sink_block = block_map.get(sink_node.get("block_id"))
if not source_block or not sink_block:
continue
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
source_type = get_type(source_outputs, link.get("source_name", ""))
sink_type = get_type(sink_inputs, link.get("sink_name", ""))
if source_type and sink_type and not are_compatible(source_type, sink_type):
self.add_error(
f"Type mismatch: {source_block.get('name')} output '{link['source_name']}' "
f"({source_type}) -> {sink_block.get('name')} input '{link['sink_name']}' ({sink_type})."
)
valid = False
return valid
def validate_nested_sink_links(
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
) -> bool:
"""Validate nested sink links (with _#_ notation)."""
valid = True
block_map = {b.get("id"): b for b in blocks_info}
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
for link in agent.get("links", []):
sink_name = link.get("sink_name", "")
if "_#_" in sink_name:
parent, child = sink_name.split("_#_", 1)
sink_node = node_lookup.get(link.get("sink_id"))
if not sink_node:
continue
block = block_map.get(sink_node.get("block_id"))
if not block:
continue
input_props = block.get("inputSchema", {}).get("properties", {})
parent_schema = input_props.get(parent)
if not parent_schema:
self.add_error(
f"Invalid nested link '{sink_name}': parent '{parent}' not found."
)
valid = False
continue
if not parent_schema.get("additionalProperties"):
if not (
isinstance(parent_schema, dict)
and "properties" in parent_schema
and child in parent_schema.get("properties", {})
):
self.add_error(
f"Invalid nested link '{sink_name}': child '{child}' not found in '{parent}'."
)
valid = False
return valid
def validate_prompt_spaces(self, agent: dict[str, Any]) -> bool:
"""Validate prompts don't have spaces in template variables."""
valid = True
for node in agent.get("nodes", []):
input_default = node.get("input_default", {})
prompt = input_default.get("prompt", "")
if not isinstance(prompt, str):
continue
# Find {{...}} with spaces
matches = re.finditer(r"\{\{([^}]+)\}\}", prompt)
for match in matches:
content = match.group(1)
if " " in content:
self.add_error(
f"Node '{node.get('id')}' has spaces in template variable: "
f"'{{{{{content}}}}}' should be '{{{{{content.replace(' ', '_')}}}}}'."
)
valid = False
return valid
def validate(
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
) -> tuple[bool, str | None]:
"""Run all validations.
Returns:
Tuple of (is_valid, error_message)
"""
self.errors = []
if blocks_info is None:
blocks_info = get_blocks_info()
checks = [
self.validate_block_existence(agent, blocks_info),
self.validate_link_node_references(agent),
self.validate_required_inputs(agent, blocks_info),
self.validate_data_type_compatibility(agent, blocks_info),
self.validate_nested_sink_links(agent, blocks_info),
self.validate_prompt_spaces(agent),
]
all_passed = all(checks)
if all_passed:
logger.info("Agent validation successful")
return True, None
error_message = "Agent validation failed:\n"
for i, error in enumerate(self.errors, 1):
error_message += f"{i}. {error}\n"
logger.warning(f"Agent validation failed with {len(self.errors)} errors")
return False, error_message
def validate_agent(
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
) -> tuple[bool, str | None]:
"""Convenience function to validate an agent.
Returns:
Tuple of (is_valid, error_message)
"""
validator = AgentValidator()
return validator.validate(agent, blocks_info)

View File

@@ -5,6 +5,7 @@ import re
from datetime import datetime, timedelta, timezone
from typing import Any
from langfuse import observe
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
@@ -328,6 +329,7 @@ class AgentOutputTool(BaseTool):
total_executions=len(available_executions) if available_executions else 1,
)
@observe(as_type="tool", name="view_agent_output")
async def _execute(
self,
user_id: str | None,

View File

@@ -1,7 +1,6 @@
"""Shared agent search functionality for find_agent and find_library_agent tools."""
import logging
import re
from typing import Literal
from backend.api.features.library import db as library_db
@@ -20,85 +19,6 @@ logger = logging.getLogger(__name__)
SearchSource = Literal["marketplace", "library"]
_UUID_PATTERN = re.compile(
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$",
re.IGNORECASE,
)
def _is_uuid(text: str) -> bool:
"""Check if text is a valid UUID v4."""
return bool(_UUID_PATTERN.match(text.strip()))
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
"""Fetch a library agent by ID (library agent ID or graph_id).
Tries multiple lookup strategies:
1. First by graph_id (AgentGraph primary key)
2. Then by library agent ID (LibraryAgent primary key)
Args:
user_id: The user ID
agent_id: The ID to look up (can be graph_id or library agent ID)
Returns:
AgentInfo if found, None otherwise
"""
try:
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
if agent:
logger.debug(f"Found library agent by graph_id: {agent.name}")
return AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by graph_id {agent_id}: {e}",
exc_info=True,
)
try:
agent = await library_db.get_library_agent(agent_id, user_id)
if agent:
logger.debug(f"Found library agent by library_id: {agent.name}")
return AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
except NotFoundError:
logger.debug(f"Library agent not found by library_id: {agent_id}")
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by library_id {agent_id}: {e}",
exc_info=True,
)
return None
async def search_agents(
query: str,
@@ -149,37 +69,29 @@ async def search_agents(
is_featured=False,
)
)
else:
if _is_uuid(query):
logger.info(f"Query looks like UUID, trying direct lookup: {query}")
agent = await _get_library_agent_by_id(user_id, query) # type: ignore[arg-type]
if agent:
agents.append(agent)
logger.info(f"Found agent by direct ID lookup: {agent.name}")
if not agents:
logger.info(f"Searching user library for: {query}")
results = await library_db.list_library_agents(
user_id=user_id, # type: ignore[arg-type]
search_term=query,
page_size=10,
)
for agent in results.agents:
agents.append(
AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
else: # library
logger.info(f"Searching user library for: {query}")
results = await library_db.list_library_agents(
user_id=user_id, # type: ignore[arg-type]
search_term=query,
page_size=10,
)
for agent in results.agents:
agents.append(
AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
)
logger.info(f"Found {len(agents)} agents in {source}")
except NotFoundError:
pass

View File

@@ -36,16 +36,6 @@ class BaseTool:
"""Whether this tool requires authentication."""
return False
@property
def is_long_running(self) -> bool:
"""Whether this tool is long-running and should execute in background.
Long-running tools (like agent generation) are executed via background
tasks to survive SSE disconnections. The result is persisted to chat
history and visible when the user refreshes.
"""
return False
def as_openai_tool(self) -> ChatCompletionToolParam:
"""Convert to OpenAI tool format."""
return ChatCompletionToolParam(

View File

@@ -3,22 +3,22 @@
import logging
from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession
from .agent_generator import (
AgentGeneratorNotConfiguredError,
apply_all_fixes,
decompose_goal,
enrich_library_agents_from_steps,
generate_agent,
get_all_relevant_agents_for_generation,
get_user_message_for_error,
get_blocks_info,
save_agent_to_library,
validate_agent,
)
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -27,6 +27,9 @@ from .models import (
logger = logging.getLogger(__name__)
# Maximum retries for agent generation with validation feedback
MAX_GENERATION_RETRIES = 2
class CreateAgentTool(BaseTool):
"""Tool for creating agents from natural language descriptions."""
@@ -46,10 +49,6 @@ class CreateAgentTool(BaseTool):
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
@@ -81,6 +80,7 @@ class CreateAgentTool(BaseTool):
"required": ["description"],
}
@observe(as_type="tool", name="create_agent")
async def _execute(
self,
user_id: str | None,
@@ -91,18 +91,15 @@ class CreateAgentTool(BaseTool):
Flow:
1. Decompose the description into steps (may return clarifying questions)
2. Generate agent JSON (external service handles fixing and validation)
3. Preview or save based on the save parameter
2. Generate agent JSON from the steps
3. Apply fixes to correct common LLM errors
4. Preview or save based on the save parameter
"""
description = kwargs.get("description", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not description:
return ErrorResponse(
message="Please provide a description of what the agent should do.",
@@ -110,61 +107,25 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
library_agents = None
if user_id:
try:
library_agents = await get_all_relevant_agents_for_generation(
user_id=user_id,
search_query=description,
include_marketplace=True,
)
logger.debug(
f"Found {len(library_agents)} relevant agents for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents: {e}")
# Step 1: Decompose goal into steps
try:
decomposition_result = await decompose_goal(
description, context, library_agents
)
except AgentGeneratorNotConfiguredError:
decomposition_result = await decompose_goal(description, context)
except ValueError as e:
# Handle missing API key or configuration errors
return ErrorResponse(
message=(
"Agent generation is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
message=f"Agent generation is not configured: {str(e)}",
error="configuration_error",
session_id=session_id,
)
if decomposition_result is None:
return ErrorResponse(
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
error="decomposition_failed",
details={"description": description[:100]},
session_id=session_id,
)
if decomposition_result.get("type") == "error":
error_msg = decomposition_result.get("error", "Unknown error")
error_type = decomposition_result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="analyze the goal",
llm_parse_message="The AI had trouble understanding this request. Please try rephrasing your goal.",
)
return ErrorResponse(
message=user_message,
error=f"decomposition_failed:{error_type}",
details={
"description": description[:100],
"service_error": error_msg,
"error_type": error_type,
},
message="Failed to analyze the goal. Please try rephrasing.",
error="Decomposition failed",
session_id=session_id,
)
# Check if LLM returned clarifying questions
if decomposition_result.get("type") == "clarifying_questions":
questions = decomposition_result.get("questions", [])
return ClarificationNeededResponse(
@@ -183,6 +144,7 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Check for unachievable/vague goals
if decomposition_result.get("type") == "unachievable_goal":
suggested = decomposition_result.get("suggested_goal", "")
reason = decomposition_result.get("reason", "")
@@ -209,88 +171,72 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
if user_id and library_agents is not None:
try:
library_agents = await enrich_library_agents_from_steps(
user_id=user_id,
decomposition_result=decomposition_result,
existing_agents=library_agents,
include_marketplace=True,
# Step 2: Generate agent JSON with retry on validation failure
blocks_info = get_blocks_info()
agent_json = None
validation_errors = None
for attempt in range(MAX_GENERATION_RETRIES + 1):
# Generate agent (include validation errors from previous attempt)
if attempt == 0:
agent_json = await generate_agent(decomposition_result)
else:
# Retry with validation error feedback
logger.info(
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
)
logger.debug(
f"After enrichment: {len(library_agents)} total agents for sub-agent composition"
retry_instructions = {
**decomposition_result,
"previous_errors": validation_errors,
"retry_instructions": (
"The previous generation had validation errors. "
"Please fix these issues in the new generation:\n"
f"{validation_errors}"
),
}
agent_json = await generate_agent(retry_instructions)
if agent_json is None:
if attempt == MAX_GENERATION_RETRIES:
return ErrorResponse(
message="Failed to generate the agent. Please try again.",
error="Generation failed",
session_id=session_id,
)
continue
# Step 3: Apply fixes to correct common errors
agent_json = apply_all_fixes(agent_json, blocks_info)
# Step 4: Validate the agent
is_valid, validation_errors = validate_agent(agent_json, blocks_info)
if is_valid:
logger.info(f"Agent generated successfully on attempt {attempt + 1}")
break
logger.warning(
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
)
if attempt == MAX_GENERATION_RETRIES:
# Return error with validation details
return ErrorResponse(
message=(
f"Generated agent has validation errors after {MAX_GENERATION_RETRIES + 1} attempts. "
f"Please try rephrasing your request or simplify the workflow."
),
error="validation_failed",
details={"validation_errors": validation_errors},
session_id=session_id,
)
except Exception as e:
logger.warning(f"Failed to enrich library agents from steps: {e}")
try:
agent_json = await generate_agent(
decomposition_result,
library_agents,
operation_id=operation_id,
task_id=task_id,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
"Agent generation is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
if agent_json is None:
return ErrorResponse(
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
error="generation_failed",
details={"description": description[:100]},
session_id=session_id,
)
if isinstance(agent_json, dict) and agent_json.get("type") == "error":
error_msg = agent_json.get("error", "Unknown error")
error_type = agent_json.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="generate the agent",
llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.",
validation_message=(
"I wasn't able to create a valid agent for this request. "
"The generated workflow had some structural issues. "
"Please try simplifying your goal or breaking it into smaller steps."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"generation_failed:{error_type}",
details={
"description": description[:100],
"service_error": error_msg,
"error_type": error_type,
},
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if agent_json.get("status") == "accepted":
logger.info(
f"Agent generation delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent generation started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
agent_name = agent_json.get("name", "Generated Agent")
agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", []))
# Step 4: Preview or save
if not save:
return AgentPreviewResponse(
message=(
@@ -305,6 +251,7 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Save to library
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
@@ -322,7 +269,7 @@ class CreateAgentTool(BaseTool):
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
library_agent_link=f"/library/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)

View File

@@ -1,337 +0,0 @@
"""CustomizeAgentTool - Customizes marketplace/template agents using natural language."""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.store import db as store_db
from backend.api.features.store.exceptions import AgentNotFoundError
from .agent_generator import (
AgentGeneratorNotConfiguredError,
customize_template,
get_user_message_for_error,
graph_to_json,
save_agent_to_library,
)
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
ToolResponseBase,
)
logger = logging.getLogger(__name__)
class CustomizeAgentTool(BaseTool):
"""Tool for customizing marketplace/template agents using natural language."""
@property
def name(self) -> str:
return "customize_agent"
@property
def description(self) -> str:
return (
"Customize a marketplace or template agent using natural language. "
"Takes an existing agent from the marketplace and modifies it based on "
"the user's requirements before adding to their library."
)
@property
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"agent_id": {
"type": "string",
"description": (
"The marketplace agent ID in format 'creator/slug' "
"(e.g., 'autogpt/newsletter-writer'). "
"Get this from find_agent results."
),
},
"modifications": {
"type": "string",
"description": (
"Natural language description of how to customize the agent. "
"Be specific about what changes you want to make."
),
},
"context": {
"type": "string",
"description": (
"Additional context or answers to previous clarifying questions."
),
},
"save": {
"type": "boolean",
"description": (
"Whether to save the customized agent to the user's library. "
"Default is true. Set to false for preview only."
),
"default": True,
},
},
"required": ["agent_id", "modifications"],
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the customize_agent tool.
Flow:
1. Parse the agent ID to get creator/slug
2. Fetch the template agent from the marketplace
3. Call customize_template with the modification request
4. Preview or save based on the save parameter
"""
agent_id = kwargs.get("agent_id", "").strip()
modifications = kwargs.get("modifications", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
session_id = session.session_id if session else None
if not agent_id:
return ErrorResponse(
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
error="missing_agent_id",
session_id=session_id,
)
if not modifications:
return ErrorResponse(
message="Please describe how you want to customize this agent.",
error="missing_modifications",
session_id=session_id,
)
# Parse agent_id in format "creator/slug"
parts = [p.strip() for p in agent_id.split("/")]
if len(parts) != 2 or not parts[0] or not parts[1]:
return ErrorResponse(
message=(
f"Invalid agent ID format: '{agent_id}'. "
"Expected format is 'creator/agent-name' "
"(e.g., 'autogpt/newsletter-writer')."
),
error="invalid_agent_id_format",
session_id=session_id,
)
creator_username, agent_slug = parts
# Fetch the marketplace agent details
try:
agent_details = await store_db.get_store_agent_details(
username=creator_username, agent_name=agent_slug
)
except AgentNotFoundError:
return ErrorResponse(
message=(
f"Could not find marketplace agent '{agent_id}'. "
"Please check the agent ID and try again."
),
error="agent_not_found",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
return ErrorResponse(
message="Failed to fetch the marketplace agent. Please try again.",
error="fetch_error",
session_id=session_id,
)
if not agent_details.store_listing_version_id:
return ErrorResponse(
message=(
f"The agent '{agent_id}' does not have an available version. "
"Please try a different agent."
),
error="no_version_available",
session_id=session_id,
)
# Get the full agent graph
try:
graph = await store_db.get_agent(agent_details.store_listing_version_id)
template_agent = graph_to_json(graph)
except Exception as e:
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
return ErrorResponse(
message="Failed to fetch the agent configuration. Please try again.",
error="graph_fetch_error",
session_id=session_id,
)
# Call customize_template
try:
result = await customize_template(
template_agent=template_agent,
modification_request=modifications,
context=context,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
"Agent customization is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error calling customize_template for {agent_id}: {e}")
return ErrorResponse(
message=(
"Failed to customize the agent due to a service error. "
"Please try again."
),
error="customization_service_error",
session_id=session_id,
)
if result is None:
return ErrorResponse(
message=(
"Failed to customize the agent. "
"The agent generation service may be unavailable or timed out. "
"Please try again."
),
error="customization_failed",
session_id=session_id,
)
# Handle error response
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="customize the agent",
llm_parse_message=(
"The AI had trouble customizing the agent. "
"Please try again or simplify your request."
),
validation_message=(
"The customized agent failed validation. "
"Please try rephrasing your request."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"customization_failed:{error_type}",
session_id=session_id,
)
# Handle clarifying questions
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
questions = result.get("questions") or []
if not isinstance(questions, list):
logger.error(
f"Unexpected clarifying questions format: {type(questions)}"
)
questions = []
return ClarificationNeededResponse(
message=(
"I need some more information to customize this agent. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
if isinstance(q, dict)
],
session_id=session_id,
)
# Result should be the customized agent JSON
if not isinstance(result, dict):
logger.error(f"Unexpected customize_template response type: {type(result)}")
return ErrorResponse(
message="Failed to customize the agent due to an unexpected response.",
error="unexpected_response_type",
session_id=session_id,
)
customized_agent = result
agent_name = customized_agent.get(
"name", f"Customized {agent_details.agent_name}"
)
agent_description = customized_agent.get("description", "")
nodes = customized_agent.get("nodes")
links = customized_agent.get("links")
node_count = len(nodes) if isinstance(nodes, list) else 0
link_count = len(links) if isinstance(links, list) else 0
if not save:
return AgentPreviewResponse(
message=(
f"I've customized the agent '{agent_details.agent_name}'. "
f"The customized agent has {node_count} blocks. "
f"Review it and call customize_agent with save=true to save it."
),
agent_json=customized_agent,
agent_name=agent_name,
description=agent_description,
node_count=node_count,
link_count=link_count,
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
error="auth_required",
session_id=session_id,
)
# Save to user's library
try:
created_graph, library_agent = await save_agent_to_library(
customized_agent, user_id, is_update=False
)
return AgentSavedResponse(
message=(
f"Customized agent '{created_graph.name}' "
f"(based on '{agent_details.agent_name}') "
f"has been saved to your library!"
),
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error saving customized agent: {e}")
return ErrorResponse(
message="Failed to save the customized agent. Please try again.",
error="save_failed",
session_id=session_id,
)

View File

@@ -3,21 +3,23 @@
import logging
from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession
from .agent_generator import (
AgentGeneratorNotConfiguredError,
apply_agent_patch,
apply_all_fixes,
generate_agent_patch,
get_agent_as_json,
get_all_relevant_agents_for_generation,
get_user_message_for_error,
get_blocks_info,
save_agent_to_library,
validate_agent,
)
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -26,6 +28,9 @@ from .models import (
logger = logging.getLogger(__name__)
# Maximum retries for patch generation with validation feedback
MAX_GENERATION_RETRIES = 2
class EditAgentTool(BaseTool):
"""Tool for editing existing agents using natural language."""
@@ -38,17 +43,13 @@ class EditAgentTool(BaseTool):
def description(self) -> str:
return (
"Edit an existing agent from the user's library using natural language. "
"Generates updates to the agent while preserving unchanged parts."
"Generates a patch to update the agent while preserving unchanged parts."
)
@property
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
@@ -86,6 +87,7 @@ class EditAgentTool(BaseTool):
"required": ["agent_id", "changes"],
}
@observe(as_type="tool", name="edit_agent")
async def _execute(
self,
user_id: str | None,
@@ -96,8 +98,9 @@ class EditAgentTool(BaseTool):
Flow:
1. Fetch the current agent
2. Generate updated agent (external service handles fixing and validation)
3. Preview or save based on the save parameter
2. Generate a patch based on the requested changes
3. Apply the patch to create an updated agent
4. Preview or save based on the save parameter
"""
agent_id = kwargs.get("agent_id", "").strip()
changes = kwargs.get("changes", "").strip()
@@ -105,10 +108,6 @@ class EditAgentTool(BaseTool):
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not agent_id:
return ErrorResponse(
message="Please provide the agent ID to edit.",
@@ -123,6 +122,7 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Step 1: Fetch current agent
current_agent = await get_agent_as_json(agent_id, user_id)
if current_agent is None:
@@ -132,117 +132,126 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
library_agents = None
if user_id:
try:
graph_id = current_agent.get("id")
library_agents = await get_all_relevant_agents_for_generation(
user_id=user_id,
search_query=changes,
exclude_graph_id=graph_id,
include_marketplace=True,
)
logger.debug(
f"Found {len(library_agents)} relevant agents for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents: {e}")
# Build the update request with context
update_request = changes
if context:
update_request = f"{changes}\n\nAdditional context:\n{context}"
try:
result = await generate_agent_patch(
update_request,
current_agent,
library_agents,
operation_id=operation_id,
task_id=task_id,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
"Agent editing is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
# Step 2: Generate patch with retry on validation failure
blocks_info = get_blocks_info()
updated_agent = None
validation_errors = None
intent = "Applied requested changes"
if result is None:
return ErrorResponse(
message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.",
error="update_generation_failed",
details={"agent_id": agent_id, "changes": changes[:100]},
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if result.get("status") == "accepted":
logger.info(
f"Agent edit delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent edit started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
# Check if the result is an error from the external service
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="generate the changes",
llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.",
validation_message="The generated changes failed validation. Please try rephrasing your request.",
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"update_generation_failed:{error_type}",
details={
"agent_id": agent_id,
"changes": changes[:100],
"service_error": error_msg,
"error_type": error_type,
},
session_id=session_id,
)
if result.get("type") == "clarifying_questions":
questions = result.get("questions", [])
return ClarificationNeededResponse(
message=(
"I need some more information about the changes. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
for attempt in range(MAX_GENERATION_RETRIES + 1):
# Generate patch (include validation errors from previous attempt)
try:
if attempt == 0:
patch_result = await generate_agent_patch(
update_request, current_agent
)
for q in questions
],
session_id=session_id,
else:
# Retry with validation error feedback
logger.info(
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
)
retry_request = (
f"{update_request}\n\n"
f"IMPORTANT: The previous edit had validation errors. "
f"Please fix these issues:\n{validation_errors}"
)
patch_result = await generate_agent_patch(
retry_request, current_agent
)
except ValueError as e:
# Handle missing API key or configuration errors
return ErrorResponse(
message=f"Agent generation is not configured: {str(e)}",
error="configuration_error",
session_id=session_id,
)
if patch_result is None:
if attempt == MAX_GENERATION_RETRIES:
return ErrorResponse(
message="Failed to generate changes. Please try rephrasing.",
error="Patch generation failed",
session_id=session_id,
)
continue
# Check if LLM returned clarifying questions
if patch_result.get("type") == "clarifying_questions":
questions = patch_result.get("questions", [])
return ClarificationNeededResponse(
message=(
"I need some more information about the changes. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
],
session_id=session_id,
)
# Step 3: Apply patch and fixes
try:
updated_agent = apply_agent_patch(current_agent, patch_result)
updated_agent = apply_all_fixes(updated_agent, blocks_info)
except Exception as e:
if attempt == MAX_GENERATION_RETRIES:
return ErrorResponse(
message=f"Failed to apply changes: {str(e)}",
error="patch_apply_failed",
details={"exception": str(e)},
session_id=session_id,
)
validation_errors = str(e)
continue
# Step 4: Validate the updated agent
is_valid, validation_errors = validate_agent(updated_agent, blocks_info)
if is_valid:
logger.info(f"Agent edited successfully on attempt {attempt + 1}")
intent = patch_result.get("intent", "Applied requested changes")
break
logger.warning(
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
)
updated_agent = result
if attempt == MAX_GENERATION_RETRIES:
# Return error with validation details
return ErrorResponse(
message=(
f"Updated agent has validation errors after "
f"{MAX_GENERATION_RETRIES + 1} attempts. "
f"Please try rephrasing your request or simplify the changes."
),
error="validation_failed",
details={"validation_errors": validation_errors},
session_id=session_id,
)
# At this point, updated_agent is guaranteed to be set (we return on all failure paths)
assert updated_agent is not None
agent_name = updated_agent.get("name", "Updated Agent")
agent_description = updated_agent.get("description", "")
node_count = len(updated_agent.get("nodes", []))
link_count = len(updated_agent.get("links", []))
# Step 5: Preview or save
if not save:
return AgentPreviewResponse(
message=(
f"I've updated the agent. "
f"I've updated the agent. Changes: {intent}. "
f"The agent now has {node_count} blocks. "
f"Review it and call edit_agent with save=true to save the changes."
),
@@ -254,6 +263,7 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Save to library (creates a new version)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
@@ -267,11 +277,14 @@ class EditAgentTool(BaseTool):
)
return AgentSavedResponse(
message=f"Updated agent '{created_graph.name}' has been saved to your library!",
message=(
f"Updated agent '{created_graph.name}' has been saved to your library! "
f"Changes: {intent}"
),
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
library_agent_link=f"/library/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)

View File

@@ -2,6 +2,8 @@
from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents
@@ -35,6 +37,7 @@ class FindAgentTool(BaseTool):
"required": ["query"],
}
@observe(as_type="tool", name="find_agent")
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:

View File

@@ -1,6 +1,7 @@
import logging
from typing import Any
from langfuse import observe
from prisma.enums import ContentType
from backend.api.features.chat.model import ChatSession
@@ -55,6 +56,7 @@ class FindBlockTool(BaseTool):
def requires_auth(self) -> bool:
return True
@observe(as_type="tool", name="find_block")
async def _execute(
self,
user_id: str | None,
@@ -107,8 +109,7 @@ class FindBlockTool(BaseTool):
block_id = result["content_id"]
block = get_block(block_id)
# Skip disabled blocks
if block and not block.disabled:
if block:
# Get input/output schemas
input_schema = {}
output_schema = {}

View File

@@ -2,6 +2,8 @@
from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents
@@ -41,6 +43,7 @@ class FindLibraryAgentTool(BaseTool):
def requires_auth(self) -> bool:
return True
@observe(as_type="tool", name="find_library_agent")
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:

View File

@@ -4,6 +4,8 @@ import logging
from pathlib import Path
from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
@@ -71,6 +73,7 @@ class GetDocPageTool(BaseTool):
url_path = path.rsplit(".", 1)[0] if "." in path else path
return f"{DOCS_BASE_URL}/{url_path}"
@observe(as_type="tool", name="get_doc_page")
async def _execute(
self,
user_id: str | None,

View File

@@ -28,18 +28,6 @@ class ResponseType(str, Enum):
BLOCK_OUTPUT = "block_output"
DOC_SEARCH_RESULTS = "doc_search_results"
DOC_PAGE = "doc_page"
# Workspace response types
WORKSPACE_FILE_LIST = "workspace_file_list"
WORKSPACE_FILE_CONTENT = "workspace_file_content"
WORKSPACE_FILE_METADATA = "workspace_file_metadata"
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
# Long-running operation types
OPERATION_STARTED = "operation_started"
OPERATION_PENDING = "operation_pending"
OPERATION_IN_PROGRESS = "operation_in_progress"
# Input validation
INPUT_VALIDATION_ERROR = "input_validation_error"
# Base response model
@@ -70,10 +58,6 @@ class AgentInfo(BaseModel):
has_external_trigger: bool | None = None
new_output: bool | None = None
graph_id: str | None = None
inputs: dict[str, Any] | None = Field(
default=None,
description="Input schema for the agent, including field names, types, and defaults",
)
class AgentsFoundResponse(ToolResponseBase):
@@ -200,20 +184,6 @@ class ErrorResponse(ToolResponseBase):
details: dict[str, Any] | None = None
class InputValidationErrorResponse(ToolResponseBase):
"""Response when run_agent receives unknown input fields."""
type: ResponseType = ResponseType.INPUT_VALIDATION_ERROR
unrecognized_fields: list[str] = Field(
description="List of input field names that were not recognized"
)
inputs: dict[str, Any] = Field(
description="The agent's valid input schema for reference"
)
graph_id: str | None = None
graph_version: int | None = None
# Agent output models
class ExecutionOutputInfo(BaseModel):
"""Summary of a single execution's outputs."""
@@ -364,60 +334,3 @@ class BlockOutputResponse(ToolResponseBase):
block_name: str
outputs: dict[str, list[Any]]
success: bool = True
# Long-running operation models
class OperationStartedResponse(ToolResponseBase):
"""Response when a long-running operation has been started in the background.
This is returned immediately to the client while the operation continues
to execute. The user can close the tab and check back later.
The task_id can be used to reconnect to the SSE stream via
GET /chat/tasks/{task_id}/stream?last_idx=0
"""
type: ResponseType = ResponseType.OPERATION_STARTED
operation_id: str
tool_name: str
task_id: str | None = None # For SSE reconnection
class OperationPendingResponse(ToolResponseBase):
"""Response stored in chat history while a long-running operation is executing.
This is persisted to the database so users see a pending state when they
refresh before the operation completes.
"""
type: ResponseType = ResponseType.OPERATION_PENDING
operation_id: str
tool_name: str
class OperationInProgressResponse(ToolResponseBase):
"""Response when an operation is already in progress.
Returned for idempotency when the same tool_call_id is requested again
while the background task is still running.
"""
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
tool_call_id: str
class AsyncProcessingResponse(ToolResponseBase):
"""Response when an operation has been delegated to async processing.
This is returned by tools when the external service accepts the request
for async processing (HTTP 202 Accepted). The Redis Streams completion
consumer will handle the result when the external service completes.
The status field is specifically "accepted" to allow the long-running tool
handler to detect this response and skip LLM continuation.
"""
type: ResponseType = ResponseType.OPERATION_STARTED
status: str = "accepted" # Must be "accepted" for detection
operation_id: str | None = None
task_id: str | None = None

View File

@@ -3,14 +3,11 @@
import logging
from typing import Any
from langfuse import observe
from pydantic import BaseModel, Field, field_validator
from backend.api.features.chat.config import ChatConfig
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tracking import (
track_agent_run_success,
track_agent_scheduled,
)
from backend.api.features.library import db as library_db
from backend.data.graph import GraphModel
from backend.data.model import CredentialsMetaInput
@@ -30,7 +27,6 @@ from .models import (
ErrorResponse,
ExecutionOptions,
ExecutionStartedResponse,
InputValidationErrorResponse,
SetupInfo,
SetupRequirementsResponse,
ToolResponseBase,
@@ -159,6 +155,7 @@ class RunAgentTool(BaseTool):
"""All operations require authentication."""
return True
@observe(as_type="tool", name="run_agent")
async def _execute(
self,
user_id: str | None,
@@ -274,22 +271,6 @@ class RunAgentTool(BaseTool):
input_properties = graph.input_schema.get("properties", {})
required_fields = set(graph.input_schema.get("required", []))
provided_inputs = set(params.inputs.keys())
valid_fields = set(input_properties.keys())
# Check for unknown input fields
unrecognized_fields = provided_inputs - valid_fields
if unrecognized_fields:
return InputValidationErrorResponse(
message=(
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
f"Agent was not executed. Please use the correct field names from the schema."
),
session_id=session_id,
unrecognized_fields=sorted(unrecognized_fields),
inputs=graph.input_schema,
graph_id=graph.id,
graph_version=graph.version,
)
# If agent has inputs but none were provided AND use_defaults is not set,
# always show what's available first so user can decide
@@ -472,16 +453,6 @@ class RunAgentTool(BaseTool):
session.successful_agent_runs.get(library_agent.graph_id, 0) + 1
)
# Track in PostHog
track_agent_run_success(
user_id=user_id,
session_id=session_id,
graph_id=library_agent.graph_id,
graph_name=library_agent.name,
execution_id=execution.id,
library_agent_id=library_agent.id,
)
library_agent_link = f"/library/agents/{library_agent.id}"
return ExecutionStartedResponse(
message=(
@@ -563,18 +534,6 @@ class RunAgentTool(BaseTool):
session.successful_agent_schedules.get(library_agent.graph_id, 0) + 1
)
# Track in PostHog
track_agent_scheduled(
user_id=user_id,
session_id=session_id,
graph_id=library_agent.graph_id,
graph_name=library_agent.name,
schedule_id=result.id,
schedule_name=schedule_name,
cron=cron,
library_agent_id=library_agent.id,
)
library_agent_link = f"/library/agents/{library_agent.id}"
return ExecutionStartedResponse(
message=(

View File

@@ -29,7 +29,7 @@ def mock_embedding_functions():
yield
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(scope="session")
async def test_run_agent(setup_test_data):
"""Test that the run_agent tool successfully executes an approved agent"""
# Use test data from fixture
@@ -70,7 +70,7 @@ async def test_run_agent(setup_test_data):
assert result_data["graph_name"] == "Test Agent"
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(scope="session")
async def test_run_agent_missing_inputs(setup_test_data):
"""Test that the run_agent tool returns error when inputs are missing"""
# Use test data from fixture
@@ -106,7 +106,7 @@ async def test_run_agent_missing_inputs(setup_test_data):
assert "message" in result_data
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(scope="session")
async def test_run_agent_invalid_agent_id(setup_test_data):
"""Test that the run_agent tool returns error for invalid agent ID"""
# Use test data from fixture
@@ -141,7 +141,7 @@ async def test_run_agent_invalid_agent_id(setup_test_data):
)
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(scope="session")
async def test_run_agent_with_llm_credentials(setup_llm_test_data):
"""Test that run_agent works with an agent requiring LLM credentials"""
# Use test data from fixture
@@ -185,7 +185,7 @@ async def test_run_agent_with_llm_credentials(setup_llm_test_data):
assert result_data["graph_name"] == "LLM Test Agent"
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(scope="session")
async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_data):
"""Test that run_agent returns available inputs when called without inputs or use_defaults."""
user = setup_test_data["user"]
@@ -219,7 +219,7 @@ async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_da
assert "inputs" in result_data["message"].lower()
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(scope="session")
async def test_run_agent_with_use_defaults(setup_test_data):
"""Test that run_agent executes successfully with use_defaults=True."""
user = setup_test_data["user"]
@@ -251,7 +251,7 @@ async def test_run_agent_with_use_defaults(setup_test_data):
assert result_data["graph_id"] == graph.id
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(scope="session")
async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
"""Test that run_agent returns setup_requirements when credentials are missing."""
user = setup_firecrawl_test_data["user"]
@@ -285,7 +285,7 @@ async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
assert len(setup_info["user_readiness"]["missing_credentials"]) > 0
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(scope="session")
async def test_run_agent_invalid_slug_format(setup_test_data):
"""Test that run_agent returns error for invalid slug format (no slash)."""
user = setup_test_data["user"]
@@ -313,7 +313,7 @@ async def test_run_agent_invalid_slug_format(setup_test_data):
assert "username/agent-name" in result_data["message"]
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(scope="session")
async def test_run_agent_unauthenticated():
"""Test that run_agent returns need_login for unauthenticated users."""
tool = RunAgentTool()
@@ -340,7 +340,7 @@ async def test_run_agent_unauthenticated():
assert "sign in" in result_data["message"].lower()
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(scope="session")
async def test_run_agent_schedule_without_cron(setup_test_data):
"""Test that run_agent returns error when scheduling without cron expression."""
user = setup_test_data["user"]
@@ -372,7 +372,7 @@ async def test_run_agent_schedule_without_cron(setup_test_data):
assert "cron" in result_data["message"].lower()
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(scope="session")
async def test_run_agent_schedule_without_name(setup_test_data):
"""Test that run_agent returns error when scheduling without schedule_name."""
user = setup_test_data["user"]
@@ -402,42 +402,3 @@ async def test_run_agent_schedule_without_name(setup_test_data):
# Should return error about missing schedule_name
assert result_data.get("type") == "error"
assert "schedule_name" in result_data["message"].lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_run_agent_rejects_unknown_input_fields(setup_test_data):
"""Test that run_agent returns input_validation_error for unknown input fields."""
user = setup_test_data["user"]
store_submission = setup_test_data["store_submission"]
tool = RunAgentTool()
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
session = make_session(user_id=user.id)
# Execute with unknown input field names
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={
"unknown_field": "some value",
"another_unknown": "another value",
},
session=session,
)
assert response is not None
assert hasattr(response, "output")
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
# Should return input_validation_error type with unrecognized fields
assert result_data.get("type") == "input_validation_error"
assert "unrecognized_fields" in result_data
assert set(result_data["unrecognized_fields"]) == {
"another_unknown",
"unknown_field",
}
assert "inputs" in result_data # Contains the valid schema
assert "Agent was not executed" in result_data["message"]

View File

@@ -1,17 +1,15 @@
"""Tool for executing blocks directly."""
import logging
import uuid
from collections import defaultdict
from typing import Any
from pydantic_core import PydanticUndefined
from langfuse import observe
from backend.api.features.chat.model import ChatSession
from backend.data.block import get_block
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
@@ -77,22 +75,15 @@ class RunBlockTool(BaseTool):
self,
user_id: str,
block: Any,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Check if user has required credentials for a block.
Args:
user_id: User ID
block: Block to check credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple[matched_credentials, missing_credentials]
"""
matched_credentials: dict[str, CredentialsMetaInput] = {}
missing_credentials: list[CredentialsMetaInput] = []
input_data = input_data or {}
# Get credential field info from block's input schema
credentials_fields_info = block.input_schema.get_credentials_fields_info()
@@ -105,33 +96,14 @@ class RunBlockTool(BaseTool):
available_creds = await creds_manager.store.get_all_creds(user_id)
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
# Get discriminator from input, falling back to schema default
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
# field_info.provider is a frozenset of acceptable providers
# field_info.supported_types is a frozenset of acceptable types
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in effective_field_info.provider
and cred.type in effective_field_info.supported_types
if cred.provider in field_info.provider
and cred.type in field_info.supported_types
),
None,
)
@@ -145,8 +117,8 @@ class RunBlockTool(BaseTool):
)
else:
# Create a placeholder for the missing credential
provider = next(iter(effective_field_info.provider), "unknown")
cred_type = next(iter(effective_field_info.supported_types), "api_key")
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing_credentials.append(
CredentialsMetaInput(
id=field_name,
@@ -158,6 +130,7 @@ class RunBlockTool(BaseTool):
return matched_credentials, missing_credentials
@observe(as_type="tool", name="run_block")
async def _execute(
self,
user_id: str | None,
@@ -206,17 +179,13 @@ class RunBlockTool(BaseTool):
message=f"Block '{block_id}' not found",
session_id=session_id,
)
if block.disabled:
return ErrorResponse(
message=f"Block '{block_id}' is disabled",
session_id=session_id,
)
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
# Check credentials
creds_manager = IntegrationCredentialsManager()
matched_credentials, missing_credentials = await self._check_block_credentials(
user_id, block, input_data
user_id, block
)
if missing_credentials:
@@ -252,48 +221,11 @@ class RunBlockTool(BaseTool):
)
try:
# Get or create user's workspace for CoPilot file operations
workspace = await get_or_create_workspace(user_id)
# Generate synthetic IDs for CoPilot context
# Each chat session is treated as its own agent with one continuous run
# This means:
# - graph_id (agent) = session (memories scoped to session when limit_to_agent=True)
# - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True)
# - node_exec_id = unique per block execution
synthetic_graph_id = f"copilot-session-{session.session_id}"
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
synthetic_node_id = f"copilot-node-{block_id}"
synthetic_node_exec_id = (
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
)
# Create unified execution context with all required fields
execution_context = ExecutionContext(
# Execution identity
user_id=user_id,
graph_id=synthetic_graph_id,
graph_exec_id=synthetic_graph_exec_id,
graph_version=1, # Versions are 1-indexed
node_id=synthetic_node_id,
node_exec_id=synthetic_node_exec_id,
# Workspace with session scoping
workspace_id=workspace.id,
session_id=session.session_id,
)
# Prepare kwargs for block execution
# Keep individual kwargs for backwards compatibility with existing blocks
# Fetch actual credentials and prepare kwargs for block execution
# Create execution context with defaults (blocks may require it)
exec_kwargs: dict[str, Any] = {
"user_id": user_id,
"execution_context": execution_context,
# Legacy: individual kwargs for blocks not yet using execution_context
"workspace_id": workspace.id,
"graph_exec_id": synthetic_graph_exec_id,
"node_exec_id": synthetic_node_exec_id,
"node_id": synthetic_node_id,
"graph_version": 1, # Versions are 1-indexed
"graph_id": synthetic_graph_id,
"execution_context": ExecutionContext(),
}
for field_name, cred_meta in matched_credentials.items():

View File

@@ -3,6 +3,7 @@
import logging
from typing import Any
from langfuse import observe
from prisma.enums import ContentType
from backend.api.features.chat.model import ChatSession
@@ -87,6 +88,7 @@ class SearchDocsTool(BaseTool):
url_path = path.rsplit(".", 1)[0] if "." in path else path
return f"{DOCS_BASE_URL}/{url_path}"
@observe(as_type="tool", name="search_docs")
async def _execute(
self,
user_id: str | None,

View File

@@ -8,7 +8,7 @@ from backend.api.features.library import model as library_model
from backend.api.features.store import db as store_db
from backend.data import graph as graph_db
from backend.data.graph import GraphModel
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import NotFoundError
@@ -266,14 +266,13 @@ async def match_user_credentials_to_graph(
credential_requirements,
_node_fields,
) in aggregated_creds.items():
# Find first matching credential by provider, type, and scopes
# Find first matching credential by provider and type
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in credential_requirements.provider
and cred.type in credential_requirements.supported_types
and _credential_has_required_scopes(cred, credential_requirements)
),
None,
)
@@ -297,17 +296,10 @@ async def match_user_credentials_to_graph(
f"{credential_field_name} (validation failed: {e})"
)
else:
# Build a helpful error message including scope requirements
error_parts = [
f"provider in {list(credential_requirements.provider)}",
f"type in {list(credential_requirements.supported_types)}",
]
if credential_requirements.required_scopes:
error_parts.append(
f"scopes including {list(credential_requirements.required_scopes)}"
)
missing_creds.append(
f"{credential_field_name} (requires {', '.join(error_parts)})"
f"{credential_field_name} "
f"(requires provider in {list(credential_requirements.provider)}, "
f"type in {list(credential_requirements.supported_types)})"
)
logger.info(
@@ -317,28 +309,6 @@ async def match_user_credentials_to_graph(
return graph_credentials_inputs, missing_creds
def _credential_has_required_scopes(
credential: Credentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""
Check if a credential has all the scopes required by the block.
For OAuth2 credentials, verifies that the credential's scopes are a superset
of the required scopes. For other credential types, returns True (no scope check).
"""
# Only OAuth2 credentials have scopes to check
if credential.type != "oauth2":
return True
# If no scopes are required, any credential matches
if not requirements.required_scopes:
return True
# Check that credential scopes are a superset of required scopes
return set(credential.scopes).issuperset(requirements.required_scopes)
async def check_user_has_required_credentials(
user_id: str,
required_credentials: list[CredentialsMetaInput],

View File

@@ -1,620 +0,0 @@
"""CoPilot tools for workspace file operations."""
import base64
import logging
from typing import Any, Optional
from pydantic import BaseModel
from backend.api.features.chat.model import ChatSession
from backend.data.workspace import get_or_create_workspace
from backend.util.settings import Config
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import WorkspaceManager
from .base import BaseTool
from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
class WorkspaceFileInfoData(BaseModel):
"""Data model for workspace file information (not a response itself)."""
file_id: str
name: str
path: str
mime_type: str
size_bytes: int
class WorkspaceFileListResponse(ToolResponseBase):
"""Response containing list of workspace files."""
type: ResponseType = ResponseType.WORKSPACE_FILE_LIST
files: list[WorkspaceFileInfoData]
total_count: int
class WorkspaceFileContentResponse(ToolResponseBase):
"""Response containing workspace file content (legacy, for small text files)."""
type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT
file_id: str
name: str
path: str
mime_type: str
content_base64: str
class WorkspaceFileMetadataResponse(ToolResponseBase):
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA
file_id: str
name: str
path: str
mime_type: str
size_bytes: int
download_url: str
preview: str | None = None # First 500 chars for text files
class WorkspaceWriteResponse(ToolResponseBase):
"""Response after writing a file to workspace."""
type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN
file_id: str
name: str
path: str
size_bytes: int
class WorkspaceDeleteResponse(ToolResponseBase):
"""Response after deleting a file from workspace."""
type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED
file_id: str
success: bool
class ListWorkspaceFilesTool(BaseTool):
"""Tool for listing files in user's workspace."""
@property
def name(self) -> str:
return "list_workspace_files"
@property
def description(self) -> str:
return (
"List files in the user's workspace. "
"Returns file names, paths, sizes, and metadata. "
"Optionally filter by path prefix."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"path_prefix": {
"type": "string",
"description": (
"Optional path prefix to filter files "
"(e.g., '/documents/' to list only files in documents folder). "
"By default, only files from the current session are listed."
),
},
"limit": {
"type": "integer",
"description": "Maximum number of files to return (default 50, max 100)",
"minimum": 1,
"maximum": 100,
},
"include_all_sessions": {
"type": "boolean",
"description": (
"If true, list files from all sessions. "
"Default is false (only current session's files)."
),
},
},
"required": [],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
path_prefix: Optional[str] = kwargs.get("path_prefix")
limit = min(kwargs.get("limit", 50), 100)
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
files = await manager.list_files(
path=path_prefix,
limit=limit,
include_all_sessions=include_all_sessions,
)
total = await manager.get_file_count(
path=path_prefix,
include_all_sessions=include_all_sessions,
)
file_infos = [
WorkspaceFileInfoData(
file_id=f.id,
name=f.name,
path=f.path,
mime_type=f.mimeType,
size_bytes=f.sizeBytes,
)
for f in files
]
scope_msg = "all sessions" if include_all_sessions else "current session"
return WorkspaceFileListResponse(
files=file_infos,
total_count=total,
message=f"Found {len(files)} files in workspace ({scope_msg})",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error listing workspace files: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to list workspace files: {str(e)}",
error=str(e),
session_id=session_id,
)
class ReadWorkspaceFileTool(BaseTool):
"""Tool for reading file content from workspace."""
# Size threshold for returning full content vs metadata+URL
# Files larger than this return metadata with download URL to prevent context bloat
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
# Preview size for text files
PREVIEW_SIZE = 500
@property
def name(self) -> str:
return "read_workspace_file"
@property
def description(self) -> str:
return (
"Read a file from the user's workspace. "
"Specify either file_id or path to identify the file. "
"For small text files, returns content directly. "
"For large or binary files, returns metadata and a download URL. "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"file_id": {
"type": "string",
"description": "The file's unique ID (from list_workspace_files)",
},
"path": {
"type": "string",
"description": (
"The virtual file path (e.g., '/documents/report.pdf'). "
"Scoped to current session by default."
),
},
"force_download_url": {
"type": "boolean",
"description": (
"If true, always return metadata+URL instead of inline content. "
"Default is false (auto-selects based on file size/type)."
),
},
},
"required": [], # At least one must be provided
}
@property
def requires_auth(self) -> bool:
return True
def _is_text_mime_type(self, mime_type: str) -> bool:
"""Check if the MIME type is a text-based type."""
text_types = [
"text/",
"application/json",
"application/xml",
"application/javascript",
"application/x-python",
"application/x-sh",
]
return any(mime_type.startswith(t) for t in text_types)
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
force_download_url: bool = kwargs.get("force_download_url", False)
if not file_id and not path:
return ErrorResponse(
message="Please provide either file_id or path",
session_id=session_id,
)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
# Get file info
if file_id:
file_info = await manager.get_file_info(file_id)
if file_info is None:
return ErrorResponse(
message=f"File not found: {file_id}",
session_id=session_id,
)
target_file_id = file_id
else:
# path is guaranteed to be non-None here due to the check above
assert path is not None
file_info = await manager.get_file_info_by_path(path)
if file_info is None:
return ErrorResponse(
message=f"File not found at path: {path}",
session_id=session_id,
)
target_file_id = file_info.id
# Decide whether to return inline content or metadata+URL
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
is_text_file = self._is_text_mime_type(file_info.mimeType)
# Return inline content for small text files (unless force_download_url)
if is_small_file and is_text_file and not force_download_url:
content = await manager.read_file_by_id(target_file_id)
content_b64 = base64.b64encode(content).decode("utf-8")
return WorkspaceFileContentResponse(
file_id=file_info.id,
name=file_info.name,
path=file_info.path,
mime_type=file_info.mimeType,
content_base64=content_b64,
message=f"Successfully read file: {file_info.name}",
session_id=session_id,
)
# Return metadata + workspace:// reference for large or binary files
# This prevents context bloat (100KB file = ~133KB as base64)
# Use workspace:// format so frontend urlTransform can add proxy prefix
download_url = f"workspace://{target_file_id}"
# Generate preview for text files
preview: str | None = None
if is_text_file:
try:
content = await manager.read_file_by_id(target_file_id)
preview_text = content[: self.PREVIEW_SIZE].decode(
"utf-8", errors="replace"
)
if len(content) > self.PREVIEW_SIZE:
preview_text += "..."
preview = preview_text
except Exception:
pass # Preview is optional
return WorkspaceFileMetadataResponse(
file_id=file_info.id,
name=file_info.name,
path=file_info.path,
mime_type=file_info.mimeType,
size_bytes=file_info.sizeBytes,
download_url=download_url,
preview=preview,
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
session_id=session_id,
)
except FileNotFoundError as e:
return ErrorResponse(
message=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Error reading workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to read workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)
class WriteWorkspaceFileTool(BaseTool):
"""Tool for writing files to workspace."""
@property
def name(self) -> str:
return "write_workspace_file"
@property
def description(self) -> str:
return (
"Write or create a file in the user's workspace. "
"Provide the content as a base64-encoded string. "
f"Maximum file size is {Config().max_file_size_mb}MB. "
"Files are saved to the current session's folder by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "Name for the file (e.g., 'report.pdf')",
},
"content_base64": {
"type": "string",
"description": "Base64-encoded file content",
},
"path": {
"type": "string",
"description": (
"Optional virtual path where to save the file "
"(e.g., '/documents/report.pdf'). "
"Defaults to '/{filename}'. Scoped to current session."
),
},
"mime_type": {
"type": "string",
"description": (
"Optional MIME type of the file. "
"Auto-detected from filename if not provided."
),
},
"overwrite": {
"type": "boolean",
"description": "Whether to overwrite if file exists at path (default: false)",
},
},
"required": ["filename", "content_base64"],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
filename: str = kwargs.get("filename", "")
content_b64: str = kwargs.get("content_base64", "")
path: Optional[str] = kwargs.get("path")
mime_type: Optional[str] = kwargs.get("mime_type")
overwrite: bool = kwargs.get("overwrite", False)
if not filename:
return ErrorResponse(
message="Please provide a filename",
session_id=session_id,
)
if not content_b64:
return ErrorResponse(
message="Please provide content_base64",
session_id=session_id,
)
# Decode content
try:
content = base64.b64decode(content_b64)
except Exception:
return ErrorResponse(
message="Invalid base64-encoded content",
session_id=session_id,
)
# Check size
max_file_size = Config().max_file_size_mb * 1024 * 1024
if len(content) > max_file_size:
return ErrorResponse(
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
session_id=session_id,
)
try:
# Virus scan
await scan_content_safe(content, filename=filename)
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
file_record = await manager.write_file(
content=content,
filename=filename,
path=path,
mime_type=mime_type,
overwrite=overwrite,
)
return WorkspaceWriteResponse(
file_id=file_record.id,
name=file_record.name,
path=file_record.path,
size_bytes=file_record.sizeBytes,
message=f"Successfully wrote file: {file_record.name}",
session_id=session_id,
)
except ValueError as e:
return ErrorResponse(
message=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Error writing workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to write workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)
class DeleteWorkspaceFileTool(BaseTool):
"""Tool for deleting files from workspace."""
@property
def name(self) -> str:
return "delete_workspace_file"
@property
def description(self) -> str:
return (
"Delete a file from the user's workspace. "
"Specify either file_id or path to identify the file. "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"file_id": {
"type": "string",
"description": "The file's unique ID (from list_workspace_files)",
},
"path": {
"type": "string",
"description": (
"The virtual file path (e.g., '/documents/report.pdf'). "
"Scoped to current session by default."
),
},
},
"required": [], # At least one must be provided
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
if not file_id and not path:
return ErrorResponse(
message="Please provide either file_id or path",
session_id=session_id,
)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
# Determine the file_id to delete
target_file_id: str
if file_id:
target_file_id = file_id
else:
# path is guaranteed to be non-None here due to the check above
assert path is not None
file_info = await manager.get_file_info_by_path(path)
if file_info is None:
return ErrorResponse(
message=f"File not found at path: {path}",
session_id=session_id,
)
target_file_id = file_info.id
success = await manager.delete_file(target_file_id)
if not success:
return ErrorResponse(
message=f"File not found: {target_file_id}",
session_id=session_id,
)
return WorkspaceDeleteResponse(
file_id=target_file_id,
success=True,
message="File deleted successfully",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to delete workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)

View File

@@ -1,250 +0,0 @@
"""PostHog analytics tracking for the chat system."""
import atexit
import logging
from typing import Any
from posthog import Posthog
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
# PostHog client instance (lazily initialized)
_posthog_client: Posthog | None = None
def _shutdown_posthog() -> None:
"""Flush and shutdown PostHog client on process exit."""
if _posthog_client is not None:
_posthog_client.flush()
_posthog_client.shutdown()
atexit.register(_shutdown_posthog)
def _get_posthog_client() -> Posthog | None:
"""Get or create the PostHog client instance."""
global _posthog_client
if _posthog_client is not None:
return _posthog_client
if not settings.secrets.posthog_api_key:
logger.debug("PostHog API key not configured, analytics disabled")
return None
_posthog_client = Posthog(
settings.secrets.posthog_api_key,
host=settings.secrets.posthog_host,
)
logger.info(
f"PostHog client initialized with host: {settings.secrets.posthog_host}"
)
return _posthog_client
def _get_base_properties() -> dict[str, Any]:
"""Get base properties included in all events."""
return {
"environment": settings.config.app_env.value,
"source": "chat_copilot",
}
def track_user_message(
user_id: str | None,
session_id: str,
message_length: int,
) -> None:
"""Track when a user sends a message in chat.
Args:
user_id: The user's ID (or None for anonymous)
session_id: The chat session ID
message_length: Length of the user's message
"""
client = _get_posthog_client()
if not client:
return
try:
properties = {
**_get_base_properties(),
"session_id": session_id,
"message_length": message_length,
}
client.capture(
distinct_id=user_id or f"anonymous_{session_id}",
event="copilot_message_sent",
properties=properties,
)
except Exception as e:
logger.warning(f"Failed to track user message: {e}")
def track_tool_called(
user_id: str | None,
session_id: str,
tool_name: str,
tool_call_id: str,
) -> None:
"""Track when a tool is called in chat.
Args:
user_id: The user's ID (or None for anonymous)
session_id: The chat session ID
tool_name: Name of the tool being called
tool_call_id: Unique ID of the tool call
"""
client = _get_posthog_client()
if not client:
logger.info("PostHog client not available for tool tracking")
return
try:
properties = {
**_get_base_properties(),
"session_id": session_id,
"tool_name": tool_name,
"tool_call_id": tool_call_id,
}
distinct_id = user_id or f"anonymous_{session_id}"
logger.info(
f"Sending copilot_tool_called event to PostHog: distinct_id={distinct_id}, "
f"tool_name={tool_name}"
)
client.capture(
distinct_id=distinct_id,
event="copilot_tool_called",
properties=properties,
)
except Exception as e:
logger.warning(f"Failed to track tool call: {e}")
def track_agent_run_success(
user_id: str,
session_id: str,
graph_id: str,
graph_name: str,
execution_id: str,
library_agent_id: str,
) -> None:
"""Track when an agent is successfully run.
Args:
user_id: The user's ID
session_id: The chat session ID
graph_id: ID of the agent graph
graph_name: Name of the agent
execution_id: ID of the execution
library_agent_id: ID of the library agent
"""
client = _get_posthog_client()
if not client:
return
try:
properties = {
**_get_base_properties(),
"session_id": session_id,
"graph_id": graph_id,
"graph_name": graph_name,
"execution_id": execution_id,
"library_agent_id": library_agent_id,
}
client.capture(
distinct_id=user_id,
event="copilot_agent_run_success",
properties=properties,
)
except Exception as e:
logger.warning(f"Failed to track agent run: {e}")
def track_agent_scheduled(
user_id: str,
session_id: str,
graph_id: str,
graph_name: str,
schedule_id: str,
schedule_name: str,
cron: str,
library_agent_id: str,
) -> None:
"""Track when an agent is successfully scheduled.
Args:
user_id: The user's ID
session_id: The chat session ID
graph_id: ID of the agent graph
graph_name: Name of the agent
schedule_id: ID of the schedule
schedule_name: Name of the schedule
cron: Cron expression for the schedule
library_agent_id: ID of the library agent
"""
client = _get_posthog_client()
if not client:
return
try:
properties = {
**_get_base_properties(),
"session_id": session_id,
"graph_id": graph_id,
"graph_name": graph_name,
"schedule_id": schedule_id,
"schedule_name": schedule_name,
"cron": cron,
"library_agent_id": library_agent_id,
}
client.capture(
distinct_id=user_id,
event="copilot_agent_scheduled",
properties=properties,
)
except Exception as e:
logger.warning(f"Failed to track agent schedule: {e}")
def track_trigger_setup(
user_id: str,
session_id: str,
graph_id: str,
graph_name: str,
trigger_type: str,
library_agent_id: str,
) -> None:
"""Track when a trigger is set up for an agent.
Args:
user_id: The user's ID
session_id: The chat session ID
graph_id: ID of the agent graph
graph_name: Name of the agent
trigger_type: Type of trigger (e.g., 'webhook')
library_agent_id: ID of the library agent
"""
client = _get_posthog_client()
if not client:
return
try:
properties = {
**_get_base_properties(),
"session_id": session_id,
"graph_id": graph_id,
"graph_name": graph_name,
"trigger_type": trigger_type,
"library_agent_id": library_agent_id,
}
client.capture(
distinct_id=user_id,
event="copilot_trigger_setup",
properties=properties,
)
except Exception as e:
logger.warning(f"Failed to track trigger setup: {e}")

View File

@@ -23,7 +23,6 @@ class PendingHumanReviewModel(BaseModel):
id: Unique identifier for the review record
user_id: ID of the user who must perform the review
node_exec_id: ID of the node execution that created this review
node_id: ID of the node definition (for grouping reviews from same node)
graph_exec_id: ID of the graph execution containing the node
graph_id: ID of the graph template being executed
graph_version: Version number of the graph template
@@ -38,10 +37,6 @@ class PendingHumanReviewModel(BaseModel):
"""
node_exec_id: str = Field(description="Node execution ID (primary key)")
node_id: str = Field(
description="Node definition ID (for grouping)",
default="", # Temporary default for test compatibility
)
user_id: str = Field(description="User ID associated with the review")
graph_exec_id: str = Field(description="Graph execution ID")
graph_id: str = Field(description="Graph ID")
@@ -71,9 +66,7 @@ class PendingHumanReviewModel(BaseModel):
)
@classmethod
def from_db(
cls, review: "PendingHumanReview", node_id: str
) -> "PendingHumanReviewModel":
def from_db(cls, review: "PendingHumanReview") -> "PendingHumanReviewModel":
"""
Convert a database model to a response model.
@@ -81,14 +74,9 @@ class PendingHumanReviewModel(BaseModel):
payload, instructions, and editable flag.
Handles invalid data gracefully by using safe defaults.
Args:
review: Database review object
node_id: Node definition ID (fetched from NodeExecution)
"""
return cls(
node_exec_id=review.nodeExecId,
node_id=node_id,
user_id=review.userId,
graph_exec_id=review.graphExecId,
graph_id=review.graphId,
@@ -119,13 +107,6 @@ class ReviewItem(BaseModel):
reviewed_data: SafeJsonData | None = Field(
None, description="Optional edited data (ignored if approved=False)"
)
auto_approve_future: bool = Field(
default=False,
description=(
"If true and this review is approved, future executions of this same "
"block (node) will be automatically approved. This only affects approved reviews."
),
)
@field_validator("reviewed_data")
@classmethod
@@ -193,9 +174,6 @@ class ReviewRequest(BaseModel):
This request must include ALL pending reviews for a graph execution.
Each review will be either approved (with optional data modifications)
or rejected (data ignored). The execution will resume only after ALL reviews are processed.
Each review item can individually specify whether to auto-approve future executions
of the same block via the `auto_approve_future` field on ReviewItem.
"""
reviews: List[ReviewItem] = Field(

View File

@@ -1,27 +1,17 @@
import asyncio
import logging
from typing import Any, List
from typing import List
import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, HTTPException, Query, Security, status
from prisma.enums import ReviewStatus
from backend.data.execution import (
ExecutionContext,
ExecutionStatus,
get_graph_execution_meta,
)
from backend.data.graph import get_graph_settings
from backend.data.execution import get_graph_execution_meta
from backend.data.human_review import (
create_auto_approval_record,
get_pending_reviews_for_execution,
get_pending_reviews_for_user,
get_reviews_by_node_exec_ids,
has_pending_reviews_for_graph_exec,
process_all_reviews_for_execution,
)
from backend.data.model import USER_TIMEZONE_NOT_SET
from backend.data.user import get_user_by_id
from backend.executor.utils import add_graph_execution
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
@@ -137,70 +127,17 @@ async def process_review_action(
detail="At least one review must be provided",
)
# Batch fetch all requested reviews (regardless of status for idempotent handling)
reviews_map = await get_reviews_by_node_exec_ids(
list(all_request_node_ids), user_id
)
# Validate all reviews were found (must exist, any status is OK for now)
missing_ids = all_request_node_ids - set(reviews_map.keys())
if missing_ids:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Review(s) not found: {', '.join(missing_ids)}",
)
# Validate all reviews belong to the same execution
graph_exec_ids = {review.graph_exec_id for review in reviews_map.values()}
if len(graph_exec_ids) > 1:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="All reviews in a single request must belong to the same execution.",
)
graph_exec_id = next(iter(graph_exec_ids))
# Validate execution status before processing reviews
graph_exec_meta = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec_meta:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
# Only allow processing reviews if execution is paused for review
# or incomplete (partial execution with some reviews already processed)
if graph_exec_meta.status not in (
ExecutionStatus.REVIEW,
ExecutionStatus.INCOMPLETE,
):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
f"Reviews can only be processed when execution is paused (REVIEW status). "
f"Current status: {graph_exec_meta.status}",
)
# Build review decisions map and track which reviews requested auto-approval
# Auto-approved reviews use original data (no modifications allowed)
# Build review decisions map
review_decisions = {}
auto_approve_requests = {} # Map node_exec_id -> auto_approve_future flag
for review in request.reviews:
review_status = (
ReviewStatus.APPROVED if review.approved else ReviewStatus.REJECTED
)
# If this review requested auto-approval, don't allow data modifications
reviewed_data = None if review.auto_approve_future else review.reviewed_data
review_decisions[review.node_exec_id] = (
review_status,
reviewed_data,
review.reviewed_data,
review.message,
)
auto_approve_requests[review.node_exec_id] = review.auto_approve_future
# Process all reviews
updated_reviews = await process_all_reviews_for_execution(
@@ -208,87 +145,6 @@ async def process_review_action(
review_decisions=review_decisions,
)
# Create auto-approval records for approved reviews that requested it
# Deduplicate by node_id to avoid race conditions when multiple reviews
# for the same node are processed in parallel
async def create_auto_approval_for_node(
node_id: str, review_result
) -> tuple[str, bool]:
"""
Create auto-approval record for a node.
Returns (node_id, success) tuple for tracking failures.
"""
try:
await create_auto_approval_record(
user_id=user_id,
graph_exec_id=review_result.graph_exec_id,
graph_id=review_result.graph_id,
graph_version=review_result.graph_version,
node_id=node_id,
payload=review_result.payload,
)
return (node_id, True)
except Exception as e:
logger.error(
f"Failed to create auto-approval record for node {node_id}",
exc_info=e,
)
return (node_id, False)
# Collect node_exec_ids that need auto-approval
node_exec_ids_needing_auto_approval = [
node_exec_id
for node_exec_id, review_result in updated_reviews.items()
if review_result.status == ReviewStatus.APPROVED
and auto_approve_requests.get(node_exec_id, False)
]
# Batch-fetch node executions to get node_ids
nodes_needing_auto_approval: dict[str, Any] = {}
if node_exec_ids_needing_auto_approval:
from backend.data.execution import get_node_executions
node_execs = await get_node_executions(
graph_exec_id=graph_exec_id, include_exec_data=False
)
node_exec_map = {node_exec.node_exec_id: node_exec for node_exec in node_execs}
for node_exec_id in node_exec_ids_needing_auto_approval:
node_exec = node_exec_map.get(node_exec_id)
if node_exec:
review_result = updated_reviews[node_exec_id]
# Use the first approved review for this node (deduplicate by node_id)
if node_exec.node_id not in nodes_needing_auto_approval:
nodes_needing_auto_approval[node_exec.node_id] = review_result
else:
logger.error(
f"Failed to create auto-approval record for {node_exec_id}: "
f"Node execution not found. This may indicate a race condition "
f"or data inconsistency."
)
# Execute all auto-approval creations in parallel (deduplicated by node_id)
auto_approval_results = await asyncio.gather(
*[
create_auto_approval_for_node(node_id, review_result)
for node_id, review_result in nodes_needing_auto_approval.items()
],
return_exceptions=True,
)
# Count auto-approval failures
auto_approval_failed_count = 0
for result in auto_approval_results:
if isinstance(result, Exception):
# Unexpected exception during auto-approval creation
auto_approval_failed_count += 1
logger.error(
f"Unexpected exception during auto-approval creation: {result}"
)
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
# Auto-approval creation failed (returned False)
auto_approval_failed_count += 1
# Count results
approved_count = sum(
1
@@ -301,53 +157,30 @@ async def process_review_action(
if review.status == ReviewStatus.REJECTED
)
# Resume execution only if ALL pending reviews for this execution have been processed
# Resume execution if we processed some reviews
if updated_reviews:
# Get graph execution ID from any processed review
first_review = next(iter(updated_reviews.values()))
graph_exec_id = first_review.graph_exec_id
# Check if any pending reviews remain for this execution
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
if not still_has_pending:
# Get the graph_id from any processed review
first_review = next(iter(updated_reviews.values()))
# Resume execution
try:
# Fetch user and settings to build complete execution context
user = await get_user_by_id(user_id)
settings = await get_graph_settings(
user_id=user_id, graph_id=first_review.graph_id
)
# Preserve user's timezone preference when resuming execution
user_timezone = (
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
)
execution_context = ExecutionContext(
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
user_timezone=user_timezone,
)
await add_graph_execution(
graph_id=first_review.graph_id,
user_id=user_id,
graph_exec_id=graph_exec_id,
execution_context=execution_context,
)
logger.info(f"Resumed execution {graph_exec_id}")
except Exception as e:
logger.error(f"Failed to resume execution {graph_exec_id}: {str(e)}")
# Build error message if auto-approvals failed
error_message = None
if auto_approval_failed_count > 0:
error_message = (
f"{auto_approval_failed_count} auto-approval setting(s) could not be saved. "
f"You may need to manually approve these reviews in future executions."
)
return ReviewResponse(
approved_count=approved_count,
rejected_count=rejected_count,
failed_count=auto_approval_failed_count,
error=error_message,
failed_count=0,
error=None,
)

View File

@@ -21,7 +21,7 @@ from backend.data.model import CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
from backend.util.clients import get_scheduler_client
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
from backend.util.exceptions import DatabaseError, NotFoundError
from backend.util.json import SafeJson
from backend.util.models import Pagination
from backend.util.settings import Config
@@ -39,7 +39,6 @@ async def list_library_agents(
sort_by: library_model.LibraryAgentSort = library_model.LibraryAgentSort.UPDATED_AT,
page: int = 1,
page_size: int = 50,
include_executions: bool = False,
) -> library_model.LibraryAgentResponse:
"""
Retrieves a paginated list of LibraryAgent records for a given user.
@@ -50,9 +49,6 @@ async def list_library_agents(
sort_by: Sorting field (createdAt, updatedAt, isFavorite, isCreatedByUser).
page: Current page (1-indexed).
page_size: Number of items per page.
include_executions: Whether to include execution data for status calculation.
Defaults to False for performance (UI fetches status separately).
Set to True when accurate status/metrics are needed (e.g., agent generator).
Returns:
A LibraryAgentResponse containing the list of agents and pagination details.
@@ -68,11 +64,11 @@ async def list_library_agents(
if page < 1 or page_size < 1:
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
raise InvalidInputError("Invalid pagination input")
raise DatabaseError("Invalid pagination input")
if search_term and len(search_term.strip()) > 100:
logger.warning(f"Search term too long: {repr(search_term)}")
raise InvalidInputError("Search term is too long")
raise DatabaseError("Search term is too long")
where_clause: prisma.types.LibraryAgentWhereInput = {
"userId": user_id,
@@ -80,6 +76,7 @@ async def list_library_agents(
"isArchived": False,
}
# Build search filter if applicable
if search_term:
where_clause["OR"] = [
{
@@ -96,6 +93,7 @@ async def list_library_agents(
},
]
# Determine sorting
order_by: prisma.types.LibraryAgentOrderByInput | None = None
if sort_by == library_model.LibraryAgentSort.CREATED_AT:
@@ -107,7 +105,7 @@ async def list_library_agents(
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
where=where_clause,
include=library_agent_include(
user_id, include_nodes=False, include_executions=include_executions
user_id, include_nodes=False, include_executions=False
),
order=order_by,
skip=(page - 1) * page_size,
@@ -177,7 +175,7 @@ async def list_favorite_library_agents(
if page < 1 or page_size < 1:
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
raise InvalidInputError("Invalid pagination input")
raise DatabaseError("Invalid pagination input")
where_clause: prisma.types.LibraryAgentWhereInput = {
"userId": user_id,
@@ -585,13 +583,7 @@ async def update_library_agent(
)
update_fields["isDeleted"] = is_deleted
if settings is not None:
existing_agent = await get_library_agent(id=library_agent_id, user_id=user_id)
current_settings_dict = (
existing_agent.settings.model_dump() if existing_agent.settings else {}
)
new_settings = settings.model_dump(exclude_unset=True)
merged_settings = {**current_settings_dict, **new_settings}
update_fields["settings"] = SafeJson(merged_settings)
update_fields["settings"] = SafeJson(settings.model_dump())
try:
# If graph_version is provided, update to that specific version

View File

@@ -9,7 +9,6 @@ import pydantic
from backend.data.block import BlockInput
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
from backend.util.json import loads as json_loads
from backend.util.models import Pagination
if TYPE_CHECKING:
@@ -17,10 +16,10 @@ if TYPE_CHECKING:
class LibraryAgentStatus(str, Enum):
COMPLETED = "COMPLETED"
HEALTHY = "HEALTHY"
WAITING = "WAITING"
ERROR = "ERROR"
COMPLETED = "COMPLETED" # All runs completed
HEALTHY = "HEALTHY" # Agent is running (not all runs have completed)
WAITING = "WAITING" # Agent is queued or waiting to start
ERROR = "ERROR" # Agent is in an error state
class MarketplaceListingCreator(pydantic.BaseModel):
@@ -40,30 +39,6 @@ class MarketplaceListing(pydantic.BaseModel):
creator: MarketplaceListingCreator
class RecentExecution(pydantic.BaseModel):
"""Summary of a recent execution for quality assessment.
Used by the LLM to understand the agent's recent performance with specific examples
rather than just aggregate statistics.
"""
status: str
correctness_score: float | None = None
activity_summary: str | None = None
def _parse_settings(settings: dict | str | None) -> GraphSettings:
"""Parse settings from database, handling both dict and string formats."""
if settings is None:
return GraphSettings()
try:
if isinstance(settings, str):
settings = json_loads(settings)
return GraphSettings.model_validate(settings)
except Exception:
return GraphSettings()
class LibraryAgent(pydantic.BaseModel):
"""
Represents an agent in the library, including metadata for display and
@@ -73,7 +48,7 @@ class LibraryAgent(pydantic.BaseModel):
id: str
graph_id: str
graph_version: int
owner_user_id: str
owner_user_id: str # ID of user who owns/created this agent graph
image_url: str | None
@@ -89,7 +64,7 @@ class LibraryAgent(pydantic.BaseModel):
description: str
instructions: str | None = None
input_schema: dict[str, Any]
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
output_schema: dict[str, Any]
credentials_input_schema: dict[str, Any] | None = pydantic.Field(
description="Input schema for credentials required by the agent",
@@ -106,19 +81,25 @@ class LibraryAgent(pydantic.BaseModel):
)
trigger_setup_info: Optional[GraphTriggerInfo] = None
# Indicates whether there's a new output (based on recent runs)
new_output: bool
execution_count: int = 0
success_rate: float | None = None
avg_correctness_score: float | None = None
recent_executions: list[RecentExecution] = pydantic.Field(
default_factory=list,
description="List of recent executions with status, score, and summary",
)
# Whether the user can access the underlying graph
can_access_graph: bool
# Indicates if this agent is the latest version
is_latest_version: bool
# Whether the agent is marked as favorite by the user
is_favorite: bool
# Recommended schedule cron (from marketplace agents)
recommended_schedule_cron: str | None = None
# User-specific settings for this library agent
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
# Marketplace listing information if the agent has been published
marketplace_listing: Optional["MarketplaceListing"] = None
@staticmethod
@@ -142,6 +123,7 @@ class LibraryAgent(pydantic.BaseModel):
agent_updated_at = agent.AgentGraph.updatedAt
lib_agent_updated_at = agent.updatedAt
# Compute updated_at as the latest between library agent and graph
updated_at = (
max(agent_updated_at, lib_agent_updated_at)
if agent_updated_at
@@ -154,6 +136,7 @@ class LibraryAgent(pydantic.BaseModel):
creator_name = agent.Creator.name or "Unknown"
creator_image_url = agent.Creator.avatarUrl or ""
# Logic to calculate status and new_output
week_ago = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
days=7
)
@@ -162,55 +145,13 @@ class LibraryAgent(pydantic.BaseModel):
status = status_result.status
new_output = status_result.new_output
execution_count = len(executions)
success_rate: float | None = None
avg_correctness_score: float | None = None
if execution_count > 0:
success_count = sum(
1
for e in executions
if e.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED
)
success_rate = (success_count / execution_count) * 100
correctness_scores = []
for e in executions:
if e.stats and isinstance(e.stats, dict):
score = e.stats.get("correctness_score")
if score is not None and isinstance(score, (int, float)):
correctness_scores.append(float(score))
if correctness_scores:
avg_correctness_score = sum(correctness_scores) / len(
correctness_scores
)
recent_executions: list[RecentExecution] = []
for e in executions:
exec_score: float | None = None
exec_summary: str | None = None
if e.stats and isinstance(e.stats, dict):
score = e.stats.get("correctness_score")
if score is not None and isinstance(score, (int, float)):
exec_score = float(score)
summary = e.stats.get("activity_status")
if summary is not None and isinstance(summary, str):
exec_summary = summary
exec_status = (
e.executionStatus.value
if hasattr(e.executionStatus, "value")
else str(e.executionStatus)
)
recent_executions.append(
RecentExecution(
status=exec_status,
correctness_score=exec_score,
activity_summary=exec_summary,
)
)
# Check if user can access the graph
can_access_graph = agent.AgentGraph.userId == agent.userId
# Hard-coded to True until a method to check is implemented
is_latest_version = True
# Build marketplace_listing if available
marketplace_listing_data = None
if store_listing and store_listing.ActiveVersion and profile:
creator_data = MarketplaceListingCreator(
@@ -249,15 +190,11 @@ class LibraryAgent(pydantic.BaseModel):
has_sensitive_action=graph.has_sensitive_action,
trigger_setup_info=graph.trigger_setup_info,
new_output=new_output,
execution_count=execution_count,
success_rate=success_rate,
avg_correctness_score=avg_correctness_score,
recent_executions=recent_executions,
can_access_graph=can_access_graph,
is_latest_version=is_latest_version,
is_favorite=agent.isFavorite,
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
settings=_parse_settings(agent.settings),
settings=GraphSettings.model_validate(agent.settings),
marketplace_listing=marketplace_listing_data,
)
@@ -283,15 +220,18 @@ def _calculate_agent_status(
if not executions:
return AgentStatusResult(status=LibraryAgentStatus.COMPLETED, new_output=False)
# Track how many times each execution status appears
status_counts = {status: 0 for status in prisma.enums.AgentExecutionStatus}
new_output = False
for execution in executions:
# Check if there's a completed run more recent than `recent_threshold`
if execution.createdAt >= recent_threshold:
if execution.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED:
new_output = True
status_counts[execution.executionStatus] += 1
# Determine the final status based on counts
if status_counts[prisma.enums.AgentExecutionStatus.FAILED] > 0:
return AgentStatusResult(status=LibraryAgentStatus.ERROR, new_output=new_output)
elif status_counts[prisma.enums.AgentExecutionStatus.QUEUED] > 0:

View File

@@ -1,3 +1,4 @@
import logging
from typing import Literal, Optional
import autogpt_libs.auth as autogpt_auth_lib
@@ -5,11 +6,15 @@ from fastapi import APIRouter, Body, HTTPException, Query, Security, status
from fastapi.responses import Response
from prisma.enums import OnboardingStep
import backend.api.features.store.exceptions as store_exceptions
from backend.data.onboarding import complete_onboarding_step
from backend.util.exceptions import DatabaseError, NotFoundError
from .. import db as library_db
from .. import model as library_model
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/agents",
tags=["library", "private"],
@@ -21,6 +26,10 @@ router = APIRouter(
"",
summary="List Library Agents",
response_model=library_model.LibraryAgentResponse,
responses={
200: {"description": "List of library agents"},
500: {"description": "Server error", "content": {"application/json": {}}},
},
)
async def list_library_agents(
user_id: str = Security(autogpt_auth_lib.get_user_id),
@@ -44,19 +53,43 @@ async def list_library_agents(
) -> library_model.LibraryAgentResponse:
"""
Get all agents in the user's library (both created and saved).
Args:
user_id: ID of the authenticated user.
search_term: Optional search term to filter agents by name/description.
filter_by: List of filters to apply (favorites, created by user).
sort_by: List of sorting criteria (created date, updated date).
page: Page number to retrieve.
page_size: Number of agents per page.
Returns:
A LibraryAgentResponse containing agents and pagination metadata.
Raises:
HTTPException: If a server/database error occurs.
"""
return await library_db.list_library_agents(
user_id=user_id,
search_term=search_term,
sort_by=sort_by,
page=page,
page_size=page_size,
)
try:
return await library_db.list_library_agents(
user_id=user_id,
search_term=search_term,
sort_by=sort_by,
page=page,
page_size=page_size,
)
except Exception as e:
logger.error(f"Could not list library agents for user #{user_id}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
) from e
@router.get(
"/favorites",
summary="List Favorite Library Agents",
responses={
500: {"description": "Server error", "content": {"application/json": {}}},
},
)
async def list_favorite_library_agents(
user_id: str = Security(autogpt_auth_lib.get_user_id),
@@ -73,12 +106,30 @@ async def list_favorite_library_agents(
) -> library_model.LibraryAgentResponse:
"""
Get all favorite agents in the user's library.
Args:
user_id: ID of the authenticated user.
page: Page number to retrieve.
page_size: Number of agents per page.
Returns:
A LibraryAgentResponse containing favorite agents and pagination metadata.
Raises:
HTTPException: If a server/database error occurs.
"""
return await library_db.list_favorite_library_agents(
user_id=user_id,
page=page,
page_size=page_size,
)
try:
return await library_db.list_favorite_library_agents(
user_id=user_id,
page=page,
page_size=page_size,
)
except Exception as e:
logger.error(f"Could not list favorite library agents for user #{user_id}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
) from e
@router.get("/{library_agent_id}", summary="Get Library Agent")
@@ -111,6 +162,10 @@ async def get_library_agent_by_graph_id(
summary="Get Agent By Store ID",
tags=["store", "library"],
response_model=library_model.LibraryAgent | None,
responses={
200: {"description": "Library agent found"},
404: {"description": "Agent not found"},
},
)
async def get_library_agent_by_store_listing_version_id(
store_listing_version_id: str,
@@ -119,15 +174,32 @@ async def get_library_agent_by_store_listing_version_id(
"""
Get Library Agent from Store Listing Version ID.
"""
return await library_db.get_library_agent_by_store_version_id(
store_listing_version_id, user_id
)
try:
return await library_db.get_library_agent_by_store_version_id(
store_listing_version_id, user_id
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
)
except Exception as e:
logger.error(f"Could not fetch library agent from store version ID: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
) from e
@router.post(
"",
summary="Add Marketplace Agent",
status_code=status.HTTP_201_CREATED,
responses={
201: {"description": "Agent added successfully"},
404: {"description": "Store listing version not found"},
500: {"description": "Server error"},
},
)
async def add_marketplace_agent_to_library(
store_listing_version_id: str = Body(embed=True),
@@ -138,19 +210,59 @@ async def add_marketplace_agent_to_library(
) -> library_model.LibraryAgent:
"""
Add an agent from the marketplace to the user's library.
Args:
store_listing_version_id: ID of the store listing version to add.
user_id: ID of the authenticated user.
Returns:
library_model.LibraryAgent: Agent added to the library
Raises:
HTTPException(404): If the listing version is not found.
HTTPException(500): If a server/database error occurs.
"""
agent = await library_db.add_store_agent_to_library(
store_listing_version_id=store_listing_version_id,
user_id=user_id,
)
if source != "onboarding":
await complete_onboarding_step(user_id, OnboardingStep.MARKETPLACE_ADD_AGENT)
return agent
try:
agent = await library_db.add_store_agent_to_library(
store_listing_version_id=store_listing_version_id,
user_id=user_id,
)
if source != "onboarding":
await complete_onboarding_step(
user_id, OnboardingStep.MARKETPLACE_ADD_AGENT
)
return agent
except store_exceptions.AgentNotFoundError as e:
logger.warning(
f"Could not find store listing version {store_listing_version_id} "
"to add to library"
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except DatabaseError as e:
logger.error(f"Database error while adding agent to library: {e}", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Inspect DB logs for details."},
) from e
except Exception as e:
logger.error(f"Unexpected error while adding agent to library: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"message": str(e),
"hint": "Check server logs for more information.",
},
) from e
@router.patch(
"/{library_agent_id}",
summary="Update Library Agent",
responses={
200: {"description": "Agent updated successfully"},
500: {"description": "Server error"},
},
)
async def update_library_agent(
library_agent_id: str,
@@ -159,21 +271,52 @@ async def update_library_agent(
) -> library_model.LibraryAgent:
"""
Update the library agent with the given fields.
Args:
library_agent_id: ID of the library agent to update.
payload: Fields to update (auto_update_version, is_favorite, etc.).
user_id: ID of the authenticated user.
Raises:
HTTPException(500): If a server/database error occurs.
"""
return await library_db.update_library_agent(
library_agent_id=library_agent_id,
user_id=user_id,
auto_update_version=payload.auto_update_version,
graph_version=payload.graph_version,
is_favorite=payload.is_favorite,
is_archived=payload.is_archived,
settings=payload.settings,
)
try:
return await library_db.update_library_agent(
library_agent_id=library_agent_id,
user_id=user_id,
auto_update_version=payload.auto_update_version,
graph_version=payload.graph_version,
is_favorite=payload.is_favorite,
is_archived=payload.is_archived,
settings=payload.settings,
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
except DatabaseError as e:
logger.error(f"Database error while updating library agent: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Verify DB connection."},
) from e
except Exception as e:
logger.error(f"Unexpected error while updating library agent: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Check server logs."},
) from e
@router.delete(
"/{library_agent_id}",
summary="Delete Library Agent",
responses={
204: {"description": "Agent deleted successfully"},
404: {"description": "Agent not found"},
500: {"description": "Server error"},
},
)
async def delete_library_agent(
library_agent_id: str,
@@ -181,11 +324,28 @@ async def delete_library_agent(
) -> Response:
"""
Soft-delete the specified library agent.
Args:
library_agent_id: ID of the library agent to delete.
user_id: ID of the authenticated user.
Returns:
204 No Content if successful.
Raises:
HTTPException(404): If the agent does not exist.
HTTPException(500): If a server/database error occurs.
"""
await library_db.delete_library_agent(
library_agent_id=library_agent_id, user_id=user_id
)
return Response(status_code=status.HTTP_204_NO_CONTENT)
try:
await library_db.delete_library_agent(
library_agent_id=library_agent_id, user_id=user_id
)
return Response(status_code=status.HTTP_204_NO_CONTENT)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
@router.post("/{library_agent_id}/fork", summary="Fork Library Agent")

View File

@@ -118,6 +118,21 @@ async def test_get_library_agents_success(
)
def test_get_library_agents_error(mocker: pytest_mock.MockFixture, test_user_id: str):
mock_db_call = mocker.patch("backend.api.features.library.db.list_library_agents")
mock_db_call.side_effect = Exception("Test error")
response = client.get("/agents?search_term=test")
assert response.status_code == 500
mock_db_call.assert_called_once_with(
user_id=test_user_id,
search_term="test",
sort_by=library_model.LibraryAgentSort.UPDATED_AT,
page=1,
page_size=15,
)
@pytest.mark.asyncio
async def test_get_favorite_library_agents_success(
mocker: pytest_mock.MockFixture,
@@ -175,6 +190,23 @@ async def test_get_favorite_library_agents_success(
)
def test_get_favorite_library_agents_error(
mocker: pytest_mock.MockFixture, test_user_id: str
):
mock_db_call = mocker.patch(
"backend.api.features.library.db.list_favorite_library_agents"
)
mock_db_call.side_effect = Exception("Test error")
response = client.get("/agents/favorites")
assert response.status_code == 500
mock_db_call.assert_called_once_with(
user_id=test_user_id,
page=1,
page_size=15,
)
def test_add_agent_to_library_success(
mocker: pytest_mock.MockFixture, test_user_id: str
):
@@ -226,3 +258,19 @@ def test_add_agent_to_library_success(
store_listing_version_id="test-version-id", user_id=test_user_id
)
mock_complete_onboarding.assert_awaited_once()
def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture, test_user_id: str):
mock_db_call = mocker.patch(
"backend.api.features.library.db.add_store_agent_to_library"
)
mock_db_call.side_effect = Exception("Test error")
response = client.post(
"/agents", json={"store_listing_version_id": "test-version-id"}
)
assert response.status_code == 500
assert "detail" in response.json() # Verify error response structure
mock_db_call.assert_called_once_with(
store_listing_version_id="test-version-id", user_id=test_user_id
)

View File

@@ -20,7 +20,6 @@ from typing import AsyncGenerator
import httpx
import pytest
import pytest_asyncio
from autogpt_libs.api_key.keysmith import APIKeySmith
from prisma.enums import APIKeyPermission
from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
@@ -39,13 +38,13 @@ keysmith = APIKeySmith()
# ============================================================================
@pytest.fixture(scope="session")
@pytest.fixture
def test_user_id() -> str:
"""Test user ID for OAuth tests."""
return str(uuid.uuid4())
@pytest_asyncio.fixture(scope="session", loop_scope="session")
@pytest.fixture
async def test_user(server, test_user_id: str):
"""Create a test user in the database."""
await PrismaUser.prisma().create(
@@ -68,7 +67,7 @@ async def test_user(server, test_user_id: str):
await PrismaUser.prisma().delete(where={"id": test_user_id})
@pytest_asyncio.fixture
@pytest.fixture
async def test_oauth_app(test_user: str):
"""Create a test OAuth application in the database."""
app_id = str(uuid.uuid4())
@@ -123,7 +122,7 @@ def pkce_credentials() -> tuple[str, str]:
return generate_pkce()
@pytest_asyncio.fixture
@pytest.fixture
async def client(server, test_user: str) -> AsyncGenerator[httpx.AsyncClient, None]:
"""
Create an async HTTP client that talks directly to the FastAPI app.
@@ -166,7 +165,7 @@ async def client(server, test_user: str) -> AsyncGenerator[httpx.AsyncClient, No
@pytest.mark.asyncio(loop_scope="session")
async def test_authorize_creates_code_in_database(
async def test_authorize_creates_code_in_database_test(
client: httpx.AsyncClient,
test_user: str,
test_oauth_app: dict,
@@ -288,7 +287,7 @@ async def test_authorize_invalid_client_returns_error(
assert query_params["error"][0] == "invalid_client"
@pytest_asyncio.fixture
@pytest.fixture
async def inactive_oauth_app(test_user: str):
"""Create an inactive test OAuth application in the database."""
app_id = str(uuid.uuid4())
@@ -1005,7 +1004,7 @@ async def test_token_refresh_revoked(
assert "revoked" in response.json()["detail"].lower()
@pytest_asyncio.fixture
@pytest.fixture
async def other_oauth_app(test_user: str):
"""Create a second OAuth application for cross-app tests."""
app_id = str(uuid.uuid4())

View File

@@ -188,10 +188,6 @@ class BlockHandler(ContentHandler):
try:
block_instance = block_cls()
# Skip disabled blocks - they shouldn't be indexed
if block_instance.disabled:
continue
# Build searchable text from block metadata
parts = []
if hasattr(block_instance, "name") and block_instance.name:
@@ -252,19 +248,12 @@ class BlockHandler(ContentHandler):
from backend.data.block import get_blocks
all_blocks = get_blocks()
# Filter out disabled blocks - they're not indexed
enabled_block_ids = [
block_id
for block_id, block_cls in all_blocks.items()
if not block_cls().disabled
]
total_blocks = len(enabled_block_ids)
total_blocks = len(all_blocks)
if total_blocks == 0:
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
block_ids = enabled_block_ids
block_ids = list(all_blocks.keys())
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
embedded_result = await query_raw_with_schema(

View File

@@ -81,7 +81,6 @@ async def test_block_handler_get_missing_items(mocker):
mock_block_instance.name = "Calculator Block"
mock_block_instance.description = "Performs calculations"
mock_block_instance.categories = [MagicMock(value="MATH")]
mock_block_instance.disabled = False
mock_block_instance.input_schema.model_json_schema.return_value = {
"properties": {"expression": {"description": "Math expression to evaluate"}}
}
@@ -117,18 +116,11 @@ async def test_block_handler_get_stats(mocker):
"""Test BlockHandler returns correct stats."""
handler = BlockHandler()
# Mock get_blocks - each block class returns an instance with disabled=False
def make_mock_block_class():
mock_class = MagicMock()
mock_instance = MagicMock()
mock_instance.disabled = False
mock_class.return_value = mock_instance
return mock_class
# Mock get_blocks
mock_blocks = {
"block-1": make_mock_block_class(),
"block-2": make_mock_block_class(),
"block-3": make_mock_block_class(),
"block-1": MagicMock(),
"block-2": MagicMock(),
"block-3": MagicMock(),
}
# Mock embedded count query (2 blocks have embeddings)
@@ -317,7 +309,6 @@ async def test_block_handler_handles_missing_attributes():
mock_block_class = MagicMock()
mock_block_instance = MagicMock()
mock_block_instance.name = "Minimal Block"
mock_block_instance.disabled = False
# No description, categories, or schema
del mock_block_instance.description
del mock_block_instance.categories
@@ -351,7 +342,6 @@ async def test_block_handler_skips_failed_blocks():
good_instance.name = "Good Block"
good_instance.description = "Works fine"
good_instance.categories = []
good_instance.disabled = False
good_block.return_value = good_instance
bad_block = MagicMock()

View File

@@ -112,7 +112,6 @@ async def get_store_agents(
description=agent["description"],
runs=agent["runs"],
rating=agent["rating"],
agent_graph_id=agent.get("agentGraphId", ""),
)
store_agents.append(store_agent)
except Exception as e:
@@ -171,7 +170,6 @@ async def get_store_agents(
description=agent.description,
runs=agent.runs,
rating=agent.rating,
agent_graph_id=agent.agentGraphId,
)
# Add to the list only if creation was successful
store_agents.append(store_agent)
@@ -1554,7 +1552,7 @@ async def review_store_submission(
# Generate embedding for approved listing (blocking - admin operation)
# Inside transaction: if embedding fails, entire transaction rolls back
await ensure_embedding(
embedding_success = await ensure_embedding(
version_id=store_listing_version_id,
name=store_listing_version.name,
description=store_listing_version.description,
@@ -1562,6 +1560,12 @@ async def review_store_submission(
categories=store_listing_version.categories or [],
tx=tx,
)
if not embedding_success:
raise ValueError(
f"Failed to generate embedding for listing {store_listing_version_id}. "
"This is likely due to OpenAI API being unavailable. "
"Please try again later or contact support if the issue persists."
)
await prisma.models.StoreListing.prisma(tx).update(
where={"id": store_listing_version.StoreListing.id},

View File

@@ -21,6 +21,7 @@ from backend.util.json import dumps
logger = logging.getLogger(__name__)
# OpenAI embedding model configuration
EMBEDDING_MODEL = "text-embedding-3-small"
# Embedding dimension for the model above
@@ -62,42 +63,49 @@ def build_searchable_text(
return " ".join(parts)
async def generate_embedding(text: str) -> list[float]:
async def generate_embedding(text: str) -> list[float] | None:
"""
Generate embedding for text using OpenAI API.
Raises exceptions on failure - caller should handle.
Returns None if embedding generation fails.
Fail-fast: no retries to maintain consistency with approval flow.
"""
client = get_openai_client()
if not client:
raise RuntimeError("openai_internal_api_key not set, cannot generate embedding")
try:
client = get_openai_client()
if not client:
logger.error("openai_internal_api_key not set, cannot generate embedding")
return None
# Truncate text to token limit using tiktoken
# Character-based truncation is insufficient because token ratios vary by content type
enc = encoding_for_model(EMBEDDING_MODEL)
tokens = enc.encode(text)
if len(tokens) > EMBEDDING_MAX_TOKENS:
tokens = tokens[:EMBEDDING_MAX_TOKENS]
truncated_text = enc.decode(tokens)
logger.info(
f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens"
# Truncate text to token limit using tiktoken
# Character-based truncation is insufficient because token ratios vary by content type
enc = encoding_for_model(EMBEDDING_MODEL)
tokens = enc.encode(text)
if len(tokens) > EMBEDDING_MAX_TOKENS:
tokens = tokens[:EMBEDDING_MAX_TOKENS]
truncated_text = enc.decode(tokens)
logger.info(
f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens"
)
else:
truncated_text = text
start_time = time.time()
response = await client.embeddings.create(
model=EMBEDDING_MODEL,
input=truncated_text,
)
else:
truncated_text = text
latency_ms = (time.time() - start_time) * 1000
start_time = time.time()
response = await client.embeddings.create(
model=EMBEDDING_MODEL,
input=truncated_text,
)
latency_ms = (time.time() - start_time) * 1000
embedding = response.data[0].embedding
logger.info(
f"Generated embedding: {len(embedding)} dims, "
f"{len(tokens)} tokens, {latency_ms:.0f}ms"
)
return embedding
embedding = response.data[0].embedding
logger.info(
f"Generated embedding: {len(embedding)} dims, "
f"{len(tokens)} tokens, {latency_ms:.0f}ms"
)
return embedding
except Exception as e:
logger.error(f"Failed to generate embedding: {e}")
return None
async def store_embedding(
@@ -136,45 +144,48 @@ async def store_content_embedding(
New function for unified content embedding storage.
Uses raw SQL since Prisma doesn't natively support pgvector.
Raises exceptions on failure - caller should handle.
"""
client = tx if tx else prisma.get_client()
try:
client = tx if tx else prisma.get_client()
# Convert embedding to PostgreSQL vector format
embedding_str = embedding_to_vector_string(embedding)
metadata_json = dumps(metadata or {})
# Convert embedding to PostgreSQL vector format
embedding_str = embedding_to_vector_string(embedding)
metadata_json = dumps(metadata or {})
# Upsert the embedding
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
# Use unqualified ::vector - pgvector is in search_path on all environments
await execute_raw_with_schema(
"""
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
# Upsert the embedding
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
# Use unqualified ::vector - pgvector is in search_path on all environments
await execute_raw_with_schema(
"""
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
)
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
ON CONFLICT ("contentType", "contentId", "userId")
DO UPDATE SET
"embedding" = $4::vector,
"searchableText" = $5,
"metadata" = $6::jsonb,
"updatedAt" = NOW()
WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType"
AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2
AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL))
""",
content_type,
content_id,
user_id,
embedding_str,
searchable_text,
metadata_json,
client=client,
)
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
ON CONFLICT ("contentType", "contentId", "userId")
DO UPDATE SET
"embedding" = $4::vector,
"searchableText" = $5,
"metadata" = $6::jsonb,
"updatedAt" = NOW()
WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType"
AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2
AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL))
""",
content_type,
content_id,
user_id,
embedding_str,
searchable_text,
metadata_json,
client=client,
)
logger.info(f"Stored embedding for {content_type}:{content_id}")
return True
logger.info(f"Stored embedding for {content_type}:{content_id}")
return True
except Exception as e:
logger.error(f"Failed to store embedding for {content_type}:{content_id}: {e}")
return False
async def get_embedding(version_id: str) -> dict[str, Any] | None:
@@ -206,31 +217,34 @@ async def get_content_embedding(
New function for unified content embedding retrieval.
Returns dict with contentType, contentId, embedding, timestamps or None if not found.
Raises exceptions on failure - caller should handle.
"""
result = await query_raw_with_schema(
"""
SELECT
"contentType",
"contentId",
"userId",
"embedding"::text as "embedding",
"searchableText",
"metadata",
"createdAt",
"updatedAt"
FROM {schema_prefix}"UnifiedContentEmbedding"
WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
""",
content_type,
content_id,
user_id,
)
try:
result = await query_raw_with_schema(
"""
SELECT
"contentType",
"contentId",
"userId",
"embedding"::text as "embedding",
"searchableText",
"metadata",
"createdAt",
"updatedAt"
FROM {schema_prefix}"UnifiedContentEmbedding"
WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
""",
content_type,
content_id,
user_id,
)
if result and len(result) > 0:
return result[0]
return None
if result and len(result) > 0:
return result[0]
return None
except Exception as e:
logger.error(f"Failed to get embedding for {content_type}:{content_id}: {e}")
return None
async def ensure_embedding(
@@ -258,38 +272,46 @@ async def ensure_embedding(
tx: Optional transaction client
Returns:
True if embedding exists/was created
Raises exceptions on failure - caller should handle.
True if embedding exists/was created, False on failure
"""
# Check if embedding already exists
if not force:
existing = await get_embedding(version_id)
if existing and existing.get("embedding"):
logger.debug(f"Embedding for version {version_id} already exists")
return True
try:
# Check if embedding already exists
if not force:
existing = await get_embedding(version_id)
if existing and existing.get("embedding"):
logger.debug(f"Embedding for version {version_id} already exists")
return True
# Build searchable text for embedding
searchable_text = build_searchable_text(name, description, sub_heading, categories)
# Build searchable text for embedding
searchable_text = build_searchable_text(
name, description, sub_heading, categories
)
# Generate new embedding
embedding = await generate_embedding(searchable_text)
# Generate new embedding
embedding = await generate_embedding(searchable_text)
if embedding is None:
logger.warning(f"Could not generate embedding for version {version_id}")
return False
# Store the embedding with metadata using new function
metadata = {
"name": name,
"subHeading": sub_heading,
"categories": categories,
}
return await store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id=version_id,
embedding=embedding,
searchable_text=searchable_text,
metadata=metadata,
user_id=None, # Store agents are public
tx=tx,
)
# Store the embedding with metadata using new function
metadata = {
"name": name,
"subHeading": sub_heading,
"categories": categories,
}
return await store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id=version_id,
embedding=embedding,
searchable_text=searchable_text,
metadata=metadata,
user_id=None, # Store agents are public
tx=tx,
)
except Exception as e:
logger.error(f"Failed to ensure embedding for version {version_id}: {e}")
return False
async def delete_embedding(version_id: str) -> bool:
@@ -454,7 +476,6 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
total_processed = 0
total_success = 0
total_failed = 0
all_errors: dict[str, int] = {} # Aggregate errors across all content types
# Process content types in explicit order
processing_order = [
@@ -500,13 +521,6 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
success = sum(1 for result in results if result is True)
failed = len(results) - success
# Aggregate errors across all content types
if failed > 0:
for result in results:
if isinstance(result, Exception):
error_key = f"{type(result).__name__}: {str(result)}"
all_errors[error_key] = all_errors.get(error_key, 0) + 1
results_by_type[content_type.value] = {
"processed": len(missing_items),
"success": success,
@@ -532,13 +546,6 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
"error": str(e),
}
# Log aggregated errors once at the end
if all_errors:
error_details = ", ".join(
f"{error} ({count}x)" for error, count in all_errors.items()
)
logger.error(f"Embedding backfill errors: {error_details}")
return {
"by_type": results_by_type,
"totals": {
@@ -550,12 +557,11 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
}
async def embed_query(query: str) -> list[float]:
async def embed_query(query: str) -> list[float] | None:
"""
Generate embedding for a search query.
Same as generate_embedding but with clearer intent.
Raises exceptions on failure - caller should handle.
"""
return await generate_embedding(query)
@@ -588,30 +594,40 @@ async def ensure_content_embedding(
tx: Optional transaction client
Returns:
True if embedding exists/was created
Raises exceptions on failure - caller should handle.
True if embedding exists/was created, False on failure
"""
# Check if embedding already exists
if not force:
existing = await get_content_embedding(content_type, content_id, user_id)
if existing and existing.get("embedding"):
logger.debug(f"Embedding for {content_type}:{content_id} already exists")
return True
try:
# Check if embedding already exists
if not force:
existing = await get_content_embedding(content_type, content_id, user_id)
if existing and existing.get("embedding"):
logger.debug(
f"Embedding for {content_type}:{content_id} already exists"
)
return True
# Generate new embedding
embedding = await generate_embedding(searchable_text)
# Generate new embedding
embedding = await generate_embedding(searchable_text)
if embedding is None:
logger.warning(
f"Could not generate embedding for {content_type}:{content_id}"
)
return False
# Store the embedding
return await store_content_embedding(
content_type=content_type,
content_id=content_id,
embedding=embedding,
searchable_text=searchable_text,
metadata=metadata or {},
user_id=user_id,
tx=tx,
)
# Store the embedding
return await store_content_embedding(
content_type=content_type,
content_id=content_id,
embedding=embedding,
searchable_text=searchable_text,
metadata=metadata or {},
user_id=user_id,
tx=tx,
)
except Exception as e:
logger.error(f"Failed to ensure embedding for {content_type}:{content_id}: {e}")
return False
async def cleanup_orphaned_embeddings() -> dict[str, Any]:
@@ -838,8 +854,9 @@ async def semantic_search(
limit = 100
# Generate query embedding
try:
query_embedding = await embed_query(query)
query_embedding = await embed_query(query)
if query_embedding is not None:
# Semantic search with embeddings
embedding_str = embedding_to_vector_string(query_embedding)
@@ -890,21 +907,24 @@ async def semantic_search(
"""
)
results = await query_raw_with_schema(sql, *params)
return [
{
"content_id": row["content_id"],
"content_type": row["content_type"],
"searchable_text": row["searchable_text"],
"metadata": row["metadata"],
"similarity": float(row["similarity"]),
}
for row in results
]
except Exception as e:
logger.warning(f"Semantic search failed, falling back to lexical search: {e}")
try:
results = await query_raw_with_schema(sql, *params)
return [
{
"content_id": row["content_id"],
"content_type": row["content_type"],
"searchable_text": row["searchable_text"],
"metadata": row["metadata"],
"similarity": float(row["similarity"]),
}
for row in results
]
except Exception as e:
logger.error(f"Semantic search failed: {e}")
# Fall through to lexical search below
# Fallback to lexical search if embeddings unavailable
logger.warning("Falling back to lexical search (embeddings unavailable)")
params_lexical: list[Any] = [limit]
user_filter = ""

View File

@@ -454,9 +454,6 @@ async def test_unified_hybrid_search_pagination(
cleanup_embeddings: list,
):
"""Test unified search pagination works correctly."""
# Use a unique search term to avoid matching other test data
unique_term = f"xyzpagtest{uuid.uuid4().hex[:8]}"
# Create multiple items
content_ids = []
for i in range(5):
@@ -468,14 +465,14 @@ async def test_unified_hybrid_search_pagination(
content_type=ContentType.BLOCK,
content_id=content_id,
embedding=mock_embedding,
searchable_text=f"{unique_term} item number {i}",
searchable_text=f"pagination test item number {i}",
metadata={"index": i},
user_id=None,
)
# Get first page
page1_results, total1 = await unified_hybrid_search(
query=unique_term,
query="pagination test",
content_types=[ContentType.BLOCK],
page=1,
page_size=2,
@@ -483,7 +480,7 @@ async def test_unified_hybrid_search_pagination(
# Get second page
page2_results, total2 = await unified_hybrid_search(
query=unique_term,
query="pagination test",
content_types=[ContentType.BLOCK],
page=2,
page_size=2,

View File

@@ -298,16 +298,17 @@ async def test_schema_handling_error_cases():
mock_client.execute_raw.side_effect = Exception("Database error")
mock_get_client.return_value = mock_client
# Should raise exception on error
with pytest.raises(Exception, match="Database error"):
await embeddings.store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id="test-id",
embedding=[0.1] * EMBEDDING_DIM,
searchable_text="test",
metadata=None,
user_id=None,
)
result = await embeddings.store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id="test-id",
embedding=[0.1] * EMBEDDING_DIM,
searchable_text="test",
metadata=None,
user_id=None,
)
# Should return False on error, not raise
assert result is False
if __name__ == "__main__":

View File

@@ -80,8 +80,9 @@ async def test_generate_embedding_no_api_key():
) as mock_get_client:
mock_get_client.return_value = None
with pytest.raises(RuntimeError, match="openai_internal_api_key not set"):
await embeddings.generate_embedding("test text")
result = await embeddings.generate_embedding("test text")
assert result is None
@pytest.mark.asyncio(loop_scope="session")
@@ -96,8 +97,9 @@ async def test_generate_embedding_api_error():
) as mock_get_client:
mock_get_client.return_value = mock_client
with pytest.raises(Exception, match="API Error"):
await embeddings.generate_embedding("test text")
result = await embeddings.generate_embedding("test text")
assert result is None
@pytest.mark.asyncio(loop_scope="session")
@@ -171,10 +173,11 @@ async def test_store_embedding_database_error(mocker):
embedding = [0.1, 0.2, 0.3]
with pytest.raises(Exception, match="Database error"):
await embeddings.store_embedding(
version_id="test-version-id", embedding=embedding, tx=mock_client
)
result = await embeddings.store_embedding(
version_id="test-version-id", embedding=embedding, tx=mock_client
)
assert result is False
@pytest.mark.asyncio(loop_scope="session")
@@ -274,16 +277,17 @@ async def test_ensure_embedding_create_new(mock_get, mock_store, mock_generate):
async def test_ensure_embedding_generation_fails(mock_get, mock_generate):
"""Test ensure_embedding when generation fails."""
mock_get.return_value = None
mock_generate.side_effect = Exception("Generation failed")
mock_generate.return_value = None
with pytest.raises(Exception, match="Generation failed"):
await embeddings.ensure_embedding(
version_id="test-id",
name="Test",
description="Test description",
sub_heading="Test heading",
categories=["test"],
)
result = await embeddings.ensure_embedding(
version_id="test-id",
name="Test",
description="Test description",
sub_heading="Test heading",
categories=["test"],
)
assert result is False
@pytest.mark.asyncio(loop_scope="session")

View File

@@ -186,12 +186,13 @@ async def unified_hybrid_search(
offset = (page - 1) * page_size
# Generate query embedding with graceful degradation
try:
query_embedding = await embed_query(query)
except Exception as e:
# Generate query embedding
query_embedding = await embed_query(query)
# Graceful degradation if embedding unavailable
if query_embedding is None or not query_embedding:
logger.warning(
f"Failed to generate query embedding - falling back to lexical-only search: {e}. "
"Failed to generate query embedding - falling back to lexical-only search. "
"Check that openai_internal_api_key is configured and OpenAI API is accessible."
)
query_embedding = [0.0] * EMBEDDING_DIM
@@ -463,12 +464,13 @@ async def hybrid_search(
offset = (page - 1) * page_size
# Generate query embedding with graceful degradation
try:
query_embedding = await embed_query(query)
except Exception as e:
# Generate query embedding
query_embedding = await embed_query(query)
# Graceful degradation
if query_embedding is None or not query_embedding:
logger.warning(
f"Failed to generate query embedding - falling back to lexical-only search: {e}"
"Failed to generate query embedding - falling back to lexical-only search."
)
query_embedding = [0.0] * EMBEDDING_DIM
total_non_semantic = (
@@ -600,7 +602,6 @@ async def hybrid_search(
sa.featured,
sa.is_available,
sa.updated_at,
sa."agentGraphId",
-- Searchable text for BM25 reranking
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
-- Semantic score
@@ -660,7 +661,6 @@ async def hybrid_search(
featured,
is_available,
updated_at,
"agentGraphId",
searchable_text,
semantic_score,
lexical_score,

View File

@@ -172,8 +172,8 @@ async def test_hybrid_search_without_embeddings():
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
# Simulate embedding failure by raising exception
mock_embed.side_effect = Exception("Embedding generation failed")
# Simulate embedding failure
mock_embed.return_value = None
mock_query.return_value = mock_results
# Should NOT raise - graceful degradation
@@ -613,9 +613,7 @@ async def test_unified_hybrid_search_graceful_degradation():
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_query.return_value = mock_results
mock_embed.side_effect = Exception(
"Embedding generation failed"
) # Embedding failure
mock_embed.return_value = None # Embedding failure
# Should NOT raise - graceful degradation
results, total = await unified_hybrid_search(

View File

@@ -38,7 +38,6 @@ class StoreAgent(pydantic.BaseModel):
description: str
runs: int
rating: float
agent_graph_id: str
class StoreAgentsResponse(pydantic.BaseModel):

View File

@@ -26,13 +26,11 @@ def test_store_agent():
description="Test description",
runs=50,
rating=4.5,
agent_graph_id="test-graph-id",
)
assert agent.slug == "test-agent"
assert agent.agent_name == "Test Agent"
assert agent.runs == 50
assert agent.rating == 4.5
assert agent.agent_graph_id == "test-graph-id"
def test_store_agents_response():
@@ -48,7 +46,6 @@ def test_store_agents_response():
description="Test description",
runs=50,
rating=4.5,
agent_graph_id="test-graph-id",
)
],
pagination=store_model.Pagination(

View File

@@ -82,7 +82,6 @@ def test_get_agents_featured(
description="Featured agent description",
runs=100,
rating=4.5,
agent_graph_id="test-graph-1",
)
],
pagination=store_model.Pagination(
@@ -128,7 +127,6 @@ def test_get_agents_by_creator(
description="Creator agent description",
runs=50,
rating=4.0,
agent_graph_id="test-graph-2",
)
],
pagination=store_model.Pagination(
@@ -174,7 +172,6 @@ def test_get_agents_sorted(
description="Top agent description",
runs=1000,
rating=5.0,
agent_graph_id="test-graph-3",
)
],
pagination=store_model.Pagination(
@@ -220,7 +217,6 @@ def test_get_agents_search(
description="Specific search term description",
runs=75,
rating=4.2,
agent_graph_id="test-graph-search",
)
],
pagination=store_model.Pagination(
@@ -266,7 +262,6 @@ def test_get_agents_category(
description="Category agent description",
runs=60,
rating=4.1,
agent_graph_id="test-graph-category",
)
],
pagination=store_model.Pagination(
@@ -311,7 +306,6 @@ def test_get_agents_pagination(
description=f"Agent {i} description",
runs=i * 10,
rating=4.0,
agent_graph_id="test-graph-2",
)
for i in range(5)
],

View File

@@ -33,7 +33,6 @@ class TestCacheDeletion:
description="Test description",
runs=100,
rating=4.5,
agent_graph_id="test-graph-id",
)
],
pagination=Pagination(

View File

@@ -261,36 +261,14 @@ async def get_onboarding_agents(
return await get_recommended_agents(user_id)
class OnboardingStatusResponse(pydantic.BaseModel):
"""Response for onboarding status check."""
is_onboarding_enabled: bool
is_chat_enabled: bool
@v1_router.get(
"/onboarding/enabled",
summary="Is onboarding enabled",
tags=["onboarding", "public"],
response_model=OnboardingStatusResponse,
dependencies=[Security(requires_user)],
)
async def is_onboarding_enabled(
user_id: Annotated[str, Security(get_user_id)],
) -> OnboardingStatusResponse:
# Check if chat is enabled for user
is_chat_enabled = await is_feature_enabled(Flag.CHAT, user_id, False)
# If chat is enabled, skip legacy onboarding
if is_chat_enabled:
return OnboardingStatusResponse(
is_onboarding_enabled=False,
is_chat_enabled=True,
)
return OnboardingStatusResponse(
is_onboarding_enabled=await onboarding_enabled(),
is_chat_enabled=False,
)
async def is_onboarding_enabled() -> bool:
return await onboarding_enabled()
@v1_router.post(
@@ -386,8 +364,6 @@ async def execute_graph_block(
obj = get_block(block_id)
if not obj:
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
if obj.disabled:
raise HTTPException(status_code=403, detail=f"Block #{block_id} is disabled.")
user = await get_user_by_id(user_id)
if not user:

View File

@@ -138,7 +138,6 @@ def test_execute_graph_block(
"""Test execute block endpoint"""
# Mock block
mock_block = Mock()
mock_block.disabled = False
async def mock_execute(*args, **kwargs):
yield "output1", {"data": "result1"}

View File

@@ -1 +0,0 @@
# Workspace API feature module

View File

@@ -1,122 +0,0 @@
"""
Workspace API routes for managing user file storage.
"""
import logging
import re
from typing import Annotated
from urllib.parse import quote
import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
from fastapi.responses import Response
from backend.data.workspace import get_workspace, get_workspace_file
from backend.util.workspace_storage import get_workspace_storage
def _sanitize_filename_for_header(filename: str) -> str:
"""
Sanitize filename for Content-Disposition header to prevent header injection.
Removes/replaces characters that could break the header or inject new headers.
Uses RFC5987 encoding for non-ASCII characters.
"""
# Remove CR, LF, and null bytes (header injection prevention)
sanitized = re.sub(r"[\r\n\x00]", "", filename)
# Escape quotes
sanitized = sanitized.replace('"', '\\"')
# For non-ASCII, use RFC5987 filename* parameter
# Check if filename has non-ASCII characters
try:
sanitized.encode("ascii")
return f'attachment; filename="{sanitized}"'
except UnicodeEncodeError:
# Use RFC5987 encoding for UTF-8 filenames
encoded = quote(sanitized, safe="")
return f"attachment; filename*=UTF-8''{encoded}"
logger = logging.getLogger(__name__)
router = fastapi.APIRouter(
dependencies=[fastapi.Security(requires_user)],
)
def _create_streaming_response(content: bytes, file) -> Response:
"""Create a streaming response for file content."""
return Response(
content=content,
media_type=file.mimeType,
headers={
"Content-Disposition": _sanitize_filename_for_header(file.name),
"Content-Length": str(len(content)),
},
)
async def _create_file_download_response(file) -> Response:
"""
Create a download response for a workspace file.
Handles both local storage (direct streaming) and GCS (signed URL redirect
with fallback to streaming).
"""
storage = await get_workspace_storage()
# For local storage, stream the file directly
if file.storagePath.startswith("local://"):
content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file)
# For GCS, try to redirect to signed URL, fall back to streaming
try:
url = await storage.get_download_url(file.storagePath, expires_in=300)
# If we got back an API path (fallback), stream directly instead
if url.startswith("/api/"):
content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file)
return fastapi.responses.RedirectResponse(url=url, status_code=302)
except Exception as e:
# Log the signed URL failure with context
logger.error(
f"Failed to get signed URL for file {file.id} "
f"(storagePath={file.storagePath}): {e}",
exc_info=True,
)
# Fall back to streaming directly from GCS
try:
content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file)
except Exception as fallback_error:
logger.error(
f"Fallback streaming also failed for file {file.id} "
f"(storagePath={file.storagePath}): {fallback_error}",
exc_info=True,
)
raise
@router.get(
"/files/{file_id}/download",
summary="Download file by ID",
)
async def download_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
file_id: str,
) -> Response:
"""
Download a file by its ID.
Returns the file content directly or redirects to a signed URL for GCS.
"""
workspace = await get_workspace(user_id)
if workspace is None:
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
file = await get_workspace_file(file_id, workspace.id)
if file is None:
raise fastapi.HTTPException(status_code=404, detail="File not found")
return await _create_file_download_response(file)

View File

@@ -32,7 +32,6 @@ import backend.api.features.postmark.postmark
import backend.api.features.store.model
import backend.api.features.store.routes
import backend.api.features.v1
import backend.api.features.workspace.routes as workspace_routes
import backend.data.block
import backend.data.db
import backend.data.graph
@@ -40,10 +39,6 @@ import backend.data.user
import backend.integrations.webhooks.utils
import backend.util.service
import backend.util.settings
from backend.api.features.chat.completion_consumer import (
start_completion_consumer,
stop_completion_consumer,
)
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
@@ -57,7 +52,6 @@ from backend.util.exceptions import (
)
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
from backend.util.service import UnhealthyServiceError
from backend.util.workspace_storage import shutdown_workspace_storage
from .external.fastapi_app import external_api
from .features.analytics import router as analytics_router
@@ -122,31 +116,14 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
# Start chat completion consumer for Redis Streams notifications
try:
await start_completion_consumer()
except Exception as e:
logger.warning(f"Could not start chat completion consumer: {e}")
with launch_darkly_context():
yield
# Stop chat completion consumer
try:
await stop_completion_consumer()
except Exception as e:
logger.warning(f"Error stopping chat completion consumer: {e}")
try:
await shutdown_cloud_storage_handler()
except Exception as e:
logger.warning(f"Error shutting down cloud storage handler: {e}")
try:
await shutdown_workspace_storage()
except Exception as e:
logger.warning(f"Error shutting down workspace storage: {e}")
await backend.data.db.disconnect()
@@ -338,11 +315,6 @@ app.include_router(
tags=["v2", "chat"],
prefix="/api/chat",
)
app.include_router(
workspace_routes.router,
tags=["workspace"],
prefix="/api/workspace",
)
app.include_router(
backend.api.features.oauth.router,
tags=["oauth"],

View File

@@ -66,24 +66,18 @@ async def event_broadcaster(manager: ConnectionManager):
execution_bus = AsyncRedisExecutionEventBus()
notification_bus = AsyncRedisNotificationEventBus()
try:
async def execution_worker():
async for event in execution_bus.listen("*"):
await manager.send_execution_update(event)
async def execution_worker():
async for event in execution_bus.listen("*"):
await manager.send_execution_update(event)
async def notification_worker():
async for notification in notification_bus.listen("*"):
await manager.send_notification(
user_id=notification.user_id,
payload=notification.payload,
)
async def notification_worker():
async for notification in notification_bus.listen("*"):
await manager.send_notification(
user_id=notification.user_id,
payload=notification.payload,
)
await asyncio.gather(execution_worker(), notification_worker())
finally:
# Ensure PubSub connections are closed on any exit to prevent leaks
await execution_bus.close()
await notification_bus.close()
await asyncio.gather(execution_worker(), notification_worker())
async def authenticate_websocket(websocket: WebSocket) -> str:

View File

@@ -13,7 +13,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -118,13 +117,11 @@ class AIImageCustomizerBlock(Block):
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
# Output will be a workspace ref or data URI depending on context
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
("image_url", "https://replicate.delivery/generated-image.jpg"),
],
test_mock={
# Use data URI to avoid HTTP requests during tests
"run_model": lambda *args, **kwargs: MediaFileType(
"data:image/jpeg;base64,/9j/4AAQSkZJRgABAgAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAABAAEDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD3+iiigD//2Q=="
"https://replicate.delivery/generated-image.jpg"
),
},
test_credentials=TEST_CREDENTIALS,
@@ -135,7 +132,8 @@ class AIImageCustomizerBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
try:
@@ -143,9 +141,10 @@ class AIImageCustomizerBlock(Block):
processed_images = await asyncio.gather(
*(
store_media_file(
graph_exec_id=graph_exec_id,
file=img,
execution_context=execution_context,
return_format="for_external_api", # Get content for Replicate API
user_id=user_id,
return_content=True,
)
for img in input_data.images
)
@@ -159,14 +158,7 @@ class AIImageCustomizerBlock(Block):
aspect_ratio=input_data.aspect_ratio.value,
output_format=input_data.output_format.value,
)
# Store the generated image to the user's workspace for persistence
stored_url = await store_media_file(
file=result,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", stored_url
yield "image_url", result
except Exception as e:
yield "error", str(e)

View File

@@ -6,7 +6,6 @@ from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -14,8 +13,6 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
class ImageSize(str, Enum):
@@ -168,13 +165,11 @@ class AIImageGeneratorBlock(Block):
test_output=[
(
"image_url",
# Test output is a data URI since we now store images
lambda x: x.startswith("data:image/"),
"https://replicate.delivery/generated-image.webp",
),
],
test_mock={
# Return a data URI directly so store_media_file doesn't need to download
"_run_client": lambda *args, **kwargs: "data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
"_run_client": lambda *args, **kwargs: "https://replicate.delivery/generated-image.webp"
},
)
@@ -323,24 +318,11 @@ class AIImageGeneratorBlock(Block):
style_text = style_map.get(style, "")
return f"{style_text} of" if style_text else ""
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
):
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
try:
url = await self.generate_image(input_data, credentials)
if url:
# Store the generated image to the user's workspace/execution folder
stored_url = await store_media_file(
file=MediaFileType(url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", stored_url
yield "image_url", url
else:
yield "error", "Image generation returned an empty result."
except Exception as e:

View File

@@ -13,7 +13,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -22,9 +21,7 @@ from backend.data.model import (
)
from backend.integrations.providers import ProviderName
from backend.util.exceptions import BlockExecutionError
from backend.util.file import store_media_file
from backend.util.request import Requests
from backend.util.type import MediaFileType
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@@ -274,10 +271,7 @@ class AIShortformVideoCreatorBlock(Block):
"voice": Voice.LILY,
"video_style": VisualMediaType.STOCK_VIDEOS,
},
test_output=(
"video_url",
lambda x: x.startswith(("workspace://", "data:")),
),
test_output=("video_url", "https://example.com/video.mp4"),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
@@ -286,21 +280,15 @@ class AIShortformVideoCreatorBlock(Block):
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "data:video/mp4;base64,AAAA",
"videoUrl": "https://example.com/video.mp4",
},
# Use data URI to avoid HTTP requests during tests
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
"wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
# Create a new Webhook.site URL
webhook_token, webhook_url = await self.create_webhook()
@@ -352,13 +340,7 @@ class AIShortformVideoCreatorBlock(Block):
)
video_url = await self.wait_for_video(credentials.api_key, pid)
logger.debug(f"Video ready: {video_url}")
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
yield "video_url", video_url
class AIAdMakerVideoCreatorBlock(Block):
@@ -465,10 +447,7 @@ class AIAdMakerVideoCreatorBlock(Block):
"https://cdn.revid.ai/uploads/1747076315114-image.png",
],
},
test_output=(
"video_url",
lambda x: x.startswith(("workspace://", "data:")),
),
test_output=("video_url", "https://example.com/ad.mp4"),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
@@ -477,21 +456,14 @@ class AIAdMakerVideoCreatorBlock(Block):
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "data:video/mp4;base64,AAAA",
"videoUrl": "https://example.com/ad.mp4",
},
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
"wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
):
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
webhook_token, webhook_url = await self.create_webhook()
payload = {
@@ -559,13 +531,7 @@ class AIAdMakerVideoCreatorBlock(Block):
raise RuntimeError("Failed to create video: No project ID returned")
video_url = await self.wait_for_video(credentials.api_key, pid)
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
yield "video_url", video_url
class AIScreenshotToVideoAdBlock(Block):
@@ -660,10 +626,7 @@ class AIScreenshotToVideoAdBlock(Block):
"script": "Amazing numbers!",
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
},
test_output=(
"video_url",
lambda x: x.startswith(("workspace://", "data:")),
),
test_output=("video_url", "https://example.com/screenshot.mp4"),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
@@ -672,21 +635,14 @@ class AIScreenshotToVideoAdBlock(Block):
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "data:video/mp4;base64,AAAA",
"videoUrl": "https://example.com/screenshot.mp4",
},
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
"wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
):
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
webhook_token, webhook_url = await self.create_webhook()
payload = {
@@ -754,10 +710,4 @@ class AIScreenshotToVideoAdBlock(Block):
raise RuntimeError("Failed to create video: No project ID returned")
video_url = await self.wait_for_video(credentials.api_key, pid)
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
yield "video_url", video_url

View File

@@ -6,7 +6,6 @@ if TYPE_CHECKING:
from pydantic import SecretStr
from backend.data.execution import ExecutionContext
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -18,8 +17,6 @@ from backend.sdk import (
Requests,
SchemaField,
)
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
from ._config import bannerbear
@@ -138,17 +135,15 @@ class BannerbearTextOverlayBlock(Block):
},
test_output=[
("success", True),
# Output will be a workspace ref or data URI depending on context
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
("image_url", "https://cdn.bannerbear.com/test-image.jpg"),
("uid", "test-uid-123"),
("status", "completed"),
],
test_mock={
# Use data URI to avoid HTTP requests during tests
"_make_api_request": lambda *args, **kwargs: {
"uid": "test-uid-123",
"status": "completed",
"image_url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAABAAEBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+v//Z",
"image_url": "https://cdn.bannerbear.com/test-image.jpg",
}
},
test_credentials=TEST_CREDENTIALS,
@@ -182,12 +177,7 @@ class BannerbearTextOverlayBlock(Block):
raise Exception(error_msg)
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
# Build the modifications array
modifications = []
@@ -244,18 +234,6 @@ class BannerbearTextOverlayBlock(Block):
# Synchronous request - image should be ready
yield "success", True
# Store the generated image to workspace for persistence
image_url = data.get("image_url", "")
if image_url:
stored_url = await store_media_file(
file=MediaFileType(image_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", stored_url
else:
yield "image_url", ""
yield "image_url", data.get("image_url", "")
yield "uid", data.get("uid", "")
yield "status", data.get("status", "completed")

View File

@@ -9,7 +9,6 @@ from backend.data.block import (
BlockSchemaOutput,
BlockType,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import store_media_file
from backend.util.type import MediaFileType, convert
@@ -18,10 +17,10 @@ from backend.util.type import MediaFileType, convert
class FileStoreBlock(Block):
class Input(BlockSchemaInput):
file_in: MediaFileType = SchemaField(
description="The file to download and store. Can be a URL (https://...), data URI, or local path."
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
)
base_64: bool = SchemaField(
description="Whether to produce output in base64 format (not recommended, you can pass the file reference across blocks).",
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
default=False,
advanced=True,
title="Produce Base64 Output",
@@ -29,18 +28,13 @@ class FileStoreBlock(Block):
class Output(BlockSchemaOutput):
file_out: MediaFileType = SchemaField(
description="Reference to the stored file. In CoPilot: workspace:// URI (visible in list_workspace_files). In graphs: data URI for passing to other blocks."
description="The relative path to the stored file in the temporary directory."
)
def __init__(self):
super().__init__(
id="cbb50872-625b-42f0-8203-a2ae78242d8a",
description=(
"Downloads and stores a file from a URL, data URI, or local path. "
"Use this to fetch images, documents, or other files for processing. "
"In CoPilot: saves to workspace (use list_workspace_files to see it). "
"In graphs: outputs a data URI to pass to other blocks."
),
description="Stores the input file in the temporary directory.",
categories={BlockCategory.BASIC, BlockCategory.MULTIMEDIA},
input_schema=FileStoreBlock.Input,
output_schema=FileStoreBlock.Output,
@@ -51,18 +45,15 @@ class FileStoreBlock(Block):
self,
input_data: Input,
*,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
# Determine return format based on user preference
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
yield "file_out", await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.file_in,
execution_context=execution_context,
return_format=return_format,
user_id=user_id,
return_content=input_data.base_64,
)
@@ -125,7 +116,6 @@ class PrintToConsoleBlock(Block):
input_schema=PrintToConsoleBlock.Input,
output_schema=PrintToConsoleBlock.Output,
test_input={"text": "Hello, World!"},
is_sensitive_action=True,
test_output=[
("output", "Hello, World!"),
("status", "printed"),

View File

@@ -15,7 +15,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import APIKeyCredentials, SchemaField
from backend.util.file import store_media_file
from backend.util.request import Requests
@@ -667,7 +666,8 @@ class SendDiscordFileBlock(Block):
file: MediaFileType,
filename: str,
message_content: str,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
) -> dict:
intents = discord.Intents.default()
intents.guilds = True
@@ -731,9 +731,10 @@ class SendDiscordFileBlock(Block):
# Local file path - read from stored media file
# This would be a path from a previous block's output
stored_file = await store_media_file(
graph_exec_id=graph_exec_id,
file=file,
execution_context=execution_context,
return_format="for_external_api", # Get content to send to Discord
user_id=user_id,
return_content=True, # Get as data URI
)
# Now process as data URI
header, encoded = stored_file.split(",", 1)
@@ -780,7 +781,8 @@ class SendDiscordFileBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
try:
@@ -791,7 +793,8 @@ class SendDiscordFileBlock(Block):
file=input_data.file,
filename=input_data.filename,
message_content=input_data.message_content,
execution_context=execution_context,
graph_exec_id=graph_exec_id,
user_id=user_id,
)
yield "status", result.get("status", "Unknown error")

View File

@@ -17,11 +17,8 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import store_media_file
from backend.util.request import ClientResponseError, Requests
from backend.util.type import MediaFileType
logger = logging.getLogger(__name__)
@@ -67,13 +64,9 @@ class AIVideoGeneratorBlock(Block):
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
# Output will be a workspace ref or data URI depending on context
("video_url", lambda x: x.startswith(("workspace://", "data:"))),
],
test_output=[("video_url", "https://fal.media/files/example/video.mp4")],
test_mock={
# Use data URI to avoid HTTP requests during tests
"generate_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA"
"generate_video": lambda *args, **kwargs: "https://fal.media/files/example/video.mp4"
},
)
@@ -215,22 +208,11 @@ class AIVideoGeneratorBlock(Block):
raise RuntimeError(f"API request failed: {str(e)}")
async def run(
self,
input_data: Input,
*,
credentials: FalCredentials,
execution_context: ExecutionContext,
**kwargs,
self, input_data: Input, *, credentials: FalCredentials, **kwargs
) -> BlockOutput:
try:
video_url = await self.generate_video(input_data, credentials)
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
yield "video_url", video_url
except Exception as e:
error_message = str(e)
yield "error", error_message

View File

@@ -12,7 +12,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -122,12 +121,10 @@ class AIImageEditorBlock(Block):
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
# Output will be a workspace ref or data URI depending on context
("output_image", lambda x: x.startswith(("workspace://", "data:"))),
("output_image", "https://replicate.com/output/edited-image.png"),
],
test_mock={
# Use data URI to avoid HTTP requests during tests
"run_model": lambda *args, **kwargs: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
"run_model": lambda *args, **kwargs: "https://replicate.com/output/edited-image.png",
},
test_credentials=TEST_CREDENTIALS,
)
@@ -137,7 +134,8 @@ class AIImageEditorBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
result = await self.run_model(
@@ -146,25 +144,20 @@ class AIImageEditorBlock(Block):
prompt=input_data.prompt,
input_image_b64=(
await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.input_image,
execution_context=execution_context,
return_format="for_external_api", # Get content for Replicate API
user_id=user_id,
return_content=True,
)
if input_data.input_image
else None
),
aspect_ratio=input_data.aspect_ratio.value,
seed=input_data.seed,
user_id=execution_context.user_id or "",
graph_exec_id=execution_context.graph_exec_id or "",
user_id=user_id,
graph_exec_id=graph_exec_id,
)
# Store the generated image to the user's workspace for persistence
stored_url = await store_media_file(
file=result,
execution_context=execution_context,
return_format="for_block_output",
)
yield "output_image", stored_url
yield "output_image", result
async def run_model(
self,

View File

@@ -21,7 +21,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
from backend.util.settings import Settings
@@ -96,7 +95,8 @@ def _make_mime_text(
async def create_mime_message(
input_data,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
) -> str:
"""Create a MIME message with attachments and return base64-encoded raw message."""
@@ -117,12 +117,12 @@ async def create_mime_message(
if input_data.attachments:
for attach in input_data.attachments:
local_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=attach,
execution_context=execution_context,
return_format="for_local_processing",
return_content=False,
)
assert execution_context.graph_exec_id # Validated by store_media_file
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
abs_path = get_exec_file_path(graph_exec_id, local_path)
part = MIMEBase("application", "octet-stream")
with open(abs_path, "rb") as f:
part.set_payload(f.read())
@@ -582,25 +582,27 @@ class GmailSendBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
result = await self._send_email(
service,
input_data,
execution_context,
graph_exec_id,
user_id,
)
yield "result", result
async def _send_email(
self, service, input_data: Input, execution_context: ExecutionContext
self, service, input_data: Input, graph_exec_id: str, user_id: str
) -> dict:
if not input_data.to or not input_data.subject or not input_data.body:
raise ValueError(
"At least one recipient, subject, and body are required for sending an email"
)
raw_message = await create_mime_message(input_data, execution_context)
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
sent_message = await asyncio.to_thread(
lambda: service.users()
.messages()
@@ -690,28 +692,30 @@ class GmailCreateDraftBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
result = await self._create_draft(
service,
input_data,
execution_context,
graph_exec_id,
user_id,
)
yield "result", GmailDraftResult(
id=result["id"], message_id=result["message"]["id"], status="draft_created"
)
async def _create_draft(
self, service, input_data: Input, execution_context: ExecutionContext
self, service, input_data: Input, graph_exec_id: str, user_id: str
) -> dict:
if not input_data.to or not input_data.subject:
raise ValueError(
"At least one recipient and subject are required for creating a draft"
)
raw_message = await create_mime_message(input_data, execution_context)
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
draft = await asyncio.to_thread(
lambda: service.users()
.drafts()
@@ -1096,7 +1100,7 @@ class GmailGetThreadBlock(GmailBase):
async def _build_reply_message(
service, input_data, execution_context: ExecutionContext
service, input_data, graph_exec_id: str, user_id: str
) -> tuple[str, str]:
"""
Builds a reply MIME message for Gmail threads.
@@ -1186,12 +1190,12 @@ async def _build_reply_message(
# Handle attachments
for attach in input_data.attachments:
local_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=attach,
execution_context=execution_context,
return_format="for_local_processing",
return_content=False,
)
assert execution_context.graph_exec_id # Validated by store_media_file
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
abs_path = get_exec_file_path(graph_exec_id, local_path)
part = MIMEBase("application", "octet-stream")
with open(abs_path, "rb") as f:
part.set_payload(f.read())
@@ -1307,14 +1311,16 @@ class GmailReplyBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
message = await self._reply(
service,
input_data,
execution_context,
graph_exec_id,
user_id,
)
yield "messageId", message["id"]
yield "threadId", message.get("threadId", input_data.threadId)
@@ -1337,11 +1343,11 @@ class GmailReplyBlock(GmailBase):
yield "email", email
async def _reply(
self, service, input_data: Input, execution_context: ExecutionContext
self, service, input_data: Input, graph_exec_id: str, user_id: str
) -> dict:
# Build the reply message using the shared helper
raw, thread_id = await _build_reply_message(
service, input_data, execution_context
service, input_data, graph_exec_id, user_id
)
# Send the message
@@ -1435,14 +1441,16 @@ class GmailDraftReplyBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
draft = await self._create_draft_reply(
service,
input_data,
execution_context,
graph_exec_id,
user_id,
)
yield "draftId", draft["id"]
yield "messageId", draft["message"]["id"]
@@ -1450,11 +1458,11 @@ class GmailDraftReplyBlock(GmailBase):
yield "status", "draft_created"
async def _create_draft_reply(
self, service, input_data: Input, execution_context: ExecutionContext
self, service, input_data: Input, graph_exec_id: str, user_id: str
) -> dict:
# Build the reply message using the shared helper
raw, thread_id = await _build_reply_message(
service, input_data, execution_context
service, input_data, graph_exec_id, user_id
)
# Create draft with proper thread association
@@ -1621,21 +1629,23 @@ class GmailForwardBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
result = await self._forward_message(
service,
input_data,
execution_context,
graph_exec_id,
user_id,
)
yield "messageId", result["id"]
yield "threadId", result.get("threadId", "")
yield "status", "forwarded"
async def _forward_message(
self, service, input_data: Input, execution_context: ExecutionContext
self, service, input_data: Input, graph_exec_id: str, user_id: str
) -> dict:
if not input_data.to:
raise ValueError("At least one recipient is required for forwarding")
@@ -1717,12 +1727,12 @@ To: {original_to}
# Add any additional attachments
for attach in input_data.additionalAttachments:
local_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=attach,
execution_context=execution_context,
return_format="for_local_processing",
return_content=False,
)
assert execution_context.graph_exec_id # Validated by store_media_file
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
abs_path = get_exec_file_path(graph_exec_id, local_path)
part = MIMEBase("application", "octet-stream")
with open(abs_path, "rb") as f:
part.set_payload(f.read())

View File

@@ -9,7 +9,7 @@ from typing import Any, Optional
from prisma.enums import ReviewStatus
from pydantic import BaseModel
from backend.data.execution import ExecutionStatus
from backend.data.execution import ExecutionContext, ExecutionStatus
from backend.data.human_review import ReviewResult
from backend.executor.manager import async_update_node_execution_status
from backend.util.clients import get_database_manager_async_client
@@ -28,11 +28,6 @@ class ReviewDecision(BaseModel):
class HITLReviewHelper:
"""Helper class for Human-In-The-Loop review operations."""
@staticmethod
async def check_approval(**kwargs) -> Optional[ReviewResult]:
"""Check if there's an existing approval for this node execution."""
return await get_database_manager_async_client().check_approval(**kwargs)
@staticmethod
async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]:
"""Create or retrieve a human review from the database."""
@@ -60,11 +55,11 @@ class HITLReviewHelper:
async def _handle_review_request(
input_data: Any,
user_id: str,
node_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
block_name: str = "Block",
editable: bool = False,
) -> Optional[ReviewResult]:
@@ -74,11 +69,11 @@ class HITLReviewHelper:
Args:
input_data: The input data to be reviewed
user_id: ID of the user requesting the review
node_id: ID of the node in the graph definition
node_exec_id: ID of the node execution
graph_exec_id: ID of the graph execution
graph_id: ID of the graph
graph_version: Version of the graph
execution_context: Current execution context
block_name: Name of the block requesting review
editable: Whether the reviewer can edit the data
@@ -88,41 +83,15 @@ class HITLReviewHelper:
Raises:
Exception: If review creation or status update fails
"""
# Note: Safe mode checks (human_in_the_loop_safe_mode, sensitive_action_safe_mode)
# are handled by the caller:
# - HITL blocks check human_in_the_loop_safe_mode in their run() method
# - Sensitive action blocks check sensitive_action_safe_mode in is_block_exec_need_review()
# This function only handles checking for existing approvals.
# Check if this node has already been approved (normal or auto-approval)
if approval_result := await HITLReviewHelper.check_approval(
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
node_id=node_id,
user_id=user_id,
input_data=input_data,
):
# Skip review if safe mode is disabled - return auto-approved result
if not execution_context.human_in_the_loop_safe_mode:
logger.info(
f"Block {block_name} skipping review for node {node_exec_id} - "
f"found existing approval"
)
# Return a new ReviewResult with the current node_exec_id but approved status
# For auto-approvals, always use current input_data
# For normal approvals, use approval_result.data unless it's None
is_auto_approval = approval_result.node_exec_id != node_exec_id
approved_data = (
input_data
if is_auto_approval
else (
approval_result.data
if approval_result.data is not None
else input_data
)
f"Block {block_name} skipping review for node {node_exec_id} - safe mode disabled"
)
return ReviewResult(
data=approved_data,
data=input_data,
status=ReviewStatus.APPROVED,
message=approval_result.message,
message="Auto-approved (safe mode disabled)",
processed=True,
node_exec_id=node_exec_id,
)
@@ -134,7 +103,7 @@ class HITLReviewHelper:
graph_id=graph_id,
graph_version=graph_version,
input_data=input_data,
message=block_name, # Use block_name directly as the message
message=f"Review required for {block_name} execution",
editable=editable,
)
@@ -160,11 +129,11 @@ class HITLReviewHelper:
async def handle_review_decision(
input_data: Any,
user_id: str,
node_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
block_name: str = "Block",
editable: bool = False,
) -> Optional[ReviewDecision]:
@@ -174,11 +143,11 @@ class HITLReviewHelper:
Args:
input_data: The input data to be reviewed
user_id: ID of the user requesting the review
node_id: ID of the node in the graph definition
node_exec_id: ID of the node execution
graph_exec_id: ID of the graph execution
graph_id: ID of the graph
graph_version: Version of the graph
execution_context: Current execution context
block_name: Name of the block requesting review
editable: Whether the reviewer can edit the data
@@ -189,11 +158,11 @@ class HITLReviewHelper:
review_result = await HITLReviewHelper._handle_review_request(
input_data=input_data,
user_id=user_id,
node_id=node_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
block_name=block_name,
editable=editable,
)

View File

@@ -15,7 +15,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
@@ -117,9 +116,10 @@ class SendWebRequestBlock(Block):
@staticmethod
async def _prepare_files(
execution_context: ExecutionContext,
graph_exec_id: str,
files_name: str,
files: list[MediaFileType],
user_id: str,
) -> list[tuple[str, tuple[str, BytesIO, str]]]:
"""
Prepare files for the request by storing them and reading their content.
@@ -127,16 +127,11 @@ class SendWebRequestBlock(Block):
(files_name, (filename, BytesIO, mime_type))
"""
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
graph_exec_id = execution_context.graph_exec_id
if graph_exec_id is None:
raise ValueError("graph_exec_id is required for file operations")
for media in files:
# Normalise to a list so we can repeat the same key
rel_path = await store_media_file(
file=media,
execution_context=execution_context,
return_format="for_local_processing",
graph_exec_id, media, user_id, return_content=False
)
abs_path = get_exec_file_path(graph_exec_id, rel_path)
async with aiofiles.open(abs_path, "rb") as f:
@@ -148,7 +143,7 @@ class SendWebRequestBlock(Block):
return files_payload
async def run(
self, input_data: Input, *, execution_context: ExecutionContext, **kwargs
self, input_data: Input, *, graph_exec_id: str, user_id: str, **kwargs
) -> BlockOutput:
# ─── Parse/normalise body ────────────────────────────────────
body = input_data.body
@@ -179,7 +174,7 @@ class SendWebRequestBlock(Block):
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
if use_files:
files_payload = await self._prepare_files(
execution_context, input_data.files_name, input_data.files
graph_exec_id, input_data.files_name, input_data.files, user_id
)
# Enforce body format rules
@@ -243,8 +238,9 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
self,
input_data: Input,
*,
execution_context: ExecutionContext,
graph_exec_id: str,
credentials: HostScopedCredentials,
user_id: str,
**kwargs,
) -> BlockOutput:
# Create SendWebRequestBlock.Input from our input (removing credentials field)
@@ -275,6 +271,6 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
# Use parent class run method
async for output_name, output_data in super().run(
base_input, execution_context=execution_context, **kwargs
base_input, graph_exec_id=graph_exec_id, user_id=user_id, **kwargs
):
yield output_name, output_data

View File

@@ -97,7 +97,6 @@ class HumanInTheLoopBlock(Block):
input_data: Input,
*,
user_id: str,
node_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
@@ -116,12 +115,12 @@ class HumanInTheLoopBlock(Block):
decision = await self.handle_review_decision(
input_data=input_data.data,
user_id=user_id,
node_id=node_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
block_name=input_data.name, # Use user-provided name instead of block type
execution_context=execution_context,
block_name=self.name,
editable=input_data.editable,
)

View File

@@ -12,7 +12,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockType,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import store_media_file
from backend.util.mock import MockObject
@@ -463,21 +462,18 @@ class AgentFileInputBlock(AgentInputBlock):
self,
input_data: Input,
*,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
if not input_data.value:
return
# Determine return format based on user preference
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
yield "result", await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.value,
execution_context=execution_context,
return_format=return_format,
user_id=user_id,
return_content=input_data.base_64,
)

View File

@@ -162,16 +162,8 @@ class LinearClient:
"searchTerm": team_name,
}
result = await self.query(query, variables)
nodes = result["teams"]["nodes"]
if not nodes:
raise LinearAPIException(
f"Team '{team_name}' not found. Check the team name or key and try again.",
status_code=404,
)
return nodes[0]["id"]
team_id = await self.query(query, variables)
return team_id["teams"]["nodes"][0]["id"]
except LinearAPIException as e:
raise e
@@ -248,44 +240,17 @@ class LinearClient:
except LinearAPIException as e:
raise e
async def try_search_issues(
self,
term: str,
max_results: int = 10,
team_id: str | None = None,
) -> list[Issue]:
async def try_search_issues(self, term: str) -> list[Issue]:
try:
query = """
query SearchIssues(
$term: String!,
$first: Int,
$teamId: String
) {
searchIssues(
term: $term,
first: $first,
teamId: $teamId
) {
query SearchIssues($term: String!, $includeComments: Boolean!) {
searchIssues(term: $term, includeComments: $includeComments) {
nodes {
id
identifier
title
description
priority
createdAt
state {
id
name
type
}
project {
id
name
}
assignee {
id
name
}
}
}
}
@@ -293,8 +258,7 @@ class LinearClient:
variables: dict[str, Any] = {
"term": term,
"first": max_results,
"teamId": team_id,
"includeComments": True,
}
issues = await self.query(query, variables)

View File

@@ -17,7 +17,7 @@ from ._config import (
LinearScope,
linear,
)
from .models import CreateIssueResponse, Issue, State
from .models import CreateIssueResponse, Issue
class LinearCreateIssueBlock(Block):
@@ -135,20 +135,9 @@ class LinearSearchIssuesBlock(Block):
description="Linear credentials with read permissions",
required_scopes={LinearScope.READ},
)
max_results: int = SchemaField(
description="Maximum number of results to return",
default=10,
ge=1,
le=100,
)
team_name: str | None = SchemaField(
description="Optional team name to filter results (e.g., 'Internal', 'Open Source')",
default=None,
)
class Output(BlockSchemaOutput):
issues: list[Issue] = SchemaField(description="List of issues")
error: str = SchemaField(description="Error message if the search failed")
def __init__(self):
super().__init__(
@@ -156,11 +145,8 @@ class LinearSearchIssuesBlock(Block):
description="Searches for issues on Linear",
input_schema=self.Input,
output_schema=self.Output,
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
test_input={
"term": "Test issue",
"max_results": 10,
"team_name": None,
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
},
test_credentials=TEST_CREDENTIALS_OAUTH,
@@ -170,14 +156,10 @@ class LinearSearchIssuesBlock(Block):
[
Issue(
id="abc123",
identifier="TST-123",
identifier="abc123",
title="Test issue",
description="Test description",
priority=1,
state=State(
id="state1", name="In Progress", type="started"
),
createdAt="2026-01-15T10:00:00.000Z",
)
],
)
@@ -186,12 +168,10 @@ class LinearSearchIssuesBlock(Block):
"search_issues": lambda *args, **kwargs: [
Issue(
id="abc123",
identifier="TST-123",
identifier="abc123",
title="Test issue",
description="Test description",
priority=1,
state=State(id="state1", name="In Progress", type="started"),
createdAt="2026-01-15T10:00:00.000Z",
)
]
},
@@ -201,22 +181,10 @@ class LinearSearchIssuesBlock(Block):
async def search_issues(
credentials: OAuth2Credentials | APIKeyCredentials,
term: str,
max_results: int = 10,
team_name: str | None = None,
) -> list[Issue]:
client = LinearClient(credentials=credentials)
# Resolve team name to ID if provided
# Raises LinearAPIException with descriptive message if team not found
team_id: str | None = None
if team_name:
team_id = await client.try_get_team_by_name(team_name=team_name)
return await client.try_search_issues(
term=term,
max_results=max_results,
team_id=team_id,
)
response: list[Issue] = await client.try_search_issues(term=term)
return response
async def run(
self,
@@ -228,10 +196,7 @@ class LinearSearchIssuesBlock(Block):
"""Execute the issue search"""
try:
issues = await self.search_issues(
credentials=credentials,
term=input_data.term,
max_results=input_data.max_results,
team_name=input_data.team_name,
credentials=credentials, term=input_data.term
)
yield "issues", issues
except LinearAPIException as e:

View File

@@ -36,21 +36,12 @@ class Project(BaseModel):
content: str | None = None
class State(BaseModel):
id: str
name: str
type: str | None = (
None # Workflow state type (e.g., "triage", "backlog", "started", "completed", "canceled")
)
class Issue(BaseModel):
id: str
identifier: str
title: str
description: str | None
priority: int
state: State | None = None
project: Project | None = None
createdAt: str | None = None
comments: list[Comment] | None = None

View File

@@ -32,7 +32,7 @@ from backend.data.model import (
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.logging import TruncatedLogger
from backend.util.prompt import compress_context, estimate_token_count
from backend.util.prompt import compress_prompt, estimate_token_count
from backend.util.text import TextFormatter
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
@@ -115,6 +115,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
# AI/ML API models
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
@@ -279,6 +280,9 @@ MODEL_METADATA = {
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
), # claude-haiku-4-5-20251001
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
"anthropic", 200000, 64000, "Claude 3.7 Sonnet", "Anthropic", "Anthropic", 2
), # claude-3-7-sonnet-20250219
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
), # claude-3-haiku-20240307
@@ -634,18 +638,11 @@ async def llm_call(
context_window = llm_model.context_window
if compress_prompt_to_fit:
result = await compress_context(
prompt = compress_prompt(
messages=prompt,
target_tokens=llm_model.context_window // 2,
client=None, # Truncation-only, no LLM summarization
reserve=0, # Caller handles response token budget separately
lossy_ok=True,
)
if result.error:
logger.warning(
f"Prompt compression did not meet target: {result.error}. "
f"Proceeding with {result.token_count} tokens."
)
prompt = result.messages
# Calculate available tokens based on context window and input length
estimated_input_tokens = estimate_token_count(prompt)

View File

@@ -1,6 +1,6 @@
import os
import tempfile
from typing import Optional
from typing import Literal, Optional
from moviepy.audio.io.AudioFileClip import AudioFileClip
from moviepy.video.fx.Loop import Loop
@@ -13,7 +13,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
@@ -47,19 +46,18 @@ class MediaDurationBlock(Block):
self,
input_data: Input,
*,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
# 1) Store the input media locally
local_media_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.media_in,
execution_context=execution_context,
return_format="for_local_processing",
)
assert execution_context.graph_exec_id is not None
media_abspath = get_exec_file_path(
execution_context.graph_exec_id, local_media_path
user_id=user_id,
return_content=False,
)
media_abspath = get_exec_file_path(graph_exec_id, local_media_path)
# 2) Load the clip
if input_data.is_video:
@@ -90,6 +88,10 @@ class LoopVideoBlock(Block):
default=None,
ge=1,
)
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
description="How to return the output video. Either a relative path or base64 data URI.",
default="file_path",
)
class Output(BlockSchemaOutput):
video_out: str = SchemaField(
@@ -109,19 +111,17 @@ class LoopVideoBlock(Block):
self,
input_data: Input,
*,
execution_context: ExecutionContext,
node_exec_id: str,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
assert execution_context.graph_exec_id is not None
assert execution_context.node_exec_id is not None
graph_exec_id = execution_context.graph_exec_id
node_exec_id = execution_context.node_exec_id
# 1) Store the input video locally
local_video_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.video_in,
execution_context=execution_context,
return_format="for_local_processing",
user_id=user_id,
return_content=False,
)
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
@@ -149,11 +149,12 @@ class LoopVideoBlock(Block):
looped_clip = looped_clip.with_audio(clip.audio)
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
# Return output - for_block_output returns workspace:// if available, else data URI
# Return as data URI
video_out = await store_media_file(
graph_exec_id=graph_exec_id,
file=output_filename,
execution_context=execution_context,
return_format="for_block_output",
user_id=user_id,
return_content=input_data.output_return_type == "data_uri",
)
yield "video_out", video_out
@@ -176,6 +177,10 @@ class AddAudioToVideoBlock(Block):
description="Volume scale for the newly attached audio track (1.0 = original).",
default=1.0,
)
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
description="Return the final output as a relative path or base64 data URI.",
default="file_path",
)
class Output(BlockSchemaOutput):
video_out: MediaFileType = SchemaField(
@@ -195,24 +200,23 @@ class AddAudioToVideoBlock(Block):
self,
input_data: Input,
*,
execution_context: ExecutionContext,
node_exec_id: str,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
assert execution_context.graph_exec_id is not None
assert execution_context.node_exec_id is not None
graph_exec_id = execution_context.graph_exec_id
node_exec_id = execution_context.node_exec_id
# 1) Store the inputs locally
local_video_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.video_in,
execution_context=execution_context,
return_format="for_local_processing",
user_id=user_id,
return_content=False,
)
local_audio_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.audio_in,
execution_context=execution_context,
return_format="for_local_processing",
user_id=user_id,
return_content=False,
)
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
@@ -236,11 +240,12 @@ class AddAudioToVideoBlock(Block):
output_abspath = os.path.join(abs_temp_dir, output_filename)
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
# 5) Return output - for_block_output returns workspace:// if available, else data URI
# 5) Return either path or data URI
video_out = await store_media_file(
graph_exec_id=graph_exec_id,
file=output_filename,
execution_context=execution_context,
return_format="for_block_output",
user_id=user_id,
return_content=input_data.output_return_type == "data_uri",
)
yield "video_out", video_out

View File

@@ -11,7 +11,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -113,7 +112,8 @@ class ScreenshotWebPageBlock(Block):
@staticmethod
async def take_screenshot(
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
url: str,
viewport_width: int,
viewport_height: int,
@@ -155,11 +155,12 @@ class ScreenshotWebPageBlock(Block):
return {
"image": await store_media_file(
graph_exec_id=graph_exec_id,
file=MediaFileType(
f"data:image/{format.value};base64,{b64encode(content).decode('utf-8')}"
),
execution_context=execution_context,
return_format="for_block_output",
user_id=user_id,
return_content=True,
)
}
@@ -168,13 +169,15 @@ class ScreenshotWebPageBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
try:
screenshot_data = await self.take_screenshot(
credentials=credentials,
execution_context=execution_context,
graph_exec_id=graph_exec_id,
user_id=user_id,
url=input_data.url,
viewport_width=input_data.viewport_width,
viewport_height=input_data.viewport_height,

View File

@@ -7,7 +7,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import ContributorDetails, SchemaField
from backend.util.file import get_exec_file_path, store_media_file
from backend.util.type import MediaFileType
@@ -99,7 +98,7 @@ class ReadSpreadsheetBlock(Block):
)
async def run(
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
) -> BlockOutput:
import csv
from io import StringIO
@@ -107,16 +106,14 @@ class ReadSpreadsheetBlock(Block):
# Determine data source - prefer file_input if provided, otherwise use contents
if input_data.file_input:
stored_file_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=input_data.file_input,
execution_context=execution_context,
return_format="for_local_processing",
return_content=False,
)
# Get full file path
assert execution_context.graph_exec_id # Validated by store_media_file
file_path = get_exec_file_path(
execution_context.graph_exec_id, stored_file_path
)
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
if not Path(file_path).exists():
raise ValueError(f"File does not exist: {file_path}")

View File

@@ -83,7 +83,7 @@ class StagehandRecommendedLlmModel(str, Enum):
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
# Anthropic
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
@property
def provider_name(self) -> str:
@@ -137,7 +137,7 @@ class StagehandObserveBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
@@ -182,7 +182,10 @@ class StagehandObserveBlock(Block):
**kwargs,
) -> BlockOutput:
logger.debug(f"OBSERVE: Using model provider {model_credentials.provider}")
logger.info(f"OBSERVE: Stagehand credentials: {stagehand_credentials}")
logger.info(
f"OBSERVE: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
)
with disable_signal_handling():
stagehand = Stagehand(
@@ -227,7 +230,7 @@ class StagehandActBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
@@ -279,7 +282,10 @@ class StagehandActBlock(Block):
**kwargs,
) -> BlockOutput:
logger.debug(f"ACT: Using model provider {model_credentials.provider}")
logger.info(f"ACT: Stagehand credentials: {stagehand_credentials}")
logger.info(
f"ACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
)
with disable_signal_handling():
stagehand = Stagehand(
@@ -324,7 +330,7 @@ class StagehandExtractBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
@@ -364,7 +370,10 @@ class StagehandExtractBlock(Block):
**kwargs,
) -> BlockOutput:
logger.debug(f"EXTRACT: Using model provider {model_credentials.provider}")
logger.info(f"EXTRACT: Stagehand credentials: {stagehand_credentials}")
logger.info(
f"EXTRACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
)
with disable_signal_handling():
stagehand = Stagehand(

View File

@@ -10,7 +10,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -18,9 +17,7 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.file import store_media_file
from backend.util.request import Requests
from backend.util.type import MediaFileType
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@@ -105,7 +102,7 @@ class CreateTalkingAvatarVideoBlock(Block):
test_output=[
(
"video_url",
lambda x: x.startswith(("workspace://", "data:")),
"https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
),
],
test_mock={
@@ -113,10 +110,9 @@ class CreateTalkingAvatarVideoBlock(Block):
"id": "abcd1234-5678-efgh-ijkl-mnopqrstuvwx",
"status": "created",
},
# Use data URI to avoid HTTP requests during tests
"get_clip_status": lambda *args, **kwargs: {
"status": "done",
"result_url": "data:video/mp4;base64,AAAA",
"result_url": "https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
},
},
test_credentials=TEST_CREDENTIALS,
@@ -142,12 +138,7 @@ class CreateTalkingAvatarVideoBlock(Block):
return response.json()
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
# Create the clip
payload = {
@@ -174,14 +165,7 @@ class CreateTalkingAvatarVideoBlock(Block):
for _ in range(input_data.max_polling_attempts):
status_response = await self.get_clip_status(credentials.api_key, clip_id)
if status_response["status"] == "done":
# Store the generated video to the user's workspace for persistence
video_url = status_response["result_url"]
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
yield "video_url", status_response["result_url"]
return
elif status_response["status"] == "error":
raise RuntimeError(

View File

@@ -12,7 +12,6 @@ from backend.blocks.iteration import StepThroughItemsBlock
from backend.blocks.llm import AITextSummarizerBlock
from backend.blocks.text import ExtractTextInformationBlock
from backend.blocks.xml_parser import XMLParserBlock
from backend.data.execution import ExecutionContext
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
@@ -234,12 +233,9 @@ class TestStoreMediaFileSecurity:
with pytest.raises(ValueError, match="File too large"):
await store_media_file(
graph_exec_id="test",
file=MediaFileType(large_data_uri),
execution_context=ExecutionContext(
user_id="test_user",
graph_exec_id="test",
),
return_format="for_local_processing",
user_id="test_user",
)
@patch("backend.util.file.Path")
@@ -274,12 +270,9 @@ class TestStoreMediaFileSecurity:
# Should raise an error when directory size exceeds limit
with pytest.raises(ValueError, match="Disk usage limit exceeded"):
await store_media_file(
graph_exec_id="test",
file=MediaFileType(
"data:text/plain;base64,dGVzdA=="
), # Small test file
execution_context=ExecutionContext(
user_id="test_user",
graph_exec_id="test",
),
return_format="for_local_processing",
user_id="test_user",
)

Some files were not shown because too many files have changed in this diff Show More