Compare commits

...

8 Commits

Author SHA1 Message Date
Swifty
bb608ea60d pr comments 2026-01-29 22:29:17 +01:00
Swifty
46af3b94f2 Merge branch 'swiftyos/sse-long-running-tasks' of github.com:Significant-Gravitas/AutoGPT into swiftyos/sse-long-running-tasks 2026-01-29 18:03:01 +01:00
Swifty
083cceca0f fixing edge cases 2026-01-29 18:02:21 +01:00
Swifty
06758adefd Merge branch 'dev' into swiftyos/sse-long-running-tasks 2026-01-29 13:33:32 +01:00
Swifty
c01c29a059 fmt issues 2026-01-29 13:28:01 +01:00
Ubbe
b94c83aacc feat(frontend): Copilot speech to text via Whisper model (#11871)
## Changes 🏗️


https://github.com/user-attachments/assets/d9c12ac0-625c-4b38-8834-e494b5eda9c0

Add a "speech to text" feature in the Chat input fox of Copilot, similar
as what you have in ChatGPT.

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Run locally and try the speech to text feature as part of the chat
input box

### For configuration changes:

We need to add `OPENAI_API_KEY=` to Vercel ( used in the Front-end )
both in Dev and Prod.

- [x] `.env.default` is updated or already compatible with my changes

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-29 17:46:36 +07:00
Swifty
d738059da8 added long running task support 2026-01-29 10:24:14 +01:00
Nicholas Tindle
7668c17d9c feat(platform): add User Workspace for persistent CoPilot file storage (#11867)
Implements persistent User Workspace storage for CoPilot, enabling
blocks to save and retrieve files across sessions. Files are stored in
session-scoped virtual paths (`/sessions/{session_id}/`).

Fixes SECRT-1833

### Changes 🏗️

**Database & Storage:**
- Add `UserWorkspace` and `UserWorkspaceFile` Prisma models
- Implement `WorkspaceStorageBackend` abstraction (GCS for cloud, local
filesystem for self-hosted)
- Add `workspace_id` and `session_id` fields to `ExecutionContext`

**Backend API:**
- Add REST endpoints: `GET/POST /api/workspace/files`, `GET/DELETE
/api/workspace/files/{id}`, `GET /api/workspace/files/{id}/download`
- Add CoPilot tools: `list_workspace_files`, `read_workspace_file`,
`write_workspace_file`
- Integrate workspace storage into `store_media_file()` - returns
`workspace://file-id` references

**Block Updates:**
- Refactor all file-handling blocks to use unified `ExecutionContext`
parameter
- Update media-generating blocks to persist outputs to workspace
(AIImageGenerator, AIImageCustomizer, FluxKontext, TalkingHead, FAL
video, Bannerbear, etc.)

**Frontend:**
- Render `workspace://` image references in chat via proxy endpoint
- Add "AI cannot see this image" overlay indicator

**CoPilot Context Mapping:**
- Session = Agent (graph_id) = Run (graph_exec_id)
- Files scoped to `/sessions/{session_id}/`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [ ] I have tested my changes according to the test plan:
- [ ] Create CoPilot session, generate image with AIImageGeneratorBlock
  - [ ] Verify image returns `workspace://file-id` (not base64)
  - [ ] Verify image renders in chat with visibility indicator
  - [ ] Verify workspace files persist across sessions
  - [ ] Test list/read/write workspace files via CoPilot tools
  - [ ] Test local storage backend for self-hosted deployments

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

🤖 Generated with [Claude Code](https://claude.ai/code)

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> Introduces a new persistent file-storage surface area (DB tables,
storage backends, download API, and chat tools) and rewires
`store_media_file()`/block execution context across many blocks, so
regressions could impact file handling, access control, or storage
costs.
> 
> **Overview**
> Adds a **persistent per-user Workspace** (new
`UserWorkspace`/`UserWorkspaceFile` models plus `WorkspaceManager` +
`WorkspaceStorageBackend` with GCS/local implementations) and wires it
into the API via a new `/api/workspace/files/{file_id}/download` route
(including header-sanitized `Content-Disposition`) and shutdown
lifecycle hooks.
> 
> Extends `ExecutionContext` to carry execution identity +
`workspace_id`/`session_id`, updates executor tooling to clone
node-specific contexts, and updates `run_block` (CoPilot) to create a
session-scoped workspace and synthetic graph/run/node IDs.
> 
> Refactors `store_media_file()` to require `execution_context` +
`return_format` and to support `workspace://` references; migrates many
media/file-handling blocks and related tests to the new API and to
persist generated media as `workspace://...` (or fall back to data URIs
outside CoPilot), and adds CoPilot chat tools for
listing/reading/writing/deleting workspace files with safeguards against
context bloat.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
6abc70f793. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2026-01-29 05:49:47 +00:00
71 changed files with 5601 additions and 436 deletions

View File

@@ -16,7 +16,6 @@ See `docs/content/platform/getting-started.md` for setup instructions.
- Format Python code with `poetry run format`.
- Format frontend code using `pnpm format`.
## Frontend guidelines:
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
@@ -33,14 +32,17 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
- Use function declarations for components, arrow functions only for callbacks
- No barrel files or `index.ts` re-exports
- Do not use `useCallback` or `useMemo` unless strictly needed
- Avoid comments at all times unless the code is very complex
- Do not use `useCallback` or `useMemo` unless asked to optimise a given function
- Do not type hook returns, let Typescript infer as much as possible
- Never type with `any`, if not types available use `unknown`
## Testing
@@ -49,22 +51,8 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
Always run the relevant linters and tests before committing.
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).
Types:
- feat
- fix
- refactor
- ci
- dx (developer experience)
Scopes:
- platform
- platform/library
- platform/marketplace
- backend
- backend/executor
- frontend
- frontend/library
- frontend/marketplace
- blocks
Types: - feat - fix - refactor - ci - dx (developer experience)
Scopes: - platform - platform/library - platform/marketplace - backend - backend/executor - frontend - frontend/library - frontend/marketplace - blocks
## Pull requests

View File

@@ -85,17 +85,6 @@ pnpm format
pnpm types
```
**📖 Complete Guide**: See `/frontend/CONTRIBUTING.md` and `/frontend/.cursorrules` for comprehensive frontend patterns.
**Key Frontend Conventions:**
- Separate render logic from data/behavior in components
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
- Use function declarations (not arrow functions) for components/handlers
- Use design system components from `src/components/` (atoms, molecules, organisms)
- Only use Phosphor Icons
- Never use `src/components/__legacy__/*` or deprecated `BackendAPI`
## Architecture Overview
### Backend Architecture
@@ -194,6 +183,50 @@ ex: do the inputs and outputs tie well together?
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
**Handling files in blocks with `store_media_file()`:**
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
| Format | Use When | Returns |
|--------|----------|---------|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
**Examples:**
```python
# INPUT: Need to process file locally with ffmpeg
local_path = await store_media_file(
file=input_data.video,
execution_context=execution_context,
return_format="for_local_processing",
)
# local_path = "video.mp4" - use with Path/ffmpeg/etc
# INPUT: Need to send to external API like Replicate
image_b64 = await store_media_file(
file=input_data.image,
execution_context=execution_context,
return_format="for_external_api",
)
# image_b64 = "..." - send to API
# OUTPUT: Returning result from block
result_url = await store_media_file(
file=generated_image_url,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", result_url
# In CoPilot: result_url = "workspace://abc123"
# In graphs: result_url = "data:image/png;base64,..."
```
**Key points:**
- `for_block_output` is the ONLY format that auto-adapts to execution context
- Always use `for_block_output` for block outputs unless you have a specific reason not to
- Never hardcode workspace checks - let `for_block_output` handle it
**Modifying the API:**
1. Update route in `/backend/backend/server/routers/`
@@ -217,14 +250,17 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
- Use function declarations for components, arrow functions only for callbacks
- No barrel files or `index.ts` re-exports
- Do not use `useCallback` or `useMemo` unless strictly needed
- Do not use `useCallback` or `useMemo` unless asked to optimise a given function
- Avoid comments at all times unless the code is very complex
- Do not type hook returns, let Typescript infer as much as possible
- Never type with `any`, if not types available use `unknown`
### Security Implementation

View File

@@ -0,0 +1,325 @@
"""RabbitMQ consumer for operation completion messages.
This module provides a consumer that listens for completion notifications
from external services (like Agent Generator) and triggers the appropriate
stream registry and chat service updates.
"""
import asyncio
import logging
import orjson
from pydantic import BaseModel
from backend.data.rabbitmq import (
AsyncRabbitMQ,
Exchange,
ExchangeType,
Queue,
RabbitMQConfig,
)
from . import service as chat_service
from . import stream_registry
from .response_model import StreamError, StreamFinish, StreamToolOutputAvailable
from .tools.models import ErrorResponse
logger = logging.getLogger(__name__)
# Queue and exchange configuration
OPERATION_COMPLETE_EXCHANGE = Exchange(
name="chat_operations",
type=ExchangeType.DIRECT,
durable=True,
)
OPERATION_COMPLETE_QUEUE = Queue(
name="chat_operation_complete",
durable=True,
exchange=OPERATION_COMPLETE_EXCHANGE,
routing_key="operation.complete",
)
RABBITMQ_CONFIG = RabbitMQConfig(
exchanges=[OPERATION_COMPLETE_EXCHANGE],
queues=[OPERATION_COMPLETE_QUEUE],
)
class OperationCompleteMessage(BaseModel):
"""Message format for operation completion notifications."""
operation_id: str
task_id: str
success: bool
result: dict | str | None = None
error: str | None = None
class ChatCompletionConsumer:
"""Consumer for chat operation completion messages from RabbitMQ."""
def __init__(self):
self._rabbitmq: AsyncRabbitMQ | None = None
self._consumer_task: asyncio.Task | None = None
self._running = False
async def start(self) -> None:
"""Start the completion consumer."""
if self._running:
logger.warning("Completion consumer already running")
return
self._rabbitmq = AsyncRabbitMQ(RABBITMQ_CONFIG)
await self._rabbitmq.connect()
self._running = True
self._consumer_task = asyncio.create_task(self._consume_messages())
logger.info("Chat completion consumer started")
async def stop(self) -> None:
"""Stop the completion consumer."""
self._running = False
if self._consumer_task:
self._consumer_task.cancel()
try:
await self._consumer_task
except asyncio.CancelledError:
pass
self._consumer_task = None
if self._rabbitmq:
await self._rabbitmq.disconnect()
self._rabbitmq = None
logger.info("Chat completion consumer stopped")
async def _consume_messages(self) -> None:
"""Main message consumption loop with retry logic."""
max_retries = 10
retry_delay = 5 # seconds
retry_count = 0
while self._running and retry_count < max_retries:
if not self._rabbitmq:
logger.error("RabbitMQ not initialized")
return
try:
channel = await self._rabbitmq.get_channel()
queue = await channel.get_queue(OPERATION_COMPLETE_QUEUE.name)
# Reset retry count on successful connection
retry_count = 0
async with queue.iterator() as queue_iter:
async for message in queue_iter:
if not self._running:
return
try:
async with message.process():
await self._handle_message(message.body)
except Exception as e:
logger.error(
f"Error processing completion message: {e}",
exc_info=True,
)
# Message will be requeued due to exception
except asyncio.CancelledError:
logger.info("Consumer cancelled")
return
except Exception as e:
retry_count += 1
logger.error(
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
exc_info=True,
)
if self._running and retry_count < max_retries:
await asyncio.sleep(retry_delay)
else:
logger.error("Max retries reached, stopping consumer")
return
async def _handle_message(self, body: bytes) -> None:
"""Handle a single completion message."""
try:
data = orjson.loads(body)
message = OperationCompleteMessage(**data)
except Exception as e:
logger.error(f"Failed to parse completion message: {e}")
return
logger.info(
f"Received completion for operation {message.operation_id} "
f"(task_id={message.task_id}, success={message.success})"
)
# Find task in registry
task = await stream_registry.find_task_by_operation_id(message.operation_id)
if task is None:
# Try to look up by task_id directly
task = await stream_registry.get_task(message.task_id)
if task is None:
logger.warning(
f"Task not found for operation {message.operation_id} "
f"(task_id={message.task_id})"
)
return
if message.success:
await self._handle_success(task, message)
else:
await self._handle_failure(task, message)
async def _handle_success(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle successful operation completion."""
# Publish result to stream registry
result_output = message.result if message.result else {"status": "completed"}
await stream_registry.publish_chunk(
task.task_id,
StreamToolOutputAvailable(
toolCallId=task.tool_call_id,
toolName=task.tool_name,
output=(
result_output
if isinstance(result_output, str)
else orjson.dumps(result_output).decode("utf-8")
),
success=True,
),
)
# Update pending operation in database
result_str = (
message.result
if isinstance(message.result, str)
else (
orjson.dumps(message.result).decode("utf-8")
if message.result
else '{"status": "completed"}'
)
)
await chat_service._update_pending_operation(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
result=result_str,
)
# Generate LLM continuation with streaming
await chat_service._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
# Mark task as completed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="completed")
await chat_service._mark_operation_completed(task.tool_call_id)
logger.info(
f"Successfully processed completion for task {task.task_id} "
f"(operation {message.operation_id})"
)
async def _handle_failure(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle failed operation completion."""
error_msg = message.error or "Operation failed"
# Publish error to stream registry followed by finish event
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText=error_msg),
)
await stream_registry.publish_chunk(task.task_id, StreamFinish())
# Update pending operation with error
error_response = ErrorResponse(
message=error_msg,
error=message.error,
)
await chat_service._update_pending_operation(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
result=error_response.model_dump_json(),
)
# Mark task as failed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="failed")
await chat_service._mark_operation_completed(task.tool_call_id)
logger.info(
f"Processed failure for task {task.task_id} "
f"(operation {message.operation_id}): {error_msg}"
)
# Module-level consumer instance
_consumer: ChatCompletionConsumer | None = None
async def start_completion_consumer() -> None:
"""Start the global completion consumer."""
global _consumer
if _consumer is None:
_consumer = ChatCompletionConsumer()
await _consumer.start()
async def stop_completion_consumer() -> None:
"""Stop the global completion consumer."""
global _consumer
if _consumer:
await _consumer.stop()
_consumer = None
async def publish_operation_complete(
operation_id: str,
task_id: str,
success: bool,
result: dict | str | None = None,
error: str | None = None,
) -> None:
"""Publish an operation completion message.
This is a helper function for testing or for services that want to
publish completion messages directly.
Args:
operation_id: The operation ID that completed.
task_id: The task ID associated with the operation.
success: Whether the operation succeeded.
result: The result data (for success).
error: The error message (for failure).
"""
message = OperationCompleteMessage(
operation_id=operation_id,
task_id=task_id,
success=success,
result=result,
error=error,
)
rabbitmq = AsyncRabbitMQ(RABBITMQ_CONFIG)
try:
await rabbitmq.connect()
await rabbitmq.publish_message(
routing_key="operation.complete",
message=message.model_dump_json(),
exchange=OPERATION_COMPLETE_EXCHANGE,
)
logger.info(f"Published completion for operation {operation_id}")
finally:
await rabbitmq.disconnect()

View File

@@ -44,6 +44,20 @@ class ChatConfig(BaseSettings):
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
)
# Stream registry configuration for SSE reconnection
stream_ttl: int = Field(
default=3600,
description="TTL in seconds for stream data in Redis (1 hour)",
)
stream_max_length: int = Field(
default=1000,
description="Maximum number of messages to store per stream",
)
internal_api_key: str | None = Field(
default=None,
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
)
# Langfuse Prompt Management Configuration
# Note: Langfuse credentials are in Settings().secrets (settings.py)
langfuse_prompt_name: str = Field(
@@ -82,6 +96,14 @@ class ChatConfig(BaseSettings):
v = "https://openrouter.ai/api/v1"
return v
@field_validator("internal_api_key", mode="before")
@classmethod
def get_internal_api_key(cls, v):
"""Get internal API key from environment if not provided."""
if v is None:
v = os.getenv("CHAT_INTERNAL_API_KEY")
return v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",

View File

@@ -4,16 +4,19 @@ import logging
from collections.abc import AsyncGenerator
from typing import Annotated
import orjson
from autogpt_libs import auth
from fastapi import APIRouter, Depends, Query, Security
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from backend.util.exceptions import NotFoundError
from . import service as chat_service
from . import stream_registry
from .config import ChatConfig
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
from .response_model import StreamFinish, StreamHeartbeat
config = ChatConfig()
@@ -81,6 +84,14 @@ class ListSessionsResponse(BaseModel):
total: int
class OperationCompleteRequest(BaseModel):
"""Request model for external completion webhook."""
success: bool
result: dict | str | None = None
error: str | None = None
# ========== Routes ==========
@@ -366,6 +377,267 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Task Streaming (SSE Reconnection) ==========
@router.get(
"/tasks/{task_id}/stream",
)
async def stream_task(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
last_message_id: str = Query(
default="0-0",
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
),
):
"""
Reconnect to a long-running task's SSE stream.
When a long-running operation (like agent generation) starts, the client
receives a task_id. If the connection drops, the client can reconnect
using this endpoint to resume receiving updates.
Args:
task_id: The task ID from the operation_started response.
user_id: Authenticated user ID for ownership validation.
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
Returns:
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
Raises:
NotFoundError: If task_id is not found or user doesn't have access.
"""
# Get subscriber queue from stream registry
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id=last_message_id,
)
if subscriber_queue is None:
raise NotFoundError(f"Task {task_id} not found or access denied.")
async def event_generator() -> AsyncGenerator[str, None]:
import asyncio
chunk_count = 0
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
try:
while True:
try:
# Wait for next chunk with timeout for heartbeats
chunk = await asyncio.wait_for(
subscriber_queue.get(), timeout=heartbeat_interval
)
chunk_count += 1
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
logger.info(
f"Task stream completed for task {task_id}, "
f"chunk_count={chunk_count}"
)
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse()
except Exception as e:
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
finally:
# Unsubscribe when client disconnects or stream ends
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
},
)
@router.get(
"/tasks/{task_id}",
)
async def get_task_status(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
) -> dict:
"""
Get the status of a long-running task.
Args:
task_id: The task ID to check.
user_id: Authenticated user ID for ownership validation.
Returns:
dict: Task status including task_id, status, tool_name, and operation_id.
Raises:
NotFoundError: If task_id is not found or user doesn't have access.
"""
task = await stream_registry.get_task(task_id)
if task is None:
raise NotFoundError(f"Task {task_id} not found.")
# Validate ownership
if user_id and task.user_id and task.user_id != user_id:
raise NotFoundError(f"Task {task_id} not found.")
return {
"task_id": task.task_id,
"session_id": task.session_id,
"status": task.status,
"tool_name": task.tool_name,
"operation_id": task.operation_id,
"created_at": task.created_at.isoformat(),
}
# ========== External Completion Webhook ==========
@router.post(
"/operations/{operation_id}/complete",
status_code=200,
)
async def complete_operation(
operation_id: str,
request: OperationCompleteRequest,
x_api_key: str | None = Header(default=None),
) -> dict:
"""
External completion webhook for long-running operations.
Called by Agent Generator (or other services) when an operation completes.
This triggers the stream registry to publish completion and continue LLM generation.
Args:
operation_id: The operation ID to complete.
request: Completion payload with success status and result/error.
x_api_key: Internal API key for authentication.
Returns:
dict: Status of the completion.
Raises:
HTTPException: If API key is invalid or operation not found.
"""
# Validate internal API key - reject if not configured or invalid
if not config.internal_api_key:
logger.error(
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
)
raise HTTPException(
status_code=503,
detail="Webhook not available: internal API key not configured",
)
if x_api_key != config.internal_api_key:
raise HTTPException(status_code=401, detail="Invalid API key")
# Find task by operation_id
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None:
raise HTTPException(
status_code=404,
detail=f"Operation {operation_id} not found",
)
logger.info(
f"Received completion webhook for operation {operation_id} "
f"(task_id={task.task_id}, success={request.success})"
)
if request.success:
# Publish result to stream registry
from .response_model import StreamToolOutputAvailable
result_output = request.result if request.result else {"status": "completed"}
await stream_registry.publish_chunk(
task.task_id,
StreamToolOutputAvailable(
toolCallId=task.tool_call_id,
toolName=task.tool_name,
output=(
result_output
if isinstance(result_output, str)
else orjson.dumps(result_output).decode("utf-8")
),
success=True,
),
)
# Update pending operation in database
from . import service as svc
result_str = (
request.result
if isinstance(request.result, str)
else (
orjson.dumps(request.result).decode("utf-8")
if request.result
else '{"status": "completed"}'
)
)
await svc._update_pending_operation(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
result=result_str,
)
# Generate LLM continuation with streaming
await svc._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
# Mark task as completed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="completed")
await svc._mark_operation_completed(task.tool_call_id)
else:
# Publish error to stream registry
from .response_model import StreamError
error_msg = request.error or "Operation failed"
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText=error_msg),
)
# Send finish event to end the stream
await stream_registry.publish_chunk(task.task_id, StreamFinish())
# Update pending operation with error
from . import service as svc
from .tools.models import ErrorResponse
error_response = ErrorResponse(
message=error_msg,
error=request.error,
)
await svc._update_pending_operation(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
result=error_response.model_dump_json(),
)
# Mark task as failed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="failed")
await svc._mark_operation_completed(task.tool_call_id)
return {"status": "ok", "task_id": task.task_id}
# ========== Health Check ==========

View File

@@ -26,6 +26,7 @@ from backend.util.exceptions import NotFoundError
from backend.util.settings import Settings
from . import db as chat_db
from . import stream_registry
from .config import ChatConfig
from .model import (
ChatMessage,
@@ -1610,8 +1611,9 @@ async def _yield_tool_call(
)
return
# Generate operation ID
# Generate operation ID and task ID
operation_id = str(uuid_module.uuid4())
task_id = str(uuid_module.uuid4())
# Build a user-friendly message based on tool and arguments
if tool_name == "create_agent":
@@ -1654,6 +1656,16 @@ async def _yield_tool_call(
# Wrap session save and task creation in try-except to release lock on failure
try:
# Create task in stream registry for SSE reconnection support
await stream_registry.create_task(
task_id=task_id,
session_id=session.session_id,
user_id=session.user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
operation_id=operation_id,
)
# Save assistant message with tool_call FIRST (required by LLM)
assistant_message = ChatMessage(
role="assistant",
@@ -1675,23 +1687,27 @@ async def _yield_tool_call(
session.messages.append(pending_message)
await upsert_chat_session(session)
logger.info(
f"Saved pending operation {operation_id} for tool {tool_name} "
f"in session {session.session_id}"
f"Saved pending operation {operation_id} (task_id={task_id}) "
f"for tool {tool_name} in session {session.session_id}"
)
# Store task reference in module-level set to prevent GC before completion
task = asyncio.create_task(
_execute_long_running_tool(
bg_task = asyncio.create_task(
_execute_long_running_tool_with_streaming(
tool_name=tool_name,
parameters=arguments,
tool_call_id=tool_call_id,
operation_id=operation_id,
task_id=task_id,
session_id=session.session_id,
user_id=session.user_id,
)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
_background_tasks.add(bg_task)
bg_task.add_done_callback(_background_tasks.discard)
# Associate the asyncio task with the stream registry task
await stream_registry.set_task_asyncio_task(task_id, bg_task)
except Exception as e:
# Roll back appended messages to prevent data corruption on subsequent saves
if (
@@ -1709,6 +1725,11 @@ async def _yield_tool_call(
# Release the Redis lock since the background task won't be spawned
await _mark_operation_completed(tool_call_id)
# Mark stream registry task as failed if it was created
try:
await stream_registry.mark_task_completed(task_id, status="failed")
except Exception:
pass
logger.error(
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
)
@@ -1722,6 +1743,7 @@ async def _yield_tool_call(
message=started_msg,
operation_id=operation_id,
tool_name=tool_name,
task_id=task_id, # Include task_id for SSE reconnection
).model_dump_json(),
success=True,
)
@@ -1791,6 +1813,9 @@ async def _execute_long_running_tool(
This function runs independently of the SSE connection, so the operation
survives if the user closes their browser tab.
NOTE: This is the legacy function without stream registry support.
Use _execute_long_running_tool_with_streaming for new implementations.
"""
try:
# Load fresh session (not stale reference)
@@ -1838,6 +1863,128 @@ async def _execute_long_running_tool(
await _mark_operation_completed(tool_call_id)
async def _execute_long_running_tool_with_streaming(
tool_name: str,
parameters: dict[str, Any],
tool_call_id: str,
operation_id: str,
task_id: str,
session_id: str,
user_id: str | None,
) -> None:
"""Execute a long-running tool with stream registry support for SSE reconnection.
This function runs independently of the SSE connection, publishes progress
to the stream registry, and survives if the user closes their browser tab.
Clients can reconnect via GET /chat/tasks/{task_id}/stream to resume streaming.
If the external service returns a 202 Accepted (async), this function exits
early and lets the RabbitMQ completion consumer handle the rest.
"""
# Track whether we delegated to async processing - if so, the RabbitMQ
# completion consumer will handle cleanup, not us
delegated_to_async = False
try:
# Load fresh session (not stale reference)
session = await get_chat_session(session_id, user_id)
if not session:
logger.error(f"Session {session_id} not found for background tool")
await stream_registry.mark_task_completed(task_id, status="failed")
return
# Pass operation_id and task_id to the tool for async processing
enriched_parameters = {
**parameters,
"_operation_id": operation_id,
"_task_id": task_id,
}
# Execute the actual tool
result = await execute_tool(
tool_name=tool_name,
parameters=enriched_parameters,
tool_call_id=tool_call_id,
user_id=user_id,
session=session,
)
# Check if the tool result indicates async processing
# (e.g., Agent Generator returned 202 Accepted)
try:
result_data = orjson.loads(result.output) if result.output else {}
if result_data.get("status") == "accepted":
logger.info(
f"Tool {tool_name} delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id}). "
f"RabbitMQ completion consumer will handle the rest."
)
# Don't publish result, don't continue with LLM, and don't cleanup
# The RabbitMQ consumer will handle everything when the external
# service completes and publishes to the queue
delegated_to_async = True
return
except (orjson.JSONDecodeError, TypeError):
pass # Not JSON or not async - continue normally
# Publish tool result to stream registry
await stream_registry.publish_chunk(task_id, result)
# Update the pending message with result
result_str = (
result.output
if isinstance(result.output, str)
else orjson.dumps(result.output).decode("utf-8")
)
await _update_pending_operation(
session_id=session_id,
tool_call_id=tool_call_id,
result=result_str,
)
logger.info(
f"Background tool {tool_name} completed for session {session_id} "
f"(task_id={task_id})"
)
# Generate LLM continuation and stream chunks to registry
await _generate_llm_continuation_with_streaming(
session_id=session_id,
user_id=user_id,
task_id=task_id,
)
# Mark task as completed in stream registry
await stream_registry.mark_task_completed(task_id, status="completed")
except Exception as e:
logger.error(f"Background tool {tool_name} failed: {e}", exc_info=True)
error_response = ErrorResponse(
message=f"Tool {tool_name} failed: {str(e)}",
)
# Publish error to stream registry followed by finish event
await stream_registry.publish_chunk(
task_id,
StreamError(errorText=str(e)),
)
await stream_registry.publish_chunk(task_id, StreamFinish())
await _update_pending_operation(
session_id=session_id,
tool_call_id=tool_call_id,
result=error_response.model_dump_json(),
)
# Mark task as failed in stream registry
await stream_registry.mark_task_completed(task_id, status="failed")
finally:
# Only cleanup if we didn't delegate to async processing
# For async path, the RabbitMQ completion consumer handles cleanup
if not delegated_to_async:
await _mark_operation_completed(tool_call_id)
async def _update_pending_operation(
session_id: str,
tool_call_id: str,
@@ -1964,3 +2111,128 @@ async def _generate_llm_continuation(
except Exception as e:
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
async def _generate_llm_continuation_with_streaming(
session_id: str,
user_id: str | None,
task_id: str,
) -> None:
"""Generate an LLM response with streaming to the stream registry.
This is called by background tasks to continue the conversation
after a tool result is saved. Chunks are published to the stream registry
so reconnecting clients can receive them.
"""
import uuid as uuid_module
try:
# Load fresh session from DB (bypass cache to get the updated tool result)
await invalidate_session_cache(session_id)
session = await get_chat_session(session_id, user_id)
if not session:
logger.error(f"Session {session_id} not found for LLM continuation")
return
# Build system prompt
system_prompt, _ = await _build_system_prompt(user_id)
# Build messages in OpenAI format
messages = session.to_openai_messages()
if system_prompt:
from openai.types.chat import ChatCompletionSystemMessageParam
system_message = ChatCompletionSystemMessageParam(
role="system",
content=system_prompt,
)
messages = [system_message] + messages
# Build extra_body for tracing
extra_body: dict[str, Any] = {
"posthogProperties": {
"environment": settings.config.app_env.value,
},
}
if user_id:
extra_body["user"] = user_id[:128]
extra_body["posthogDistinctId"] = user_id
if session_id:
extra_body["session_id"] = session_id[:128]
# Make streaming LLM call (no tools - just text response)
from typing import cast
from openai.types.chat import ChatCompletionMessageParam
# Generate unique IDs for AI SDK protocol
message_id = str(uuid_module.uuid4())
text_block_id = str(uuid_module.uuid4())
# Publish start event
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
# Stream the response
stream = await client.chat.completions.create(
model=config.model,
messages=cast(list[ChatCompletionMessageParam], messages),
extra_body=extra_body,
stream=True,
)
assistant_content = ""
async for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
delta = chunk.choices[0].delta.content
assistant_content += delta
# Publish delta to stream registry
await stream_registry.publish_chunk(
task_id,
StreamTextDelta(id=text_block_id, delta=delta),
)
# Publish end events
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
if assistant_content:
# Reload session from DB to avoid race condition with user messages
fresh_session = await get_chat_session(session_id, user_id)
if not fresh_session:
logger.error(
f"Session {session_id} disappeared during LLM continuation"
)
return
# Save assistant message to database
assistant_message = ChatMessage(
role="assistant",
content=assistant_content,
)
fresh_session.messages.append(assistant_message)
# Save to database (not cache) to persist the response
await upsert_chat_session(fresh_session)
# Invalidate cache so next poll/refresh gets fresh data
await invalidate_session_cache(session_id)
logger.info(
f"Generated streaming LLM continuation for session {session_id} "
f"(task_id={task_id}), response length: {len(assistant_content)}"
)
else:
logger.warning(
f"Streaming LLM continuation returned empty response for {session_id}"
)
except Exception as e:
logger.error(
f"Failed to generate streaming LLM continuation: {e}", exc_info=True
)
# Publish error to stream registry followed by finish event
await stream_registry.publish_chunk(
task_id,
StreamError(errorText=f"Failed to generate response: {e}"),
)
await stream_registry.publish_chunk(task_id, StreamFinish())

View File

@@ -0,0 +1,648 @@
"""Stream registry for managing reconnectable SSE streams.
This module provides a registry for tracking active streaming tasks and their
messages. It supports:
- Creating tasks with unique IDs for long-running operations
- Publishing stream messages to both Redis Streams and in-memory queues
- Subscribing to tasks with replay of missed messages
- Looking up tasks by operation_id for webhook callbacks
- Cross-pod real-time delivery via Redis pub/sub
"""
import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Literal
import orjson
from backend.data.redis_client import get_redis_async
from .config import ChatConfig
from .response_model import StreamBaseResponse, StreamFinish
logger = logging.getLogger(__name__)
config = ChatConfig()
# Track active pub/sub listeners for cross-pod delivery
_pubsub_listeners: dict[str, asyncio.Task] = {}
@dataclass
class ActiveTask:
"""Represents an active streaming task."""
task_id: str
session_id: str
user_id: str | None
tool_call_id: str
tool_name: str
operation_id: str
status: Literal["running", "completed", "failed"] = "running"
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
queue: asyncio.Queue[StreamBaseResponse] = field(default_factory=asyncio.Queue)
asyncio_task: asyncio.Task | None = None
# Lock for atomic status checks and subscriber management
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
# Set of subscriber queues for fan-out
subscribers: set[asyncio.Queue[StreamBaseResponse]] = field(default_factory=set)
# Module-level registry for active tasks
_active_tasks: dict[str, ActiveTask] = {}
# Redis key patterns
TASK_META_PREFIX = "chat:task:meta:" # Hash for task metadata
TASK_STREAM_PREFIX = "chat:stream:" # Redis Stream for messages
TASK_OP_PREFIX = "chat:task:op:" # Operation ID -> task_id mapping
TASK_PUBSUB_PREFIX = "chat:task:pubsub:" # Pub/sub channel for cross-pod delivery
def _get_task_meta_key(task_id: str) -> str:
"""Get Redis key for task metadata."""
return f"{TASK_META_PREFIX}{task_id}"
def _get_task_stream_key(task_id: str) -> str:
"""Get Redis key for task message stream."""
return f"{TASK_STREAM_PREFIX}{task_id}"
def _get_operation_mapping_key(operation_id: str) -> str:
"""Get Redis key for operation_id to task_id mapping."""
return f"{TASK_OP_PREFIX}{operation_id}"
def _get_task_pubsub_channel(task_id: str) -> str:
"""Get Redis pub/sub channel for task cross-pod delivery."""
return f"{TASK_PUBSUB_PREFIX}{task_id}"
async def create_task(
task_id: str,
session_id: str,
user_id: str | None,
tool_call_id: str,
tool_name: str,
operation_id: str,
) -> ActiveTask:
"""Create a new streaming task in memory and Redis.
Args:
task_id: Unique identifier for the task
session_id: Chat session ID
user_id: User ID (may be None for anonymous)
tool_call_id: Tool call ID from the LLM
tool_name: Name of the tool being executed
operation_id: Operation ID for webhook callbacks
Returns:
The created ActiveTask instance
"""
task = ActiveTask(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
operation_id=operation_id,
)
# Store in memory registry
_active_tasks[task_id] = task
# Store metadata in Redis for durability
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
op_key = _get_operation_mapping_key(operation_id)
await redis.hset( # type: ignore[misc]
meta_key,
mapping={
"task_id": task_id,
"session_id": session_id,
"user_id": user_id or "",
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"operation_id": operation_id,
"status": task.status,
"created_at": task.created_at.isoformat(),
},
)
await redis.expire(meta_key, config.stream_ttl)
# Create operation_id -> task_id mapping for webhook lookups
await redis.set(op_key, task_id, ex=config.stream_ttl)
logger.info(
f"Created streaming task {task_id} for operation {operation_id} "
f"in session {session_id}"
)
return task
async def publish_chunk(
task_id: str,
chunk: StreamBaseResponse,
) -> str:
"""Publish a chunk to the task's stream.
Delivers to in-memory subscribers first (for real-time), then persists to
Redis Stream (for replay). This order ensures live subscribers get messages
even if Redis temporarily fails.
Args:
task_id: Task ID to publish to
chunk: The stream response chunk to publish
Returns:
The Redis Stream message ID (format: "timestamp-sequence"), or "0-0" if
Redis persistence failed
"""
# Deliver to in-memory subscribers FIRST for real-time updates
task = _active_tasks.get(task_id)
if task:
async with task.lock:
for subscriber_queue in task.subscribers:
try:
subscriber_queue.put_nowait(chunk)
except asyncio.QueueFull:
logger.warning(
f"Subscriber queue full for task {task_id}, dropping chunk"
)
# Then persist to Redis Stream for replay (with error handling)
message_id = "0-0"
chunk_json = chunk.model_dump_json()
try:
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
# Add to Redis Stream with auto-generated ID
# The ID format is "timestamp-sequence" which gives us ordering
raw_id = await redis.xadd(
stream_key,
{"data": chunk_json},
maxlen=config.stream_max_length,
)
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
# Publish to pub/sub for cross-pod real-time delivery
pubsub_channel = _get_task_pubsub_channel(task_id)
await redis.publish(pubsub_channel, chunk_json)
logger.debug(f"Published chunk to task {task_id}, message_id={message_id}")
except Exception as e:
logger.error(
f"Failed to persist chunk to Redis for task {task_id}: {e}",
exc_info=True,
)
return message_id
async def subscribe_to_task(
task_id: str,
user_id: str | None,
last_message_id: str = "0-0",
) -> asyncio.Queue[StreamBaseResponse] | None:
"""Subscribe to a task's stream with replay of missed messages.
Args:
task_id: Task ID to subscribe to
user_id: User ID for ownership validation
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
Returns:
An asyncio Queue that will receive stream chunks, or None if task not found
or user doesn't have access
"""
# Check in-memory first
task = _active_tasks.get(task_id)
if task:
# Validate ownership
if user_id and task.user_id and task.user_id != user_id:
logger.warning(
f"User {user_id} attempted to subscribe to task {task_id} "
f"owned by {task.user_id}"
)
return None
# Create a new queue for this subscriber
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
# Replay from Redis Stream
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
# Track the last message ID we've seen for gap detection
replay_last_id = last_message_id
# Read all messages from stream starting after last_message_id
# xread returns messages with ID > last_message_id
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
if messages:
# messages format: [[stream_name, [(id, {data: json}), ...]]]
for _stream_name, stream_messages in messages:
for msg_id, msg_data in stream_messages:
# Track the last message ID we've processed
replay_last_id = (
msg_id if isinstance(msg_id, str) else msg_id.decode()
)
if b"data" in msg_data:
try:
chunk_data = orjson.loads(msg_data[b"data"])
# Reconstruct the appropriate response type
chunk = _reconstruct_chunk(chunk_data)
if chunk:
await subscriber_queue.put(chunk)
except Exception as e:
logger.warning(f"Failed to replay message: {e}")
# Atomically check status and register subscriber under lock
# This prevents race condition where task completes between check and subscribe
should_start_pubsub = False
async with task.lock:
if task.status == "running":
# Register this subscriber for live updates
task.subscribers.add(subscriber_queue)
# Start pub/sub listener if this is the first subscriber
should_start_pubsub = len(task.subscribers) == 1
logger.debug(
f"Registered subscriber for task {task_id}, "
f"total subscribers: {len(task.subscribers)}"
)
else:
# Task is done, add finish marker
await subscriber_queue.put(StreamFinish())
# After registering, do a second read to catch any messages published
# between the first read and registration (closes the race window)
if task.status == "running":
gap_messages = await redis.xread(
{stream_key: replay_last_id}, block=0, count=1000
)
if gap_messages:
for _stream_name, stream_messages in gap_messages:
for _msg_id, msg_data in stream_messages:
if b"data" in msg_data:
try:
chunk_data = orjson.loads(msg_data[b"data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
await subscriber_queue.put(chunk)
except Exception as e:
logger.warning(f"Failed to replay gap message: {e}")
# Start pub/sub listener outside the lock to avoid deadlocks
if should_start_pubsub:
await start_pubsub_listener(task_id)
return subscriber_queue
# Try to load from Redis if not in memory
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
logger.warning(f"Task {task_id} not found in memory or Redis")
return None
# Validate ownership
task_user_id = meta.get(b"user_id", b"").decode() or None
if user_id and task_user_id and task_user_id != user_id:
logger.warning(
f"User {user_id} attempted to subscribe to task {task_id} "
f"owned by {task_user_id}"
)
return None
# Replay from Redis Stream only (task is not in memory, so it's completed/crashed)
subscriber_queue = asyncio.Queue()
stream_key = _get_task_stream_key(task_id)
# Read all messages starting after last_message_id
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
if messages:
for _stream_name, stream_messages in messages:
for _msg_id, msg_data in stream_messages:
if b"data" in msg_data:
try:
chunk_data = orjson.loads(msg_data[b"data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
await subscriber_queue.put(chunk)
except Exception as e:
logger.warning(f"Failed to replay message: {e}")
# Add finish marker since task is not active
await subscriber_queue.put(StreamFinish())
return subscriber_queue
async def mark_task_completed(
task_id: str,
status: Literal["completed", "failed"] = "completed",
) -> None:
"""Mark a task as completed and publish final event.
Args:
task_id: Task ID to mark as completed
status: Final status ("completed" or "failed")
"""
task = _active_tasks.get(task_id)
if task:
# Acquire lock to prevent new subscribers during completion
async with task.lock:
task.status = status
# Send finish event directly to all current subscribers
finish_event = StreamFinish()
for subscriber_queue in task.subscribers:
try:
subscriber_queue.put_nowait(finish_event)
except asyncio.QueueFull:
logger.warning(
f"Subscriber queue full for task {task_id} during completion"
)
# Clear subscribers since task is done
task.subscribers.clear()
# Stop pub/sub listener since task is done
await stop_pubsub_listener(task_id)
# Also publish to Redis Stream for replay (and pub/sub for cross-pod)
await publish_chunk(task_id, StreamFinish())
# Remove from active tasks after a short delay to allow subscribers to finish
async def _cleanup():
await asyncio.sleep(5)
_active_tasks.pop(task_id, None)
logger.info(f"Cleaned up task {task_id} from memory")
asyncio.create_task(_cleanup())
# Update Redis metadata
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
await redis.hset(meta_key, "status", status) # type: ignore[misc]
logger.info(f"Marked task {task_id} as {status}")
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
"""Find a task by its operation ID.
Used by webhook callbacks to locate the task to update.
Args:
operation_id: Operation ID to search for
Returns:
ActiveTask if found, None otherwise
"""
# Check in-memory first
for task in _active_tasks.values():
if task.operation_id == operation_id:
return task
# Try Redis lookup
redis = await get_redis_async()
op_key = _get_operation_mapping_key(operation_id)
task_id = await redis.get(op_key)
if task_id:
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
# Check if task is in memory
if task_id_str in _active_tasks:
return _active_tasks[task_id_str]
# Load metadata from Redis
meta_key = _get_task_meta_key(task_id_str)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if meta:
# Reconstruct task object (not fully active, but has metadata)
return ActiveTask(
task_id=meta.get(b"task_id", b"").decode(),
session_id=meta.get(b"session_id", b"").decode(),
user_id=meta.get(b"user_id", b"").decode() or None,
tool_call_id=meta.get(b"tool_call_id", b"").decode(),
tool_name=meta.get(b"tool_name", b"").decode(),
operation_id=operation_id,
status=meta.get(b"status", b"running").decode(), # type: ignore
)
return None
async def get_task(task_id: str) -> ActiveTask | None:
"""Get a task by its ID.
Args:
task_id: Task ID to look up
Returns:
ActiveTask if found, None otherwise
"""
# Check in-memory first
if task_id in _active_tasks:
return _active_tasks[task_id]
# Try Redis lookup
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if meta:
return ActiveTask(
task_id=meta.get(b"task_id", b"").decode(),
session_id=meta.get(b"session_id", b"").decode(),
user_id=meta.get(b"user_id", b"").decode() or None,
tool_call_id=meta.get(b"tool_call_id", b"").decode(),
tool_name=meta.get(b"tool_name", b"").decode(),
operation_id=meta.get(b"operation_id", b"").decode(),
status=meta.get(b"status", b"running").decode(), # type: ignore[arg-type]
)
return None
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
"""Reconstruct a StreamBaseResponse from JSON data.
Args:
chunk_data: Parsed JSON data from Redis
Returns:
Reconstructed response object, or None if unknown type
"""
from .response_model import (
ResponseType,
StreamError,
StreamFinish,
StreamHeartbeat,
StreamStart,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
chunk_type = chunk_data.get("type")
try:
if chunk_type == ResponseType.START.value:
return StreamStart(**chunk_data)
elif chunk_type == ResponseType.FINISH.value:
return StreamFinish(**chunk_data)
elif chunk_type == ResponseType.TEXT_START.value:
return StreamTextStart(**chunk_data)
elif chunk_type == ResponseType.TEXT_DELTA.value:
return StreamTextDelta(**chunk_data)
elif chunk_type == ResponseType.TEXT_END.value:
return StreamTextEnd(**chunk_data)
elif chunk_type == ResponseType.TOOL_INPUT_START.value:
return StreamToolInputStart(**chunk_data)
elif chunk_type == ResponseType.TOOL_INPUT_AVAILABLE.value:
return StreamToolInputAvailable(**chunk_data)
elif chunk_type == ResponseType.TOOL_OUTPUT_AVAILABLE.value:
return StreamToolOutputAvailable(**chunk_data)
elif chunk_type == ResponseType.ERROR.value:
return StreamError(**chunk_data)
elif chunk_type == ResponseType.USAGE.value:
return StreamUsage(**chunk_data)
elif chunk_type == ResponseType.HEARTBEAT.value:
return StreamHeartbeat(**chunk_data)
else:
logger.warning(f"Unknown chunk type: {chunk_type}")
return None
except Exception as e:
logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}")
return None
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
"""Associate an asyncio.Task with an ActiveTask.
Args:
task_id: Task ID
asyncio_task: The asyncio Task to associate
"""
task = _active_tasks.get(task_id)
if task:
task.asyncio_task = asyncio_task
async def unsubscribe_from_task(
task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
) -> None:
"""Unsubscribe a queue from a task's stream.
Should be called when a client disconnects to clean up resources.
Also stops the pub/sub listener if there are no more local subscribers.
Args:
task_id: Task ID to unsubscribe from
subscriber_queue: The queue to remove from subscribers
"""
task = _active_tasks.get(task_id)
if task:
async with task.lock:
task.subscribers.discard(subscriber_queue)
remaining = len(task.subscribers)
logger.debug(
f"Unsubscribed from task {task_id}, "
f"remaining subscribers: {remaining}"
)
# Stop pub/sub listener if no more local subscribers
if remaining == 0:
await stop_pubsub_listener(task_id)
async def start_pubsub_listener(task_id: str) -> None:
"""Start listening to Redis pub/sub for cross-pod delivery.
This enables real-time updates when another pod publishes chunks for a task
that has local subscribers on this pod.
Args:
task_id: Task ID to listen for
"""
if task_id in _pubsub_listeners:
return # Already listening
task = _active_tasks.get(task_id)
if not task:
return
async def _listener():
try:
redis = await get_redis_async()
pubsub = redis.pubsub()
channel = _get_task_pubsub_channel(task_id)
await pubsub.subscribe(channel)
logger.debug(f"Started pub/sub listener for task {task_id}")
async for message in pubsub.listen():
if message["type"] != "message":
continue
try:
chunk_data = orjson.loads(message["data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
# Deliver to local subscribers
local_task = _active_tasks.get(task_id)
if local_task:
async with local_task.lock:
for queue in local_task.subscribers:
try:
queue.put_nowait(chunk)
except asyncio.QueueFull:
pass
# Stop listening if this was a finish event
if isinstance(chunk, StreamFinish):
break
except Exception as e:
logger.warning(f"Error processing pub/sub message: {e}")
await pubsub.unsubscribe(channel)
await pubsub.close()
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"Pub/sub listener error for task {task_id}: {e}")
finally:
_pubsub_listeners.pop(task_id, None)
logger.debug(f"Stopped pub/sub listener for task {task_id}")
listener_task = asyncio.create_task(_listener())
_pubsub_listeners[task_id] = listener_task
async def stop_pubsub_listener(task_id: str) -> None:
"""Stop the pub/sub listener for a task.
Args:
task_id: Task ID to stop listening for
"""
listener = _pubsub_listeners.pop(task_id, None)
if listener and not listener.done():
listener.cancel()
try:
await listener
except asyncio.CancelledError:
pass
logger.debug(f"Cancelled pub/sub listener for task {task_id}")

View File

@@ -0,0 +1,79 @@
# CoPilot Tools - Future Ideas
## Multimodal Image Support for CoPilot
**Problem:** CoPilot uses a vision-capable model but can't "see" workspace images. When a block generates an image and returns `workspace://abc123`, CoPilot can't evaluate it (e.g., checking blog thumbnail quality).
**Backend Solution:**
When preparing messages for the LLM, detect `workspace://` image references and convert them to proper image content blocks:
```python
# Before sending to LLM, scan for workspace image references
# and inject them as image content parts
# Example message transformation:
# FROM: {"role": "assistant", "content": "Generated image: workspace://abc123"}
# TO: {"role": "assistant", "content": [
# {"type": "text", "text": "Generated image: workspace://abc123"},
# {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
# ]}
```
**Where to implement:**
- In the chat stream handler before calling the LLM
- Or in a message preprocessing step
- Need to fetch image from workspace, convert to base64, add as image content
**Considerations:**
- Only do this for image MIME types (image/png, image/jpeg, etc.)
- May want a size limit (don't pass 10MB images)
- Track which images were "shown" to the AI for frontend indicator
- Cost implications - vision API calls are more expensive
**Frontend Solution:**
Show visual indicator on workspace files in chat:
- If AI saw the image: normal display
- If AI didn't see it: overlay icon saying "AI can't see this image"
Requires response metadata indicating which `workspace://` refs were passed to the model.
---
## Output Post-Processing Layer for run_block
**Problem:** Many blocks produce large outputs that:
- Consume massive context (100KB base64 image = ~133KB tokens)
- Can't fit in conversation
- Break things and cause high LLM costs
**Proposed Solution:** Instead of modifying individual blocks or `store_media_file()`, implement a centralized output processor in `run_block.py` that handles outputs before they're returned to CoPilot.
**Benefits:**
1. **Centralized** - one place to handle all output processing
2. **Future-proof** - new blocks automatically get output processing
3. **Keeps blocks pure** - they don't need to know about context constraints
4. **Handles all large outputs** - not just images
**Processing Rules:**
- Detect base64 data URIs → save to workspace, return `workspace://` reference
- Truncate very long strings (>N chars) with truncation note
- Summarize large arrays/lists (e.g., "Array with 1000 items, first 5: [...]")
- Handle nested large outputs in dicts recursively
- Cap total output size
**Implementation Location:** `run_block.py` after block execution, before returning `BlockOutputResponse`
**Example:**
```python
def _process_outputs_for_context(
outputs: dict[str, list[Any]],
workspace_manager: WorkspaceManager,
max_string_length: int = 10000,
max_array_preview: int = 5,
) -> dict[str, list[Any]]:
"""Process block outputs to prevent context bloat."""
processed = {}
for name, values in outputs.items():
processed[name] = [_process_value(v, workspace_manager) for v in values]
return processed
```

View File

@@ -18,6 +18,12 @@ from .get_doc_page import GetDocPageTool
from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .search_docs import SearchDocsTool
from .workspace_files import (
DeleteWorkspaceFileTool,
ListWorkspaceFilesTool,
ReadWorkspaceFileTool,
WriteWorkspaceFileTool,
)
if TYPE_CHECKING:
from backend.api.features.chat.response_model import StreamToolOutputAvailable
@@ -37,6 +43,11 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"view_agent_output": AgentOutputTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
# Workspace tools for CoPilot file operations
"list_workspace_files": ListWorkspaceFilesTool(),
"read_workspace_file": ReadWorkspaceFileTool(),
"write_workspace_file": WriteWorkspaceFileTool(),
"delete_workspace_file": DeleteWorkspaceFileTool(),
}
# Export individual tool instances for backwards compatibility

View File

@@ -57,21 +57,32 @@ async def decompose_goal(description: str, context: str = "") -> dict[str, Any]
return await decompose_goal_external(description, context)
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
async def generate_agent(
instructions: dict[str, Any],
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Generate agent JSON from instructions.
Args:
instructions: Structured instructions from decompose_goal
operation_id: Operation ID for async processing (enables RabbitMQ callback)
task_id: Task ID for async processing (enables RabbitMQ callback)
Returns:
Agent JSON dict or None on error
Agent JSON dict, {"status": "accepted"} for async, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent")
result = await generate_agent_external(instructions)
result = await generate_agent_external(instructions, operation_id, task_id)
# Don't modify async response
if result and result.get("status") == "accepted":
return result
if result:
# Ensure required fields
if "id" not in result:
@@ -253,7 +264,10 @@ async def get_agent_as_json(
async def generate_agent_patch(
update_request: str, current_agent: dict[str, Any]
update_request: str,
current_agent: dict[str, Any],
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Update an existing agent using natural language.
@@ -265,13 +279,17 @@ async def generate_agent_patch(
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
operation_id: Operation ID for async processing (enables RabbitMQ callback)
task_id: Task ID for async processing (enables RabbitMQ callback)
Returns:
Updated agent JSON, clarifying questions dict, or None on error
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent_patch")
return await generate_agent_patch_external(update_request, current_agent)
return await generate_agent_patch_external(
update_request, current_agent, operation_id, task_id
)

View File

@@ -124,22 +124,39 @@ async def decompose_goal_external(
async def generate_agent_external(
instructions: dict[str, Any]
instructions: dict[str, Any],
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate an agent from instructions.
Args:
instructions: Structured instructions from decompose_goal
operation_id: Operation ID for async processing (enables RabbitMQ callback)
task_id: Task ID for async processing (enables RabbitMQ callback)
Returns:
Agent JSON dict or None on error
Agent JSON dict, or {"status": "accepted"} for async, or None on error
"""
client = _get_client()
# Build request payload
payload: dict[str, Any] = {"instructions": instructions}
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post(
"/api/generate-agent", json={"instructions": instructions}
)
response = await client.post("/api/generate-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {"status": "accepted", "operation_id": operation_id, "task_id": task_id}
response.raise_for_status()
data = response.json()
@@ -161,27 +178,44 @@ async def generate_agent_external(
async def generate_agent_patch_external(
update_request: str, current_agent: dict[str, Any]
update_request: str,
current_agent: dict[str, Any],
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate a patch for an existing agent.
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
operation_id: Operation ID for async processing (enables RabbitMQ callback)
task_id: Task ID for async processing (enables RabbitMQ callback)
Returns:
Updated agent JSON, clarifying questions dict, or None on error
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or None on error
"""
client = _get_client()
# Build request payload
payload: dict[str, Any] = {
"update_request": update_request,
"current_agent_json": current_agent,
}
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post(
"/api/update-agent",
json={
"update_request": update_request,
"current_agent_json": current_agent,
},
)
response = await client.post("/api/update-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async update request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {"status": "accepted", "operation_id": operation_id, "task_id": task_id}
response.raise_for_status()
data = response.json()

View File

@@ -15,6 +15,7 @@ from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -95,6 +96,10 @@ class CreateAgentTool(BaseTool):
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not description:
return ErrorResponse(
message="Please provide a description of what the agent should do.",
@@ -173,7 +178,11 @@ class CreateAgentTool(BaseTool):
# Step 2: Generate agent JSON (external service handles fixing and validation)
try:
agent_json = await generate_agent(decomposition_result)
agent_json = await generate_agent(
decomposition_result,
operation_id=operation_id,
task_id=task_id,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
@@ -194,6 +203,19 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if agent_json.get("status") == "accepted":
logger.info(
f"Agent generation delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent generation started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
agent_name = agent_json.get("name", "Generated Agent")
agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", []))

View File

@@ -15,6 +15,7 @@ from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -102,6 +103,10 @@ class EditAgentTool(BaseTool):
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not agent_id:
return ErrorResponse(
message="Please provide the agent ID to edit.",
@@ -133,7 +138,12 @@ class EditAgentTool(BaseTool):
# Step 2: Generate updated agent (external service handles fixing and validation)
try:
result = await generate_agent_patch(update_request, current_agent)
result = await generate_agent_patch(
update_request,
current_agent,
operation_id=operation_id,
task_id=task_id,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
@@ -152,6 +162,19 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if result.get("status") == "accepted":
logger.info(
f"Agent edit delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent edit started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
# Check if LLM returned clarifying questions
if result.get("type") == "clarifying_questions":
questions = result.get("questions", [])

View File

@@ -28,6 +28,12 @@ class ResponseType(str, Enum):
BLOCK_OUTPUT = "block_output"
DOC_SEARCH_RESULTS = "doc_search_results"
DOC_PAGE = "doc_page"
# Workspace response types
WORKSPACE_FILE_LIST = "workspace_file_list"
WORKSPACE_FILE_CONTENT = "workspace_file_content"
WORKSPACE_FILE_METADATA = "workspace_file_metadata"
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
# Long-running operation types
OPERATION_STARTED = "operation_started"
OPERATION_PENDING = "operation_pending"
@@ -346,11 +352,15 @@ class OperationStartedResponse(ToolResponseBase):
This is returned immediately to the client while the operation continues
to execute. The user can close the tab and check back later.
The task_id can be used to reconnect to the SSE stream via
GET /chat/tasks/{task_id}/stream?last_idx=0
"""
type: ResponseType = ResponseType.OPERATION_STARTED
operation_id: str
tool_name: str
task_id: str | None = None # For SSE reconnection
class OperationPendingResponse(ToolResponseBase):
@@ -374,3 +384,20 @@ class OperationInProgressResponse(ToolResponseBase):
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
tool_call_id: str
class AsyncProcessingResponse(ToolResponseBase):
"""Response when an operation has been delegated to async processing.
This is returned by tools when the external service accepts the request
for async processing (HTTP 202 Accepted). The RabbitMQ completion consumer
will handle the result when the external service completes.
The status field is specifically "accepted" to allow the long-running tool
handler to detect this response and skip LLM continuation.
"""
type: ResponseType = ResponseType.OPERATION_STARTED
status: str = "accepted" # Must be "accepted" for detection
operation_id: str | None = None
task_id: str | None = None

View File

@@ -1,6 +1,7 @@
"""Tool for executing blocks directly."""
import logging
import uuid
from collections import defaultdict
from typing import Any
@@ -8,6 +9,7 @@ from backend.api.features.chat.model import ChatSession
from backend.data.block import get_block
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
@@ -223,11 +225,48 @@ class RunBlockTool(BaseTool):
)
try:
# Fetch actual credentials and prepare kwargs for block execution
# Create execution context with defaults (blocks may require it)
# Get or create user's workspace for CoPilot file operations
workspace = await get_or_create_workspace(user_id)
# Generate synthetic IDs for CoPilot context
# Each chat session is treated as its own agent with one continuous run
# This means:
# - graph_id (agent) = session (memories scoped to session when limit_to_agent=True)
# - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True)
# - node_exec_id = unique per block execution
synthetic_graph_id = f"copilot-session-{session.session_id}"
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
synthetic_node_id = f"copilot-node-{block_id}"
synthetic_node_exec_id = (
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
)
# Create unified execution context with all required fields
execution_context = ExecutionContext(
# Execution identity
user_id=user_id,
graph_id=synthetic_graph_id,
graph_exec_id=synthetic_graph_exec_id,
graph_version=1, # Versions are 1-indexed
node_id=synthetic_node_id,
node_exec_id=synthetic_node_exec_id,
# Workspace with session scoping
workspace_id=workspace.id,
session_id=session.session_id,
)
# Prepare kwargs for block execution
# Keep individual kwargs for backwards compatibility with existing blocks
exec_kwargs: dict[str, Any] = {
"user_id": user_id,
"execution_context": ExecutionContext(),
"execution_context": execution_context,
# Legacy: individual kwargs for blocks not yet using execution_context
"workspace_id": workspace.id,
"graph_exec_id": synthetic_graph_exec_id,
"node_exec_id": synthetic_node_exec_id,
"node_id": synthetic_node_id,
"graph_version": 1, # Versions are 1-indexed
"graph_id": synthetic_graph_id,
}
for field_name, cred_meta in matched_credentials.items():

View File

@@ -0,0 +1,620 @@
"""CoPilot tools for workspace file operations."""
import base64
import logging
from typing import Any, Optional
from pydantic import BaseModel
from backend.api.features.chat.model import ChatSession
from backend.data.workspace import get_or_create_workspace
from backend.util.settings import Config
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import WorkspaceManager
from .base import BaseTool
from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
class WorkspaceFileInfoData(BaseModel):
"""Data model for workspace file information (not a response itself)."""
file_id: str
name: str
path: str
mime_type: str
size_bytes: int
class WorkspaceFileListResponse(ToolResponseBase):
"""Response containing list of workspace files."""
type: ResponseType = ResponseType.WORKSPACE_FILE_LIST
files: list[WorkspaceFileInfoData]
total_count: int
class WorkspaceFileContentResponse(ToolResponseBase):
"""Response containing workspace file content (legacy, for small text files)."""
type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT
file_id: str
name: str
path: str
mime_type: str
content_base64: str
class WorkspaceFileMetadataResponse(ToolResponseBase):
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA
file_id: str
name: str
path: str
mime_type: str
size_bytes: int
download_url: str
preview: str | None = None # First 500 chars for text files
class WorkspaceWriteResponse(ToolResponseBase):
"""Response after writing a file to workspace."""
type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN
file_id: str
name: str
path: str
size_bytes: int
class WorkspaceDeleteResponse(ToolResponseBase):
"""Response after deleting a file from workspace."""
type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED
file_id: str
success: bool
class ListWorkspaceFilesTool(BaseTool):
"""Tool for listing files in user's workspace."""
@property
def name(self) -> str:
return "list_workspace_files"
@property
def description(self) -> str:
return (
"List files in the user's workspace. "
"Returns file names, paths, sizes, and metadata. "
"Optionally filter by path prefix."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"path_prefix": {
"type": "string",
"description": (
"Optional path prefix to filter files "
"(e.g., '/documents/' to list only files in documents folder). "
"By default, only files from the current session are listed."
),
},
"limit": {
"type": "integer",
"description": "Maximum number of files to return (default 50, max 100)",
"minimum": 1,
"maximum": 100,
},
"include_all_sessions": {
"type": "boolean",
"description": (
"If true, list files from all sessions. "
"Default is false (only current session's files)."
),
},
},
"required": [],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
path_prefix: Optional[str] = kwargs.get("path_prefix")
limit = min(kwargs.get("limit", 50), 100)
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
files = await manager.list_files(
path=path_prefix,
limit=limit,
include_all_sessions=include_all_sessions,
)
total = await manager.get_file_count(
path=path_prefix,
include_all_sessions=include_all_sessions,
)
file_infos = [
WorkspaceFileInfoData(
file_id=f.id,
name=f.name,
path=f.path,
mime_type=f.mimeType,
size_bytes=f.sizeBytes,
)
for f in files
]
scope_msg = "all sessions" if include_all_sessions else "current session"
return WorkspaceFileListResponse(
files=file_infos,
total_count=total,
message=f"Found {len(files)} files in workspace ({scope_msg})",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error listing workspace files: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to list workspace files: {str(e)}",
error=str(e),
session_id=session_id,
)
class ReadWorkspaceFileTool(BaseTool):
"""Tool for reading file content from workspace."""
# Size threshold for returning full content vs metadata+URL
# Files larger than this return metadata with download URL to prevent context bloat
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
# Preview size for text files
PREVIEW_SIZE = 500
@property
def name(self) -> str:
return "read_workspace_file"
@property
def description(self) -> str:
return (
"Read a file from the user's workspace. "
"Specify either file_id or path to identify the file. "
"For small text files, returns content directly. "
"For large or binary files, returns metadata and a download URL. "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"file_id": {
"type": "string",
"description": "The file's unique ID (from list_workspace_files)",
},
"path": {
"type": "string",
"description": (
"The virtual file path (e.g., '/documents/report.pdf'). "
"Scoped to current session by default."
),
},
"force_download_url": {
"type": "boolean",
"description": (
"If true, always return metadata+URL instead of inline content. "
"Default is false (auto-selects based on file size/type)."
),
},
},
"required": [], # At least one must be provided
}
@property
def requires_auth(self) -> bool:
return True
def _is_text_mime_type(self, mime_type: str) -> bool:
"""Check if the MIME type is a text-based type."""
text_types = [
"text/",
"application/json",
"application/xml",
"application/javascript",
"application/x-python",
"application/x-sh",
]
return any(mime_type.startswith(t) for t in text_types)
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
force_download_url: bool = kwargs.get("force_download_url", False)
if not file_id and not path:
return ErrorResponse(
message="Please provide either file_id or path",
session_id=session_id,
)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
# Get file info
if file_id:
file_info = await manager.get_file_info(file_id)
if file_info is None:
return ErrorResponse(
message=f"File not found: {file_id}",
session_id=session_id,
)
target_file_id = file_id
else:
# path is guaranteed to be non-None here due to the check above
assert path is not None
file_info = await manager.get_file_info_by_path(path)
if file_info is None:
return ErrorResponse(
message=f"File not found at path: {path}",
session_id=session_id,
)
target_file_id = file_info.id
# Decide whether to return inline content or metadata+URL
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
is_text_file = self._is_text_mime_type(file_info.mimeType)
# Return inline content for small text files (unless force_download_url)
if is_small_file and is_text_file and not force_download_url:
content = await manager.read_file_by_id(target_file_id)
content_b64 = base64.b64encode(content).decode("utf-8")
return WorkspaceFileContentResponse(
file_id=file_info.id,
name=file_info.name,
path=file_info.path,
mime_type=file_info.mimeType,
content_base64=content_b64,
message=f"Successfully read file: {file_info.name}",
session_id=session_id,
)
# Return metadata + workspace:// reference for large or binary files
# This prevents context bloat (100KB file = ~133KB as base64)
# Use workspace:// format so frontend urlTransform can add proxy prefix
download_url = f"workspace://{target_file_id}"
# Generate preview for text files
preview: str | None = None
if is_text_file:
try:
content = await manager.read_file_by_id(target_file_id)
preview_text = content[: self.PREVIEW_SIZE].decode(
"utf-8", errors="replace"
)
if len(content) > self.PREVIEW_SIZE:
preview_text += "..."
preview = preview_text
except Exception:
pass # Preview is optional
return WorkspaceFileMetadataResponse(
file_id=file_info.id,
name=file_info.name,
path=file_info.path,
mime_type=file_info.mimeType,
size_bytes=file_info.sizeBytes,
download_url=download_url,
preview=preview,
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
session_id=session_id,
)
except FileNotFoundError as e:
return ErrorResponse(
message=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Error reading workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to read workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)
class WriteWorkspaceFileTool(BaseTool):
"""Tool for writing files to workspace."""
@property
def name(self) -> str:
return "write_workspace_file"
@property
def description(self) -> str:
return (
"Write or create a file in the user's workspace. "
"Provide the content as a base64-encoded string. "
f"Maximum file size is {Config().max_file_size_mb}MB. "
"Files are saved to the current session's folder by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "Name for the file (e.g., 'report.pdf')",
},
"content_base64": {
"type": "string",
"description": "Base64-encoded file content",
},
"path": {
"type": "string",
"description": (
"Optional virtual path where to save the file "
"(e.g., '/documents/report.pdf'). "
"Defaults to '/{filename}'. Scoped to current session."
),
},
"mime_type": {
"type": "string",
"description": (
"Optional MIME type of the file. "
"Auto-detected from filename if not provided."
),
},
"overwrite": {
"type": "boolean",
"description": "Whether to overwrite if file exists at path (default: false)",
},
},
"required": ["filename", "content_base64"],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
filename: str = kwargs.get("filename", "")
content_b64: str = kwargs.get("content_base64", "")
path: Optional[str] = kwargs.get("path")
mime_type: Optional[str] = kwargs.get("mime_type")
overwrite: bool = kwargs.get("overwrite", False)
if not filename:
return ErrorResponse(
message="Please provide a filename",
session_id=session_id,
)
if not content_b64:
return ErrorResponse(
message="Please provide content_base64",
session_id=session_id,
)
# Decode content
try:
content = base64.b64decode(content_b64)
except Exception:
return ErrorResponse(
message="Invalid base64-encoded content",
session_id=session_id,
)
# Check size
max_file_size = Config().max_file_size_mb * 1024 * 1024
if len(content) > max_file_size:
return ErrorResponse(
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
session_id=session_id,
)
try:
# Virus scan
await scan_content_safe(content, filename=filename)
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
file_record = await manager.write_file(
content=content,
filename=filename,
path=path,
mime_type=mime_type,
overwrite=overwrite,
)
return WorkspaceWriteResponse(
file_id=file_record.id,
name=file_record.name,
path=file_record.path,
size_bytes=file_record.sizeBytes,
message=f"Successfully wrote file: {file_record.name}",
session_id=session_id,
)
except ValueError as e:
return ErrorResponse(
message=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Error writing workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to write workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)
class DeleteWorkspaceFileTool(BaseTool):
"""Tool for deleting files from workspace."""
@property
def name(self) -> str:
return "delete_workspace_file"
@property
def description(self) -> str:
return (
"Delete a file from the user's workspace. "
"Specify either file_id or path to identify the file. "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"file_id": {
"type": "string",
"description": "The file's unique ID (from list_workspace_files)",
},
"path": {
"type": "string",
"description": (
"The virtual file path (e.g., '/documents/report.pdf'). "
"Scoped to current session by default."
),
},
},
"required": [], # At least one must be provided
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
if not file_id and not path:
return ErrorResponse(
message="Please provide either file_id or path",
session_id=session_id,
)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
# Determine the file_id to delete
target_file_id: str
if file_id:
target_file_id = file_id
else:
# path is guaranteed to be non-None here due to the check above
assert path is not None
file_info = await manager.get_file_info_by_path(path)
if file_info is None:
return ErrorResponse(
message=f"File not found at path: {path}",
session_id=session_id,
)
target_file_id = file_info.id
success = await manager.delete_file(target_file_id)
if not success:
return ErrorResponse(
message=f"File not found: {target_file_id}",
session_id=session_id,
)
return WorkspaceDeleteResponse(
file_id=target_file_id,
success=True,
message="File deleted successfully",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to delete workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)

View File

@@ -0,0 +1 @@
# Workspace API feature module

View File

@@ -0,0 +1,122 @@
"""
Workspace API routes for managing user file storage.
"""
import logging
import re
from typing import Annotated
from urllib.parse import quote
import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
from fastapi.responses import Response
from backend.data.workspace import get_workspace, get_workspace_file
from backend.util.workspace_storage import get_workspace_storage
def _sanitize_filename_for_header(filename: str) -> str:
"""
Sanitize filename for Content-Disposition header to prevent header injection.
Removes/replaces characters that could break the header or inject new headers.
Uses RFC5987 encoding for non-ASCII characters.
"""
# Remove CR, LF, and null bytes (header injection prevention)
sanitized = re.sub(r"[\r\n\x00]", "", filename)
# Escape quotes
sanitized = sanitized.replace('"', '\\"')
# For non-ASCII, use RFC5987 filename* parameter
# Check if filename has non-ASCII characters
try:
sanitized.encode("ascii")
return f'attachment; filename="{sanitized}"'
except UnicodeEncodeError:
# Use RFC5987 encoding for UTF-8 filenames
encoded = quote(sanitized, safe="")
return f"attachment; filename*=UTF-8''{encoded}"
logger = logging.getLogger(__name__)
router = fastapi.APIRouter(
dependencies=[fastapi.Security(requires_user)],
)
def _create_streaming_response(content: bytes, file) -> Response:
"""Create a streaming response for file content."""
return Response(
content=content,
media_type=file.mimeType,
headers={
"Content-Disposition": _sanitize_filename_for_header(file.name),
"Content-Length": str(len(content)),
},
)
async def _create_file_download_response(file) -> Response:
"""
Create a download response for a workspace file.
Handles both local storage (direct streaming) and GCS (signed URL redirect
with fallback to streaming).
"""
storage = await get_workspace_storage()
# For local storage, stream the file directly
if file.storagePath.startswith("local://"):
content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file)
# For GCS, try to redirect to signed URL, fall back to streaming
try:
url = await storage.get_download_url(file.storagePath, expires_in=300)
# If we got back an API path (fallback), stream directly instead
if url.startswith("/api/"):
content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file)
return fastapi.responses.RedirectResponse(url=url, status_code=302)
except Exception as e:
# Log the signed URL failure with context
logger.error(
f"Failed to get signed URL for file {file.id} "
f"(storagePath={file.storagePath}): {e}",
exc_info=True,
)
# Fall back to streaming directly from GCS
try:
content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file)
except Exception as fallback_error:
logger.error(
f"Fallback streaming also failed for file {file.id} "
f"(storagePath={file.storagePath}): {fallback_error}",
exc_info=True,
)
raise
@router.get(
"/files/{file_id}/download",
summary="Download file by ID",
)
async def download_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
file_id: str,
) -> Response:
"""
Download a file by its ID.
Returns the file content directly or redirects to a signed URL for GCS.
"""
workspace = await get_workspace(user_id)
if workspace is None:
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
file = await get_workspace_file(file_id, workspace.id)
if file is None:
raise fastapi.HTTPException(status_code=404, detail="File not found")
return await _create_file_download_response(file)

View File

@@ -22,6 +22,10 @@ import backend.api.features.admin.store_admin_routes
import backend.api.features.builder
import backend.api.features.builder.routes
import backend.api.features.chat.routes as chat_routes
from backend.api.features.chat.completion_consumer import (
start_completion_consumer,
stop_completion_consumer,
)
import backend.api.features.executions.review.routes
import backend.api.features.library.db
import backend.api.features.library.model
@@ -32,6 +36,7 @@ import backend.api.features.postmark.postmark
import backend.api.features.store.model
import backend.api.features.store.routes
import backend.api.features.v1
import backend.api.features.workspace.routes as workspace_routes
import backend.data.block
import backend.data.db
import backend.data.graph
@@ -52,6 +57,7 @@ from backend.util.exceptions import (
)
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
from backend.util.service import UnhealthyServiceError
from backend.util.workspace_storage import shutdown_workspace_storage
from .external.fastapi_app import external_api
from .features.analytics import router as analytics_router
@@ -116,14 +122,31 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
# Start chat completion consumer for RabbitMQ notifications
try:
await start_completion_consumer()
except Exception as e:
logger.warning(f"Could not start chat completion consumer: {e}")
with launch_darkly_context():
yield
# Stop chat completion consumer
try:
await stop_completion_consumer()
except Exception as e:
logger.warning(f"Error stopping chat completion consumer: {e}")
try:
await shutdown_cloud_storage_handler()
except Exception as e:
logger.warning(f"Error shutting down cloud storage handler: {e}")
try:
await shutdown_workspace_storage()
except Exception as e:
logger.warning(f"Error shutting down workspace storage: {e}")
await backend.data.db.disconnect()
@@ -315,6 +338,11 @@ app.include_router(
tags=["v2", "chat"],
prefix="/api/chat",
)
app.include_router(
workspace_routes.router,
tags=["workspace"],
prefix="/api/workspace",
)
app.include_router(
backend.api.features.oauth.router,
tags=["oauth"],

View File

@@ -13,6 +13,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -117,11 +118,13 @@ class AIImageCustomizerBlock(Block):
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
("image_url", "https://replicate.delivery/generated-image.jpg"),
# Output will be a workspace ref or data URI depending on context
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
],
test_mock={
# Use data URI to avoid HTTP requests during tests
"run_model": lambda *args, **kwargs: MediaFileType(
"https://replicate.delivery/generated-image.jpg"
""
),
},
test_credentials=TEST_CREDENTIALS,
@@ -132,8 +135,7 @@ class AIImageCustomizerBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
try:
@@ -141,10 +143,9 @@ class AIImageCustomizerBlock(Block):
processed_images = await asyncio.gather(
*(
store_media_file(
graph_exec_id=graph_exec_id,
file=img,
user_id=user_id,
return_content=True,
execution_context=execution_context,
return_format="for_external_api", # Get content for Replicate API
)
for img in input_data.images
)
@@ -158,7 +159,14 @@ class AIImageCustomizerBlock(Block):
aspect_ratio=input_data.aspect_ratio.value,
output_format=input_data.output_format.value,
)
yield "image_url", result
# Store the generated image to the user's workspace for persistence
stored_url = await store_media_file(
file=result,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", stored_url
except Exception as e:
yield "error", str(e)

View File

@@ -6,6 +6,7 @@ from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -13,6 +14,8 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
class ImageSize(str, Enum):
@@ -165,11 +168,13 @@ class AIImageGeneratorBlock(Block):
test_output=[
(
"image_url",
"https://replicate.delivery/generated-image.webp",
# Test output is a data URI since we now store images
lambda x: x.startswith(""
},
)
@@ -318,11 +323,24 @@ class AIImageGeneratorBlock(Block):
style_text = style_map.get(style, "")
return f"{style_text} of" if style_text else ""
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
):
try:
url = await self.generate_image(input_data, credentials)
if url:
yield "image_url", url
# Store the generated image to the user's workspace/execution folder
stored_url = await store_media_file(
file=MediaFileType(url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", stored_url
else:
yield "error", "Image generation returned an empty result."
except Exception as e:

View File

@@ -13,6 +13,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -21,7 +22,9 @@ from backend.data.model import (
)
from backend.integrations.providers import ProviderName
from backend.util.exceptions import BlockExecutionError
from backend.util.file import store_media_file
from backend.util.request import Requests
from backend.util.type import MediaFileType
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@@ -271,7 +274,10 @@ class AIShortformVideoCreatorBlock(Block):
"voice": Voice.LILY,
"video_style": VisualMediaType.STOCK_VIDEOS,
},
test_output=("video_url", "https://example.com/video.mp4"),
test_output=(
"video_url",
lambda x: x.startswith(("workspace://", "data:")),
),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
@@ -280,15 +286,21 @@ class AIShortformVideoCreatorBlock(Block):
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "https://example.com/video.mp4",
"videoUrl": "data:video/mp4;base64,AAAA",
},
"wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4",
# Use data URI to avoid HTTP requests during tests
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# Create a new Webhook.site URL
webhook_token, webhook_url = await self.create_webhook()
@@ -340,7 +352,13 @@ class AIShortformVideoCreatorBlock(Block):
)
video_url = await self.wait_for_video(credentials.api_key, pid)
logger.debug(f"Video ready: {video_url}")
yield "video_url", video_url
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
class AIAdMakerVideoCreatorBlock(Block):
@@ -447,7 +465,10 @@ class AIAdMakerVideoCreatorBlock(Block):
"https://cdn.revid.ai/uploads/1747076315114-image.png",
],
},
test_output=("video_url", "https://example.com/ad.mp4"),
test_output=(
"video_url",
lambda x: x.startswith(("workspace://", "data:")),
),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
@@ -456,14 +477,21 @@ class AIAdMakerVideoCreatorBlock(Block):
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "https://example.com/ad.mp4",
"videoUrl": "data:video/mp4;base64,AAAA",
},
"wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4",
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
):
webhook_token, webhook_url = await self.create_webhook()
payload = {
@@ -531,7 +559,13 @@ class AIAdMakerVideoCreatorBlock(Block):
raise RuntimeError("Failed to create video: No project ID returned")
video_url = await self.wait_for_video(credentials.api_key, pid)
yield "video_url", video_url
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
class AIScreenshotToVideoAdBlock(Block):
@@ -626,7 +660,10 @@ class AIScreenshotToVideoAdBlock(Block):
"script": "Amazing numbers!",
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
},
test_output=("video_url", "https://example.com/screenshot.mp4"),
test_output=(
"video_url",
lambda x: x.startswith(("workspace://", "data:")),
),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
@@ -635,14 +672,21 @@ class AIScreenshotToVideoAdBlock(Block):
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "https://example.com/screenshot.mp4",
"videoUrl": "data:video/mp4;base64,AAAA",
},
"wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4",
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
):
webhook_token, webhook_url = await self.create_webhook()
payload = {
@@ -710,4 +754,10 @@ class AIScreenshotToVideoAdBlock(Block):
raise RuntimeError("Failed to create video: No project ID returned")
video_url = await self.wait_for_video(credentials.api_key, pid)
yield "video_url", video_url
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url

View File

@@ -6,6 +6,7 @@ if TYPE_CHECKING:
from pydantic import SecretStr
from backend.data.execution import ExecutionContext
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -17,6 +18,8 @@ from backend.sdk import (
Requests,
SchemaField,
)
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
from ._config import bannerbear
@@ -135,15 +138,17 @@ class BannerbearTextOverlayBlock(Block):
},
test_output=[
("success", True),
("image_url", "https://cdn.bannerbear.com/test-image.jpg"),
# Output will be a workspace ref or data URI depending on context
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
("uid", "test-uid-123"),
("status", "completed"),
],
test_mock={
# Use data URI to avoid HTTP requests during tests
"_make_api_request": lambda *args, **kwargs: {
"uid": "test-uid-123",
"status": "completed",
"image_url": "https://cdn.bannerbear.com/test-image.jpg",
"image_url": "",
}
},
test_credentials=TEST_CREDENTIALS,
@@ -177,7 +182,12 @@ class BannerbearTextOverlayBlock(Block):
raise Exception(error_msg)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# Build the modifications array
modifications = []
@@ -234,6 +244,18 @@ class BannerbearTextOverlayBlock(Block):
# Synchronous request - image should be ready
yield "success", True
yield "image_url", data.get("image_url", "")
# Store the generated image to workspace for persistence
image_url = data.get("image_url", "")
if image_url:
stored_url = await store_media_file(
file=MediaFileType(image_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", stored_url
else:
yield "image_url", ""
yield "uid", data.get("uid", "")
yield "status", data.get("status", "completed")

View File

@@ -9,6 +9,7 @@ from backend.data.block import (
BlockSchemaOutput,
BlockType,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import store_media_file
from backend.util.type import MediaFileType, convert
@@ -17,10 +18,10 @@ from backend.util.type import MediaFileType, convert
class FileStoreBlock(Block):
class Input(BlockSchemaInput):
file_in: MediaFileType = SchemaField(
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
description="The file to download and store. Can be a URL (https://...), data URI, or local path."
)
base_64: bool = SchemaField(
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
description="Whether to produce output in base64 format (not recommended, you can pass the file reference across blocks).",
default=False,
advanced=True,
title="Produce Base64 Output",
@@ -28,13 +29,18 @@ class FileStoreBlock(Block):
class Output(BlockSchemaOutput):
file_out: MediaFileType = SchemaField(
description="The relative path to the stored file in the temporary directory."
description="Reference to the stored file. In CoPilot: workspace:// URI (visible in list_workspace_files). In graphs: data URI for passing to other blocks."
)
def __init__(self):
super().__init__(
id="cbb50872-625b-42f0-8203-a2ae78242d8a",
description="Stores the input file in the temporary directory.",
description=(
"Downloads and stores a file from a URL, data URI, or local path. "
"Use this to fetch images, documents, or other files for processing. "
"In CoPilot: saves to workspace (use list_workspace_files to see it). "
"In graphs: outputs a data URI to pass to other blocks."
),
categories={BlockCategory.BASIC, BlockCategory.MULTIMEDIA},
input_schema=FileStoreBlock.Input,
output_schema=FileStoreBlock.Output,
@@ -45,15 +51,18 @@ class FileStoreBlock(Block):
self,
input_data: Input,
*,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# Determine return format based on user preference
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
yield "file_out", await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.file_in,
user_id=user_id,
return_content=input_data.base_64,
execution_context=execution_context,
return_format=return_format,
)

View File

@@ -15,6 +15,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import APIKeyCredentials, SchemaField
from backend.util.file import store_media_file
from backend.util.request import Requests
@@ -666,8 +667,7 @@ class SendDiscordFileBlock(Block):
file: MediaFileType,
filename: str,
message_content: str,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
) -> dict:
intents = discord.Intents.default()
intents.guilds = True
@@ -731,10 +731,9 @@ class SendDiscordFileBlock(Block):
# Local file path - read from stored media file
# This would be a path from a previous block's output
stored_file = await store_media_file(
graph_exec_id=graph_exec_id,
file=file,
user_id=user_id,
return_content=True, # Get as data URI
execution_context=execution_context,
return_format="for_external_api", # Get content to send to Discord
)
# Now process as data URI
header, encoded = stored_file.split(",", 1)
@@ -781,8 +780,7 @@ class SendDiscordFileBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
try:
@@ -793,8 +791,7 @@ class SendDiscordFileBlock(Block):
file=input_data.file,
filename=input_data.filename,
message_content=input_data.message_content,
graph_exec_id=graph_exec_id,
user_id=user_id,
execution_context=execution_context,
)
yield "status", result.get("status", "Unknown error")

View File

@@ -17,8 +17,11 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import store_media_file
from backend.util.request import ClientResponseError, Requests
from backend.util.type import MediaFileType
logger = logging.getLogger(__name__)
@@ -64,9 +67,13 @@ class AIVideoGeneratorBlock(Block):
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("video_url", "https://fal.media/files/example/video.mp4")],
test_output=[
# Output will be a workspace ref or data URI depending on context
("video_url", lambda x: x.startswith(("workspace://", "data:"))),
],
test_mock={
"generate_video": lambda *args, **kwargs: "https://fal.media/files/example/video.mp4"
# Use data URI to avoid HTTP requests during tests
"generate_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA"
},
)
@@ -208,11 +215,22 @@ class AIVideoGeneratorBlock(Block):
raise RuntimeError(f"API request failed: {str(e)}")
async def run(
self, input_data: Input, *, credentials: FalCredentials, **kwargs
self,
input_data: Input,
*,
credentials: FalCredentials,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
try:
video_url = await self.generate_video(input_data, credentials)
yield "video_url", video_url
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
except Exception as e:
error_message = str(e)
yield "error", error_message

View File

@@ -12,6 +12,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -121,10 +122,12 @@ class AIImageEditorBlock(Block):
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
("output_image", "https://replicate.com/output/edited-image.png"),
# Output will be a workspace ref or data URI depending on context
("output_image", lambda x: x.startswith(("workspace://", "data:"))),
],
test_mock={
"run_model": lambda *args, **kwargs: "https://replicate.com/output/edited-image.png",
# Use data URI to avoid HTTP requests during tests
"run_model": lambda *args, **kwargs: "",
},
test_credentials=TEST_CREDENTIALS,
)
@@ -134,8 +137,7 @@ class AIImageEditorBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
result = await self.run_model(
@@ -144,20 +146,25 @@ class AIImageEditorBlock(Block):
prompt=input_data.prompt,
input_image_b64=(
await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.input_image,
user_id=user_id,
return_content=True,
execution_context=execution_context,
return_format="for_external_api", # Get content for Replicate API
)
if input_data.input_image
else None
),
aspect_ratio=input_data.aspect_ratio.value,
seed=input_data.seed,
user_id=user_id,
graph_exec_id=graph_exec_id,
user_id=execution_context.user_id or "",
graph_exec_id=execution_context.graph_exec_id or "",
)
yield "output_image", result
# Store the generated image to the user's workspace for persistence
stored_url = await store_media_file(
file=result,
execution_context=execution_context,
return_format="for_block_output",
)
yield "output_image", stored_url
async def run_model(
self,

View File

@@ -21,6 +21,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
from backend.util.settings import Settings
@@ -95,8 +96,7 @@ def _make_mime_text(
async def create_mime_message(
input_data,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
) -> str:
"""Create a MIME message with attachments and return base64-encoded raw message."""
@@ -117,12 +117,12 @@ async def create_mime_message(
if input_data.attachments:
for attach in input_data.attachments:
local_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=attach,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
abs_path = get_exec_file_path(graph_exec_id, local_path)
assert execution_context.graph_exec_id # Validated by store_media_file
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
part = MIMEBase("application", "octet-stream")
with open(abs_path, "rb") as f:
part.set_payload(f.read())
@@ -582,27 +582,25 @@ class GmailSendBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
result = await self._send_email(
service,
input_data,
graph_exec_id,
user_id,
execution_context,
)
yield "result", result
async def _send_email(
self, service, input_data: Input, graph_exec_id: str, user_id: str
self, service, input_data: Input, execution_context: ExecutionContext
) -> dict:
if not input_data.to or not input_data.subject or not input_data.body:
raise ValueError(
"At least one recipient, subject, and body are required for sending an email"
)
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
raw_message = await create_mime_message(input_data, execution_context)
sent_message = await asyncio.to_thread(
lambda: service.users()
.messages()
@@ -692,30 +690,28 @@ class GmailCreateDraftBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
result = await self._create_draft(
service,
input_data,
graph_exec_id,
user_id,
execution_context,
)
yield "result", GmailDraftResult(
id=result["id"], message_id=result["message"]["id"], status="draft_created"
)
async def _create_draft(
self, service, input_data: Input, graph_exec_id: str, user_id: str
self, service, input_data: Input, execution_context: ExecutionContext
) -> dict:
if not input_data.to or not input_data.subject:
raise ValueError(
"At least one recipient and subject are required for creating a draft"
)
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
raw_message = await create_mime_message(input_data, execution_context)
draft = await asyncio.to_thread(
lambda: service.users()
.drafts()
@@ -1100,7 +1096,7 @@ class GmailGetThreadBlock(GmailBase):
async def _build_reply_message(
service, input_data, graph_exec_id: str, user_id: str
service, input_data, execution_context: ExecutionContext
) -> tuple[str, str]:
"""
Builds a reply MIME message for Gmail threads.
@@ -1190,12 +1186,12 @@ async def _build_reply_message(
# Handle attachments
for attach in input_data.attachments:
local_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=attach,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
abs_path = get_exec_file_path(graph_exec_id, local_path)
assert execution_context.graph_exec_id # Validated by store_media_file
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
part = MIMEBase("application", "octet-stream")
with open(abs_path, "rb") as f:
part.set_payload(f.read())
@@ -1311,16 +1307,14 @@ class GmailReplyBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
message = await self._reply(
service,
input_data,
graph_exec_id,
user_id,
execution_context,
)
yield "messageId", message["id"]
yield "threadId", message.get("threadId", input_data.threadId)
@@ -1343,11 +1337,11 @@ class GmailReplyBlock(GmailBase):
yield "email", email
async def _reply(
self, service, input_data: Input, graph_exec_id: str, user_id: str
self, service, input_data: Input, execution_context: ExecutionContext
) -> dict:
# Build the reply message using the shared helper
raw, thread_id = await _build_reply_message(
service, input_data, graph_exec_id, user_id
service, input_data, execution_context
)
# Send the message
@@ -1441,16 +1435,14 @@ class GmailDraftReplyBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
draft = await self._create_draft_reply(
service,
input_data,
graph_exec_id,
user_id,
execution_context,
)
yield "draftId", draft["id"]
yield "messageId", draft["message"]["id"]
@@ -1458,11 +1450,11 @@ class GmailDraftReplyBlock(GmailBase):
yield "status", "draft_created"
async def _create_draft_reply(
self, service, input_data: Input, graph_exec_id: str, user_id: str
self, service, input_data: Input, execution_context: ExecutionContext
) -> dict:
# Build the reply message using the shared helper
raw, thread_id = await _build_reply_message(
service, input_data, graph_exec_id, user_id
service, input_data, execution_context
)
# Create draft with proper thread association
@@ -1629,23 +1621,21 @@ class GmailForwardBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
result = await self._forward_message(
service,
input_data,
graph_exec_id,
user_id,
execution_context,
)
yield "messageId", result["id"]
yield "threadId", result.get("threadId", "")
yield "status", "forwarded"
async def _forward_message(
self, service, input_data: Input, graph_exec_id: str, user_id: str
self, service, input_data: Input, execution_context: ExecutionContext
) -> dict:
if not input_data.to:
raise ValueError("At least one recipient is required for forwarding")
@@ -1727,12 +1717,12 @@ To: {original_to}
# Add any additional attachments
for attach in input_data.additionalAttachments:
local_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=attach,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
abs_path = get_exec_file_path(graph_exec_id, local_path)
assert execution_context.graph_exec_id # Validated by store_media_file
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
part = MIMEBase("application", "octet-stream")
with open(abs_path, "rb") as f:
part.set_payload(f.read())

View File

@@ -15,6 +15,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
@@ -116,10 +117,9 @@ class SendWebRequestBlock(Block):
@staticmethod
async def _prepare_files(
graph_exec_id: str,
execution_context: ExecutionContext,
files_name: str,
files: list[MediaFileType],
user_id: str,
) -> list[tuple[str, tuple[str, BytesIO, str]]]:
"""
Prepare files for the request by storing them and reading their content.
@@ -127,11 +127,16 @@ class SendWebRequestBlock(Block):
(files_name, (filename, BytesIO, mime_type))
"""
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
graph_exec_id = execution_context.graph_exec_id
if graph_exec_id is None:
raise ValueError("graph_exec_id is required for file operations")
for media in files:
# Normalise to a list so we can repeat the same key
rel_path = await store_media_file(
graph_exec_id, media, user_id, return_content=False
file=media,
execution_context=execution_context,
return_format="for_local_processing",
)
abs_path = get_exec_file_path(graph_exec_id, rel_path)
async with aiofiles.open(abs_path, "rb") as f:
@@ -143,7 +148,7 @@ class SendWebRequestBlock(Block):
return files_payload
async def run(
self, input_data: Input, *, graph_exec_id: str, user_id: str, **kwargs
self, input_data: Input, *, execution_context: ExecutionContext, **kwargs
) -> BlockOutput:
# ─── Parse/normalise body ────────────────────────────────────
body = input_data.body
@@ -174,7 +179,7 @@ class SendWebRequestBlock(Block):
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
if use_files:
files_payload = await self._prepare_files(
graph_exec_id, input_data.files_name, input_data.files, user_id
execution_context, input_data.files_name, input_data.files
)
# Enforce body format rules
@@ -238,9 +243,8 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
self,
input_data: Input,
*,
graph_exec_id: str,
execution_context: ExecutionContext,
credentials: HostScopedCredentials,
user_id: str,
**kwargs,
) -> BlockOutput:
# Create SendWebRequestBlock.Input from our input (removing credentials field)
@@ -271,6 +275,6 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
# Use parent class run method
async for output_name, output_data in super().run(
base_input, graph_exec_id=graph_exec_id, user_id=user_id, **kwargs
base_input, execution_context=execution_context, **kwargs
):
yield output_name, output_data

View File

@@ -12,6 +12,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockType,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import store_media_file
from backend.util.mock import MockObject
@@ -462,18 +463,21 @@ class AgentFileInputBlock(AgentInputBlock):
self,
input_data: Input,
*,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
if not input_data.value:
return
# Determine return format based on user preference
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
yield "result", await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.value,
user_id=user_id,
return_content=input_data.base_64,
execution_context=execution_context,
return_format=return_format,
)

View File

@@ -1,6 +1,6 @@
import os
import tempfile
from typing import Literal, Optional
from typing import Optional
from moviepy.audio.io.AudioFileClip import AudioFileClip
from moviepy.video.fx.Loop import Loop
@@ -13,6 +13,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
@@ -46,18 +47,19 @@ class MediaDurationBlock(Block):
self,
input_data: Input,
*,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# 1) Store the input media locally
local_media_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.media_in,
user_id=user_id,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
assert execution_context.graph_exec_id is not None
media_abspath = get_exec_file_path(
execution_context.graph_exec_id, local_media_path
)
media_abspath = get_exec_file_path(graph_exec_id, local_media_path)
# 2) Load the clip
if input_data.is_video:
@@ -88,10 +90,6 @@ class LoopVideoBlock(Block):
default=None,
ge=1,
)
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
description="How to return the output video. Either a relative path or base64 data URI.",
default="file_path",
)
class Output(BlockSchemaOutput):
video_out: str = SchemaField(
@@ -111,17 +109,19 @@ class LoopVideoBlock(Block):
self,
input_data: Input,
*,
node_exec_id: str,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
assert execution_context.graph_exec_id is not None
assert execution_context.node_exec_id is not None
graph_exec_id = execution_context.graph_exec_id
node_exec_id = execution_context.node_exec_id
# 1) Store the input video locally
local_video_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.video_in,
user_id=user_id,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
@@ -149,12 +149,11 @@ class LoopVideoBlock(Block):
looped_clip = looped_clip.with_audio(clip.audio)
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
# Return as data URI
# Return output - for_block_output returns workspace:// if available, else data URI
video_out = await store_media_file(
graph_exec_id=graph_exec_id,
file=output_filename,
user_id=user_id,
return_content=input_data.output_return_type == "data_uri",
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_out", video_out
@@ -177,10 +176,6 @@ class AddAudioToVideoBlock(Block):
description="Volume scale for the newly attached audio track (1.0 = original).",
default=1.0,
)
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
description="Return the final output as a relative path or base64 data URI.",
default="file_path",
)
class Output(BlockSchemaOutput):
video_out: MediaFileType = SchemaField(
@@ -200,23 +195,24 @@ class AddAudioToVideoBlock(Block):
self,
input_data: Input,
*,
node_exec_id: str,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
assert execution_context.graph_exec_id is not None
assert execution_context.node_exec_id is not None
graph_exec_id = execution_context.graph_exec_id
node_exec_id = execution_context.node_exec_id
# 1) Store the inputs locally
local_video_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.video_in,
user_id=user_id,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
local_audio_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.audio_in,
user_id=user_id,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
@@ -240,12 +236,11 @@ class AddAudioToVideoBlock(Block):
output_abspath = os.path.join(abs_temp_dir, output_filename)
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
# 5) Return either path or data URI
# 5) Return output - for_block_output returns workspace:// if available, else data URI
video_out = await store_media_file(
graph_exec_id=graph_exec_id,
file=output_filename,
user_id=user_id,
return_content=input_data.output_return_type == "data_uri",
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_out", video_out

View File

@@ -11,6 +11,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -112,8 +113,7 @@ class ScreenshotWebPageBlock(Block):
@staticmethod
async def take_screenshot(
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
url: str,
viewport_width: int,
viewport_height: int,
@@ -155,12 +155,11 @@ class ScreenshotWebPageBlock(Block):
return {
"image": await store_media_file(
graph_exec_id=graph_exec_id,
file=MediaFileType(
f"data:image/{format.value};base64,{b64encode(content).decode('utf-8')}"
),
user_id=user_id,
return_content=True,
execution_context=execution_context,
return_format="for_block_output",
)
}
@@ -169,15 +168,13 @@ class ScreenshotWebPageBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
try:
screenshot_data = await self.take_screenshot(
credentials=credentials,
graph_exec_id=graph_exec_id,
user_id=user_id,
execution_context=execution_context,
url=input_data.url,
viewport_width=input_data.viewport_width,
viewport_height=input_data.viewport_height,

View File

@@ -7,6 +7,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import ContributorDetails, SchemaField
from backend.util.file import get_exec_file_path, store_media_file
from backend.util.type import MediaFileType
@@ -98,7 +99,7 @@ class ReadSpreadsheetBlock(Block):
)
async def run(
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
) -> BlockOutput:
import csv
from io import StringIO
@@ -106,14 +107,16 @@ class ReadSpreadsheetBlock(Block):
# Determine data source - prefer file_input if provided, otherwise use contents
if input_data.file_input:
stored_file_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=input_data.file_input,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
# Get full file path
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
assert execution_context.graph_exec_id # Validated by store_media_file
file_path = get_exec_file_path(
execution_context.graph_exec_id, stored_file_path
)
if not Path(file_path).exists():
raise ValueError(f"File does not exist: {file_path}")

View File

@@ -10,6 +10,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -17,7 +18,9 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.file import store_media_file
from backend.util.request import Requests
from backend.util.type import MediaFileType
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@@ -102,7 +105,7 @@ class CreateTalkingAvatarVideoBlock(Block):
test_output=[
(
"video_url",
"https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
lambda x: x.startswith(("workspace://", "data:")),
),
],
test_mock={
@@ -110,9 +113,10 @@ class CreateTalkingAvatarVideoBlock(Block):
"id": "abcd1234-5678-efgh-ijkl-mnopqrstuvwx",
"status": "created",
},
# Use data URI to avoid HTTP requests during tests
"get_clip_status": lambda *args, **kwargs: {
"status": "done",
"result_url": "https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
"result_url": "data:video/mp4;base64,AAAA",
},
},
test_credentials=TEST_CREDENTIALS,
@@ -138,7 +142,12 @@ class CreateTalkingAvatarVideoBlock(Block):
return response.json()
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# Create the clip
payload = {
@@ -165,7 +174,14 @@ class CreateTalkingAvatarVideoBlock(Block):
for _ in range(input_data.max_polling_attempts):
status_response = await self.get_clip_status(credentials.api_key, clip_id)
if status_response["status"] == "done":
yield "video_url", status_response["result_url"]
# Store the generated video to the user's workspace for persistence
video_url = status_response["result_url"]
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
return
elif status_response["status"] == "error":
raise RuntimeError(

View File

@@ -12,6 +12,7 @@ from backend.blocks.iteration import StepThroughItemsBlock
from backend.blocks.llm import AITextSummarizerBlock
from backend.blocks.text import ExtractTextInformationBlock
from backend.blocks.xml_parser import XMLParserBlock
from backend.data.execution import ExecutionContext
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
@@ -233,9 +234,12 @@ class TestStoreMediaFileSecurity:
with pytest.raises(ValueError, match="File too large"):
await store_media_file(
graph_exec_id="test",
file=MediaFileType(large_data_uri),
user_id="test_user",
execution_context=ExecutionContext(
user_id="test_user",
graph_exec_id="test",
),
return_format="for_local_processing",
)
@patch("backend.util.file.Path")
@@ -270,9 +274,12 @@ class TestStoreMediaFileSecurity:
# Should raise an error when directory size exceeds limit
with pytest.raises(ValueError, match="Disk usage limit exceeded"):
await store_media_file(
graph_exec_id="test",
file=MediaFileType(
"data:text/plain;base64,dGVzdA=="
), # Small test file
user_id="test_user",
execution_context=ExecutionContext(
user_id="test_user",
graph_exec_id="test",
),
return_format="for_local_processing",
)

View File

@@ -11,10 +11,22 @@ from backend.blocks.http import (
HttpMethod,
SendAuthenticatedWebRequestBlock,
)
from backend.data.execution import ExecutionContext
from backend.data.model import HostScopedCredentials
from backend.util.request import Response
def make_test_context(
graph_exec_id: str = "test-exec-id",
user_id: str = "test-user-id",
) -> ExecutionContext:
"""Helper to create test ExecutionContext."""
return ExecutionContext(
user_id=user_id,
graph_exec_id=graph_exec_id,
)
class TestHttpBlockWithHostScopedCredentials:
"""Test suite for HTTP block integration with HostScopedCredentials."""
@@ -105,8 +117,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=exact_match_credentials,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -161,8 +172,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=wildcard_credentials,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -208,8 +218,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=non_matching_credentials,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -258,8 +267,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=exact_match_credentials,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -318,8 +326,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=auto_discovered_creds, # Execution manager found these
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -382,8 +389,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=multi_header_creds,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -471,8 +477,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=test_creds,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))

View File

@@ -11,6 +11,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util import json, text
from backend.util.file import get_exec_file_path, store_media_file
@@ -444,18 +445,21 @@ class FileReadBlock(Block):
)
async def run(
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
) -> BlockOutput:
# Store the media file properly (handles URLs, data URIs, etc.)
stored_file_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=input_data.file_input,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
# Get full file path
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
# Get full file path (graph_exec_id validated by store_media_file above)
if not execution_context.graph_exec_id:
raise ValueError("execution_context.graph_exec_id is required")
file_path = get_exec_file_path(
execution_context.graph_exec_id, stored_file_path
)
if not Path(file_path).exists():
raise ValueError(f"File does not exist: {file_path}")

View File

@@ -83,12 +83,29 @@ class ExecutionContext(BaseModel):
model_config = {"extra": "ignore"}
# Execution identity
user_id: Optional[str] = None
graph_id: Optional[str] = None
graph_exec_id: Optional[str] = None
graph_version: Optional[int] = None
node_id: Optional[str] = None
node_exec_id: Optional[str] = None
# Safety settings
human_in_the_loop_safe_mode: bool = True
sensitive_action_safe_mode: bool = False
# User settings
user_timezone: str = "UTC"
# Execution hierarchy
root_execution_id: Optional[str] = None
parent_execution_id: Optional[str] = None
# Workspace
workspace_id: Optional[str] = None
session_id: Optional[str] = None
# -------------------------- Models -------------------------- #

View File

@@ -0,0 +1,276 @@
"""
Database CRUD operations for User Workspace.
This module provides functions for managing user workspaces and workspace files.
"""
import logging
from datetime import datetime, timezone
from typing import Optional
from prisma.models import UserWorkspace, UserWorkspaceFile
from prisma.types import UserWorkspaceFileWhereInput
from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
async def get_or_create_workspace(user_id: str) -> UserWorkspace:
"""
Get user's workspace, creating one if it doesn't exist.
Uses upsert to handle race conditions when multiple concurrent requests
attempt to create a workspace for the same user.
Args:
user_id: The user's ID
Returns:
UserWorkspace instance
"""
workspace = await UserWorkspace.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id},
"update": {}, # No updates needed if exists
},
)
return workspace
async def get_workspace(user_id: str) -> Optional[UserWorkspace]:
"""
Get user's workspace if it exists.
Args:
user_id: The user's ID
Returns:
UserWorkspace instance or None
"""
return await UserWorkspace.prisma().find_unique(where={"userId": user_id})
async def create_workspace_file(
workspace_id: str,
file_id: str,
name: str,
path: str,
storage_path: str,
mime_type: str,
size_bytes: int,
checksum: Optional[str] = None,
metadata: Optional[dict] = None,
) -> UserWorkspaceFile:
"""
Create a new workspace file record.
Args:
workspace_id: The workspace ID
file_id: The file ID (same as used in storage path for consistency)
name: User-visible filename
path: Virtual path (e.g., "/documents/report.pdf")
storage_path: Actual storage path (GCS or local)
mime_type: MIME type of the file
size_bytes: File size in bytes
checksum: Optional SHA256 checksum
metadata: Optional additional metadata
Returns:
Created UserWorkspaceFile instance
"""
# Normalize path to start with /
if not path.startswith("/"):
path = f"/{path}"
file = await UserWorkspaceFile.prisma().create(
data={
"id": file_id,
"workspaceId": workspace_id,
"name": name,
"path": path,
"storagePath": storage_path,
"mimeType": mime_type,
"sizeBytes": size_bytes,
"checksum": checksum,
"metadata": SafeJson(metadata or {}),
}
)
logger.info(
f"Created workspace file {file.id} at path {path} "
f"in workspace {workspace_id}"
)
return file
async def get_workspace_file(
file_id: str,
workspace_id: Optional[str] = None,
) -> Optional[UserWorkspaceFile]:
"""
Get a workspace file by ID.
Args:
file_id: The file ID
workspace_id: Optional workspace ID for validation
Returns:
UserWorkspaceFile instance or None
"""
where_clause: dict = {"id": file_id, "isDeleted": False}
if workspace_id:
where_clause["workspaceId"] = workspace_id
return await UserWorkspaceFile.prisma().find_first(where=where_clause)
async def get_workspace_file_by_path(
workspace_id: str,
path: str,
) -> Optional[UserWorkspaceFile]:
"""
Get a workspace file by its virtual path.
Args:
workspace_id: The workspace ID
path: Virtual path
Returns:
UserWorkspaceFile instance or None
"""
# Normalize path
if not path.startswith("/"):
path = f"/{path}"
return await UserWorkspaceFile.prisma().find_first(
where={
"workspaceId": workspace_id,
"path": path,
"isDeleted": False,
}
)
async def list_workspace_files(
workspace_id: str,
path_prefix: Optional[str] = None,
include_deleted: bool = False,
limit: Optional[int] = None,
offset: int = 0,
) -> list[UserWorkspaceFile]:
"""
List files in a workspace.
Args:
workspace_id: The workspace ID
path_prefix: Optional path prefix to filter (e.g., "/documents/")
include_deleted: Whether to include soft-deleted files
limit: Maximum number of files to return
offset: Number of files to skip
Returns:
List of UserWorkspaceFile instances
"""
where_clause: UserWorkspaceFileWhereInput = {"workspaceId": workspace_id}
if not include_deleted:
where_clause["isDeleted"] = False
if path_prefix:
# Normalize prefix
if not path_prefix.startswith("/"):
path_prefix = f"/{path_prefix}"
where_clause["path"] = {"startswith": path_prefix}
return await UserWorkspaceFile.prisma().find_many(
where=where_clause,
order={"createdAt": "desc"},
take=limit,
skip=offset,
)
async def count_workspace_files(
workspace_id: str,
path_prefix: Optional[str] = None,
include_deleted: bool = False,
) -> int:
"""
Count files in a workspace.
Args:
workspace_id: The workspace ID
path_prefix: Optional path prefix to filter (e.g., "/sessions/abc123/")
include_deleted: Whether to include soft-deleted files
Returns:
Number of files
"""
where_clause: dict = {"workspaceId": workspace_id}
if not include_deleted:
where_clause["isDeleted"] = False
if path_prefix:
# Normalize prefix
if not path_prefix.startswith("/"):
path_prefix = f"/{path_prefix}"
where_clause["path"] = {"startswith": path_prefix}
return await UserWorkspaceFile.prisma().count(where=where_clause)
async def soft_delete_workspace_file(
file_id: str,
workspace_id: Optional[str] = None,
) -> Optional[UserWorkspaceFile]:
"""
Soft-delete a workspace file.
The path is modified to include a deletion timestamp to free up the original
path for new files while preserving the record for potential recovery.
Args:
file_id: The file ID
workspace_id: Optional workspace ID for validation
Returns:
Updated UserWorkspaceFile instance or None if not found
"""
# First verify the file exists and belongs to workspace
file = await get_workspace_file(file_id, workspace_id)
if file is None:
return None
deleted_at = datetime.now(timezone.utc)
# Modify path to free up the unique constraint for new files at original path
# Format: {original_path}__deleted__{timestamp}
deleted_path = f"{file.path}__deleted__{int(deleted_at.timestamp())}"
updated = await UserWorkspaceFile.prisma().update(
where={"id": file_id},
data={
"isDeleted": True,
"deletedAt": deleted_at,
"path": deleted_path,
},
)
logger.info(f"Soft-deleted workspace file {file_id}")
return updated
async def get_workspace_total_size(workspace_id: str) -> int:
"""
Get the total size of all files in a workspace.
Args:
workspace_id: The workspace ID
Returns:
Total size in bytes
"""
files = await list_workspace_files(workspace_id)
return sum(file.sizeBytes for file in files)

View File

@@ -236,7 +236,14 @@ async def execute_node(
input_size = len(input_data_str)
log_metadata.debug("Executed node with input", input=input_data_str)
# Create node-specific execution context to avoid race conditions
# (multiple nodes can execute concurrently and would otherwise mutate shared state)
execution_context = execution_context.model_copy(
update={"node_id": node_id, "node_exec_id": node_exec_id}
)
# Inject extra execution arguments for the blocks via kwargs
# Keep individual kwargs for backwards compatibility with existing blocks
extra_exec_kwargs: dict = {
"graph_id": graph_id,
"graph_version": graph_version,

View File

@@ -892,11 +892,19 @@ async def add_graph_execution(
settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id)
execution_context = ExecutionContext(
# Execution identity
user_id=user_id,
graph_id=graph_id,
graph_exec_id=graph_exec.id,
graph_version=graph_exec.graph_version,
# Safety settings
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
# User settings
user_timezone=(
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
),
# Execution hierarchy
root_execution_id=graph_exec.id,
)

View File

@@ -348,6 +348,7 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
mock_graph_exec.id = "execution-id-123"
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
mock_graph_exec.graph_version = graph_version
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
# Mock the queue and event bus
@@ -434,6 +435,9 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
# Create a second mock execution for the sanity check
mock_graph_exec_2 = mocker.MagicMock(spec=GraphExecutionWithNodes)
mock_graph_exec_2.id = "execution-id-456"
mock_graph_exec_2.node_executions = []
mock_graph_exec_2.status = ExecutionStatus.QUEUED
mock_graph_exec_2.graph_version = graph_version
mock_graph_exec_2.to_graph_execution_entry.return_value = mocker.MagicMock()
# Reset mocks and set up for second call
@@ -614,6 +618,7 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
mock_graph_exec.id = "execution-id-123"
mock_graph_exec.node_executions = []
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
mock_graph_exec.graph_version = graph_version
# Track what's passed to to_graph_execution_entry
captured_kwargs = {}

View File

@@ -13,6 +13,7 @@ import aiohttp
from gcloud.aio import storage as async_gcs_storage
from google.cloud import storage as gcs_storage
from backend.util.gcs_utils import download_with_fresh_session, generate_signed_url
from backend.util.settings import Config
logger = logging.getLogger(__name__)
@@ -251,7 +252,7 @@ class CloudStorageHandler:
f"in_task: {current_task is not None}"
)
# Parse bucket and blob name from path
# Parse bucket and blob name from path (path already has gcs:// prefix removed)
parts = path.split("/", 1)
if len(parts) != 2:
raise ValueError(f"Invalid GCS path: {path}")
@@ -261,50 +262,19 @@ class CloudStorageHandler:
# Authorization check
self._validate_file_access(blob_name, user_id, graph_exec_id)
# Use a fresh client for each download to avoid session issues
# This is less efficient but more reliable with the executor's event loop
logger.info("[CloudStorage] Creating fresh GCS client for download")
# Create a new session specifically for this download
session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=10, force_close=True)
logger.info(
f"[CloudStorage] About to download from GCS - bucket: {bucket_name}, blob: {blob_name}"
)
async_client = None
try:
# Create a new GCS client with the fresh session
async_client = async_gcs_storage.Storage(session=session)
logger.info(
f"[CloudStorage] About to download from GCS - bucket: {bucket_name}, blob: {blob_name}"
)
# Download content using the fresh client
content = await async_client.download(bucket_name, blob_name)
content = await download_with_fresh_session(bucket_name, blob_name)
logger.info(
f"[CloudStorage] GCS download successful - size: {len(content)} bytes"
)
# Clean up
await async_client.close()
await session.close()
return content
except FileNotFoundError:
raise
except Exception as e:
# Always try to clean up
if async_client is not None:
try:
await async_client.close()
except Exception as cleanup_error:
logger.warning(
f"[CloudStorage] Error closing GCS client: {cleanup_error}"
)
try:
await session.close()
except Exception as cleanup_error:
logger.warning(f"[CloudStorage] Error closing session: {cleanup_error}")
# Log the specific error for debugging
logger.error(
f"[CloudStorage] GCS download failed - error: {str(e)}, "
@@ -319,10 +289,6 @@ class CloudStorageHandler:
f"current_task: {current_task}, "
f"bucket: {bucket_name}, blob: redacted for privacy"
)
# Convert gcloud-aio exceptions to standard ones
if "404" in str(e) or "Not Found" in str(e):
raise FileNotFoundError(f"File not found: gcs://{path}")
raise
def _validate_file_access(
@@ -445,8 +411,7 @@ class CloudStorageHandler:
graph_exec_id: str | None = None,
) -> str:
"""Generate signed URL for GCS with authorization."""
# Parse bucket and blob name from path
# Parse bucket and blob name from path (path already has gcs:// prefix removed)
parts = path.split("/", 1)
if len(parts) != 2:
raise ValueError(f"Invalid GCS path: {path}")
@@ -456,21 +421,11 @@ class CloudStorageHandler:
# Authorization check
self._validate_file_access(blob_name, user_id, graph_exec_id)
# Use sync client for signed URLs since gcloud-aio doesn't support them
sync_client = self._get_sync_gcs_client()
bucket = sync_client.bucket(bucket_name)
blob = bucket.blob(blob_name)
# Generate signed URL asynchronously using sync client
url = await asyncio.to_thread(
blob.generate_signed_url,
version="v4",
expiration=datetime.now(timezone.utc) + timedelta(hours=expiration_hours),
method="GET",
return await generate_signed_url(
sync_client, bucket_name, blob_name, expiration_hours * 3600
)
return url
async def delete_expired_files(self, provider: str = "gcs") -> int:
"""
Delete files that have passed their expiration time.

View File

@@ -5,13 +5,26 @@ import shutil
import tempfile
import uuid
from pathlib import Path
from typing import TYPE_CHECKING, Literal
from urllib.parse import urlparse
from backend.util.cloud_storage import get_cloud_storage_handler
from backend.util.request import Requests
from backend.util.settings import Config
from backend.util.type import MediaFileType
from backend.util.virus_scanner import scan_content_safe
if TYPE_CHECKING:
from backend.data.execution import ExecutionContext
# Return format options for store_media_file
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
# - "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs
MediaReturnFormat = Literal[
"for_local_processing", "for_external_api", "for_block_output"
]
TEMP_DIR = Path(tempfile.gettempdir()).resolve()
# Maximum filename length (conservative limit for most filesystems)
@@ -67,42 +80,56 @@ def clean_exec_files(graph_exec_id: str, file: str = "") -> None:
async def store_media_file(
graph_exec_id: str,
file: MediaFileType,
user_id: str,
return_content: bool = False,
execution_context: "ExecutionContext",
*,
return_format: MediaReturnFormat,
) -> MediaFileType:
"""
Safely handle 'file' (a data URI, a URL, or a local path relative to {temp}/exec_file/{exec_id}),
placing or verifying it under:
Safely handle 'file' (a data URI, a URL, a workspace:// reference, or a local path
relative to {temp}/exec_file/{exec_id}), placing or verifying it under:
{tempdir}/exec_file/{exec_id}/...
If 'return_content=True', return a data URI (data:<mime>;base64,<content>).
Otherwise, returns the file media path relative to the exec_id folder.
For each MediaFileType input:
- Data URI: decode and store locally
- URL: download and store locally
- workspace:// reference: read from workspace, store locally
- Local path: verify it exists in exec_file directory
For each MediaFileType type:
- Data URI:
-> decode and store in a new random file in that folder
- URL:
-> download and store in that folder
- Local path:
-> interpret as relative to that folder; verify it exists
(no copying, as it's presumably already there).
We realpath-check so no symlink or '..' can escape the folder.
Return format options:
- "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
- "for_external_api": Returns data URI (base64) - use when sending to external APIs
- "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs
:param graph_exec_id: The unique ID of the graph execution.
:param file: Data URI, URL, or local (relative) path.
:param return_content: If True, return a data URI of the file content.
If False, return the *relative* path inside the exec_id folder.
:return: The requested result: data URI or relative path of the media.
:param file: Data URI, URL, workspace://, or local (relative) path.
:param execution_context: ExecutionContext with user_id, graph_exec_id, workspace_id.
:param return_format: What to return: "for_local_processing", "for_external_api", or "for_block_output".
:return: The requested result based on return_format.
"""
# Extract values from execution_context
graph_exec_id = execution_context.graph_exec_id
user_id = execution_context.user_id
if not graph_exec_id:
raise ValueError("execution_context.graph_exec_id is required")
if not user_id:
raise ValueError("execution_context.user_id is required")
# Create workspace_manager if we have workspace_id (with session scoping)
# Import here to avoid circular import (file.py → workspace.py → data → blocks → file.py)
from backend.util.workspace import WorkspaceManager
workspace_manager: WorkspaceManager | None = None
if execution_context.workspace_id:
workspace_manager = WorkspaceManager(
user_id, execution_context.workspace_id, execution_context.session_id
)
# Build base path
base_path = Path(get_exec_file_path(graph_exec_id, ""))
base_path.mkdir(parents=True, exist_ok=True)
# Security fix: Add disk space limits to prevent DoS
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB per file
MAX_FILE_SIZE_BYTES = Config().max_file_size_mb * 1024 * 1024
MAX_TOTAL_DISK_USAGE = 1024 * 1024 * 1024 # 1GB total per execution directory
# Check total disk usage in base_path
@@ -142,9 +169,57 @@ async def store_media_file(
"""
return str(absolute_path.relative_to(base))
# Check if this is a cloud storage path
# Get cloud storage handler for checking cloud paths
cloud_storage = await get_cloud_storage_handler()
if cloud_storage.is_cloud_path(file):
# Track if the input came from workspace (don't re-save it)
is_from_workspace = file.startswith("workspace://")
# Check if this is a workspace file reference
if is_from_workspace:
if workspace_manager is None:
raise ValueError(
"Workspace file reference requires workspace context. "
"This file type is only available in CoPilot sessions."
)
# Parse workspace reference
# workspace://abc123 - by file ID
# workspace:///path/to/file.txt - by virtual path
file_ref = file[12:] # Remove "workspace://"
if file_ref.startswith("/"):
# Path reference
workspace_content = await workspace_manager.read_file(file_ref)
file_info = await workspace_manager.get_file_info_by_path(file_ref)
filename = sanitize_filename(
file_info.name if file_info else f"{uuid.uuid4()}.bin"
)
else:
# ID reference
workspace_content = await workspace_manager.read_file_by_id(file_ref)
file_info = await workspace_manager.get_file_info(file_ref)
filename = sanitize_filename(
file_info.name if file_info else f"{uuid.uuid4()}.bin"
)
try:
target_path = _ensure_inside_base(base_path / filename, base_path)
except OSError as e:
raise ValueError(f"Invalid file path '{filename}': {e}") from e
# Check file size limit
if len(workspace_content) > MAX_FILE_SIZE_BYTES:
raise ValueError(
f"File too large: {len(workspace_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
)
# Virus scan the workspace content before writing locally
await scan_content_safe(workspace_content, filename=filename)
target_path.write_bytes(workspace_content)
# Check if this is a cloud storage path
elif cloud_storage.is_cloud_path(file):
# Download from cloud storage and store locally
cloud_content = await cloud_storage.retrieve_file(
file, user_id=user_id, graph_exec_id=graph_exec_id
@@ -159,9 +234,9 @@ async def store_media_file(
raise ValueError(f"Invalid file path '{filename}': {e}") from e
# Check file size limit
if len(cloud_content) > MAX_FILE_SIZE:
if len(cloud_content) > MAX_FILE_SIZE_BYTES:
raise ValueError(
f"File too large: {len(cloud_content)} bytes > {MAX_FILE_SIZE} bytes"
f"File too large: {len(cloud_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
)
# Virus scan the cloud content before writing locally
@@ -189,9 +264,9 @@ async def store_media_file(
content = base64.b64decode(b64_content)
# Check file size limit
if len(content) > MAX_FILE_SIZE:
if len(content) > MAX_FILE_SIZE_BYTES:
raise ValueError(
f"File too large: {len(content)} bytes > {MAX_FILE_SIZE} bytes"
f"File too large: {len(content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
)
# Virus scan the base64 content before writing
@@ -199,23 +274,31 @@ async def store_media_file(
target_path.write_bytes(content)
elif file.startswith(("http://", "https://")):
# URL
# URL - download first to get Content-Type header
resp = await Requests().get(file)
# Check file size limit
if len(resp.content) > MAX_FILE_SIZE_BYTES:
raise ValueError(
f"File too large: {len(resp.content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
)
# Extract filename from URL path
parsed_url = urlparse(file)
filename = sanitize_filename(Path(parsed_url.path).name or f"{uuid.uuid4()}")
# If filename lacks extension, add one from Content-Type header
if "." not in filename:
content_type = resp.headers.get("Content-Type", "").split(";")[0].strip()
if content_type:
ext = _extension_from_mime(content_type)
filename = f"{filename}{ext}"
try:
target_path = _ensure_inside_base(base_path / filename, base_path)
except OSError as e:
raise ValueError(f"Invalid file path '{filename}': {e}") from e
# Download and save
resp = await Requests().get(file)
# Check file size limit
if len(resp.content) > MAX_FILE_SIZE:
raise ValueError(
f"File too large: {len(resp.content)} bytes > {MAX_FILE_SIZE} bytes"
)
# Virus scan the downloaded content before writing
await scan_content_safe(resp.content, filename=filename)
target_path.write_bytes(resp.content)
@@ -230,12 +313,44 @@ async def store_media_file(
if not target_path.is_file():
raise ValueError(f"Local file does not exist: {target_path}")
# Return result
if return_content:
return MediaFileType(_file_to_data_uri(target_path))
else:
# Return based on requested format
if return_format == "for_local_processing":
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL
# Returns: relative path in exec_file directory (e.g., "image.png")
return MediaFileType(_strip_base_prefix(target_path, base_path))
elif return_format == "for_external_api":
# Use when sending content to external APIs that need base64
# Returns: data URI (e.g., "...")
return MediaFileType(_file_to_data_uri(target_path))
elif return_format == "for_block_output":
# Use when returning output from a block to user/next block
# Returns: workspace:// ref (CoPilot) or data URI (graph execution)
if workspace_manager is None:
# No workspace available (graph execution without CoPilot)
# Fallback to data URI so the content can still be used/displayed
return MediaFileType(_file_to_data_uri(target_path))
# Don't re-save if input was already from workspace
if is_from_workspace:
# Return original workspace reference
return MediaFileType(file)
# Save new content to workspace
content = target_path.read_bytes()
filename = target_path.name
file_record = await workspace_manager.write_file(
content=content,
filename=filename,
overwrite=True,
)
return MediaFileType(f"workspace://{file_record.id}")
else:
raise ValueError(f"Invalid return_format: {return_format}")
def get_dir_size(path: Path) -> int:
"""Get total size of directory."""

View File

@@ -7,10 +7,22 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.data.execution import ExecutionContext
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
def make_test_context(
graph_exec_id: str = "test-exec-123",
user_id: str = "test-user-123",
) -> ExecutionContext:
"""Helper to create test ExecutionContext."""
return ExecutionContext(
user_id=user_id,
graph_exec_id=graph_exec_id,
)
class TestFileCloudIntegration:
"""Test cases for cloud storage integration in file utilities."""
@@ -70,10 +82,9 @@ class TestFileCloudIntegration:
mock_path_class.side_effect = path_constructor
result = await store_media_file(
graph_exec_id,
MediaFileType(cloud_path),
"test-user-123",
return_content=False,
file=MediaFileType(cloud_path),
execution_context=make_test_context(graph_exec_id=graph_exec_id),
return_format="for_local_processing",
)
# Verify cloud storage operations
@@ -144,10 +155,9 @@ class TestFileCloudIntegration:
mock_path_obj.name = "image.png"
with patch("backend.util.file.Path", return_value=mock_path_obj):
result = await store_media_file(
graph_exec_id,
MediaFileType(cloud_path),
"test-user-123",
return_content=True,
file=MediaFileType(cloud_path),
execution_context=make_test_context(graph_exec_id=graph_exec_id),
return_format="for_external_api",
)
# Verify result is a data URI
@@ -198,10 +208,9 @@ class TestFileCloudIntegration:
mock_resolved_path.relative_to.return_value = Path("test-uuid-789.txt")
await store_media_file(
graph_exec_id,
MediaFileType(data_uri),
"test-user-123",
return_content=False,
file=MediaFileType(data_uri),
execution_context=make_test_context(graph_exec_id=graph_exec_id),
return_format="for_local_processing",
)
# Verify cloud handler was checked but not used for retrieval
@@ -234,5 +243,7 @@ class TestFileCloudIntegration:
FileNotFoundError, match="File not found in cloud storage"
):
await store_media_file(
graph_exec_id, MediaFileType(cloud_path), "test-user-123"
file=MediaFileType(cloud_path),
execution_context=make_test_context(graph_exec_id=graph_exec_id),
return_format="for_local_processing",
)

View File

@@ -0,0 +1,108 @@
"""
Shared GCS utilities for workspace and cloud storage backends.
This module provides common functionality for working with Google Cloud Storage,
including path parsing, client management, and signed URL generation.
"""
import asyncio
import logging
from datetime import datetime, timedelta, timezone
import aiohttp
from gcloud.aio import storage as async_gcs_storage
from google.cloud import storage as gcs_storage
logger = logging.getLogger(__name__)
def parse_gcs_path(path: str) -> tuple[str, str]:
"""
Parse a GCS path in the format 'gcs://bucket/blob' to (bucket, blob).
Args:
path: GCS path string (e.g., "gcs://my-bucket/path/to/file")
Returns:
Tuple of (bucket_name, blob_name)
Raises:
ValueError: If the path format is invalid
"""
if not path.startswith("gcs://"):
raise ValueError(f"Invalid GCS path: {path}")
path_without_prefix = path[6:] # Remove "gcs://"
parts = path_without_prefix.split("/", 1)
if len(parts) != 2:
raise ValueError(f"Invalid GCS path format: {path}")
return parts[0], parts[1]
async def download_with_fresh_session(bucket: str, blob: str) -> bytes:
"""
Download file content using a fresh session.
This approach avoids event loop issues that can occur when reusing
sessions across different async contexts (e.g., in executors).
Args:
bucket: GCS bucket name
blob: Blob path within the bucket
Returns:
File content as bytes
Raises:
FileNotFoundError: If the file doesn't exist
"""
session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=10, force_close=True)
)
client: async_gcs_storage.Storage | None = None
try:
client = async_gcs_storage.Storage(session=session)
content = await client.download(bucket, blob)
return content
except Exception as e:
if "404" in str(e) or "Not Found" in str(e):
raise FileNotFoundError(f"File not found: gcs://{bucket}/{blob}")
raise
finally:
if client:
try:
await client.close()
except Exception:
pass # Best-effort cleanup
await session.close()
async def generate_signed_url(
sync_client: gcs_storage.Client,
bucket_name: str,
blob_name: str,
expires_in: int,
) -> str:
"""
Generate a signed URL for temporary access to a GCS file.
Uses asyncio.to_thread() to run the sync operation without blocking.
Args:
sync_client: Sync GCS client with service account credentials
bucket_name: GCS bucket name
blob_name: Blob path within the bucket
expires_in: URL expiration time in seconds
Returns:
Signed URL string
"""
bucket = sync_client.bucket(bucket_name)
blob = bucket.blob(blob_name)
return await asyncio.to_thread(
blob.generate_signed_url,
version="v4",
expiration=datetime.now(timezone.utc) + timedelta(seconds=expires_in),
method="GET",
)

View File

@@ -263,6 +263,12 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="The name of the Google Cloud Storage bucket for media files",
)
workspace_storage_dir: str = Field(
default="",
description="Local directory for workspace file storage when GCS is not configured. "
"If empty, defaults to {app_data}/workspaces. Used for self-hosted deployments.",
)
reddit_user_agent: str = Field(
default="web:AutoGPT:v0.6.0 (by /u/autogpt)",
description="The user agent for the Reddit API",
@@ -389,6 +395,13 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="Maximum file size in MB for file uploads (1-1024 MB)",
)
max_file_size_mb: int = Field(
default=100,
ge=1,
le=1024,
description="Maximum file size in MB for workspace files (1-1024 MB)",
)
# AutoMod configuration
automod_enabled: bool = Field(
default=False,

View File

@@ -140,14 +140,29 @@ async def execute_block_test(block: Block):
setattr(block, mock_name, mock_obj)
# Populate credentials argument(s)
# Generate IDs for execution context
graph_id = str(uuid.uuid4())
node_id = str(uuid.uuid4())
graph_exec_id = str(uuid.uuid4())
node_exec_id = str(uuid.uuid4())
user_id = str(uuid.uuid4())
graph_version = 1 # Default version for tests
extra_exec_kwargs: dict = {
"graph_id": str(uuid.uuid4()),
"node_id": str(uuid.uuid4()),
"graph_exec_id": str(uuid.uuid4()),
"node_exec_id": str(uuid.uuid4()),
"user_id": str(uuid.uuid4()),
"graph_version": 1, # Default version for tests
"execution_context": ExecutionContext(),
"graph_id": graph_id,
"node_id": node_id,
"graph_exec_id": graph_exec_id,
"node_exec_id": node_exec_id,
"user_id": user_id,
"graph_version": graph_version,
"execution_context": ExecutionContext(
user_id=user_id,
graph_id=graph_id,
graph_exec_id=graph_exec_id,
graph_version=graph_version,
node_id=node_id,
node_exec_id=node_exec_id,
),
}
input_model = cast(type[BlockSchema], block.input_schema)

View File

@@ -0,0 +1,419 @@
"""
WorkspaceManager for managing user workspace file operations.
This module provides a high-level interface for workspace file operations,
combining the storage backend and database layer.
"""
import logging
import mimetypes
import uuid
from typing import Optional
from prisma.errors import UniqueViolationError
from prisma.models import UserWorkspaceFile
from backend.data.workspace import (
count_workspace_files,
create_workspace_file,
get_workspace_file,
get_workspace_file_by_path,
list_workspace_files,
soft_delete_workspace_file,
)
from backend.util.settings import Config
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
logger = logging.getLogger(__name__)
class WorkspaceManager:
"""
Manages workspace file operations.
Combines storage backend operations with database record management.
Supports session-scoped file segmentation where files are stored in
session-specific virtual paths: /sessions/{session_id}/{filename}
"""
def __init__(
self, user_id: str, workspace_id: str, session_id: Optional[str] = None
):
"""
Initialize WorkspaceManager.
Args:
user_id: The user's ID
workspace_id: The workspace ID
session_id: Optional session ID for session-scoped file access
"""
self.user_id = user_id
self.workspace_id = workspace_id
self.session_id = session_id
# Session path prefix for file isolation
self.session_path = f"/sessions/{session_id}" if session_id else ""
def _resolve_path(self, path: str) -> str:
"""
Resolve a path, defaulting to session folder if session_id is set.
Cross-session access is allowed by explicitly using /sessions/other-session-id/...
Args:
path: Virtual path (e.g., "/file.txt" or "/sessions/abc123/file.txt")
Returns:
Resolved path with session prefix if applicable
"""
# If path explicitly references a session folder, use it as-is
if path.startswith("/sessions/"):
return path
# If we have a session context, prepend session path
if self.session_path:
# Normalize the path
if not path.startswith("/"):
path = f"/{path}"
return f"{self.session_path}{path}"
# No session context, use path as-is
return path if path.startswith("/") else f"/{path}"
def _get_effective_path(
self, path: Optional[str], include_all_sessions: bool
) -> Optional[str]:
"""
Get effective path for list/count operations based on session context.
Args:
path: Optional path prefix to filter
include_all_sessions: If True, don't apply session scoping
Returns:
Effective path prefix for database query
"""
if include_all_sessions:
# Normalize path to ensure leading slash (stored paths are normalized)
if path is not None and not path.startswith("/"):
return f"/{path}"
return path
elif path is not None:
# Resolve the provided path with session scoping
return self._resolve_path(path)
elif self.session_path:
# Default to session folder with trailing slash to prevent prefix collisions
# e.g., "/sessions/abc" should not match "/sessions/abc123"
return self.session_path.rstrip("/") + "/"
else:
# No session context, use path as-is
return path
async def read_file(self, path: str) -> bytes:
"""
Read file from workspace by virtual path.
When session_id is set, paths are resolved relative to the session folder
unless they explicitly reference /sessions/...
Args:
path: Virtual path (e.g., "/documents/report.pdf")
Returns:
File content as bytes
Raises:
FileNotFoundError: If file doesn't exist
"""
resolved_path = self._resolve_path(path)
file = await get_workspace_file_by_path(self.workspace_id, resolved_path)
if file is None:
raise FileNotFoundError(f"File not found at path: {resolved_path}")
storage = await get_workspace_storage()
return await storage.retrieve(file.storagePath)
async def read_file_by_id(self, file_id: str) -> bytes:
"""
Read file from workspace by file ID.
Args:
file_id: The file's ID
Returns:
File content as bytes
Raises:
FileNotFoundError: If file doesn't exist
"""
file = await get_workspace_file(file_id, self.workspace_id)
if file is None:
raise FileNotFoundError(f"File not found: {file_id}")
storage = await get_workspace_storage()
return await storage.retrieve(file.storagePath)
async def write_file(
self,
content: bytes,
filename: str,
path: Optional[str] = None,
mime_type: Optional[str] = None,
overwrite: bool = False,
) -> UserWorkspaceFile:
"""
Write file to workspace.
When session_id is set, files are written to /sessions/{session_id}/...
by default. Use explicit /sessions/... paths for cross-session access.
Args:
content: File content as bytes
filename: Filename for the file
path: Virtual path (defaults to "/{filename}", session-scoped if session_id set)
mime_type: MIME type (auto-detected if not provided)
overwrite: Whether to overwrite existing file at path
Returns:
Created UserWorkspaceFile instance
Raises:
ValueError: If file exceeds size limit or path already exists
"""
# Enforce file size limit
max_file_size = Config().max_file_size_mb * 1024 * 1024
if len(content) > max_file_size:
raise ValueError(
f"File too large: {len(content)} bytes exceeds "
f"{Config().max_file_size_mb}MB limit"
)
# Determine path with session scoping
if path is None:
path = f"/{filename}"
elif not path.startswith("/"):
path = f"/{path}"
# Resolve path with session prefix
path = self._resolve_path(path)
# Check if file exists at path (only error for non-overwrite case)
# For overwrite=True, we let the write proceed and handle via UniqueViolationError
# This ensures the new file is written to storage BEFORE the old one is deleted,
# preventing data loss if the new write fails
if not overwrite:
existing = await get_workspace_file_by_path(self.workspace_id, path)
if existing is not None:
raise ValueError(f"File already exists at path: {path}")
# Auto-detect MIME type if not provided
if mime_type is None:
mime_type, _ = mimetypes.guess_type(filename)
mime_type = mime_type or "application/octet-stream"
# Compute checksum
checksum = compute_file_checksum(content)
# Generate unique file ID for storage
file_id = str(uuid.uuid4())
# Store file in storage backend
storage = await get_workspace_storage()
storage_path = await storage.store(
workspace_id=self.workspace_id,
file_id=file_id,
filename=filename,
content=content,
)
# Create database record - handle race condition where another request
# created a file at the same path between our check and create
try:
file = await create_workspace_file(
workspace_id=self.workspace_id,
file_id=file_id,
name=filename,
path=path,
storage_path=storage_path,
mime_type=mime_type,
size_bytes=len(content),
checksum=checksum,
)
except UniqueViolationError:
# Race condition: another request created a file at this path
if overwrite:
# Re-fetch and delete the conflicting file, then retry
existing = await get_workspace_file_by_path(self.workspace_id, path)
if existing:
await self.delete_file(existing.id)
# Retry the create - if this also fails, clean up storage file
try:
file = await create_workspace_file(
workspace_id=self.workspace_id,
file_id=file_id,
name=filename,
path=path,
storage_path=storage_path,
mime_type=mime_type,
size_bytes=len(content),
checksum=checksum,
)
except Exception:
# Clean up orphaned storage file on retry failure
try:
await storage.delete(storage_path)
except Exception as e:
logger.warning(f"Failed to clean up orphaned storage file: {e}")
raise
else:
# Clean up the orphaned storage file before raising
try:
await storage.delete(storage_path)
except Exception as e:
logger.warning(f"Failed to clean up orphaned storage file: {e}")
raise ValueError(f"File already exists at path: {path}")
except Exception:
# Any other database error (connection, validation, etc.) - clean up storage
try:
await storage.delete(storage_path)
except Exception as e:
logger.warning(f"Failed to clean up orphaned storage file: {e}")
raise
logger.info(
f"Wrote file {file.id} ({filename}) to workspace {self.workspace_id} "
f"at path {path}, size={len(content)} bytes"
)
return file
async def list_files(
self,
path: Optional[str] = None,
limit: Optional[int] = None,
offset: int = 0,
include_all_sessions: bool = False,
) -> list[UserWorkspaceFile]:
"""
List files in workspace.
When session_id is set and include_all_sessions is False (default),
only files in the current session's folder are listed.
Args:
path: Optional path prefix to filter (e.g., "/documents/")
limit: Maximum number of files to return
offset: Number of files to skip
include_all_sessions: If True, list files from all sessions.
If False (default), only list current session's files.
Returns:
List of UserWorkspaceFile instances
"""
effective_path = self._get_effective_path(path, include_all_sessions)
return await list_workspace_files(
workspace_id=self.workspace_id,
path_prefix=effective_path,
limit=limit,
offset=offset,
)
async def delete_file(self, file_id: str) -> bool:
"""
Delete a file (soft-delete).
Args:
file_id: The file's ID
Returns:
True if deleted, False if not found
"""
file = await get_workspace_file(file_id, self.workspace_id)
if file is None:
return False
# Delete from storage
storage = await get_workspace_storage()
try:
await storage.delete(file.storagePath)
except Exception as e:
logger.warning(f"Failed to delete file from storage: {e}")
# Continue with database soft-delete even if storage delete fails
# Soft-delete database record
result = await soft_delete_workspace_file(file_id, self.workspace_id)
return result is not None
async def get_download_url(self, file_id: str, expires_in: int = 3600) -> str:
"""
Get download URL for a file.
Args:
file_id: The file's ID
expires_in: URL expiration in seconds (default 1 hour)
Returns:
Download URL (signed URL for GCS, API endpoint for local)
Raises:
FileNotFoundError: If file doesn't exist
"""
file = await get_workspace_file(file_id, self.workspace_id)
if file is None:
raise FileNotFoundError(f"File not found: {file_id}")
storage = await get_workspace_storage()
return await storage.get_download_url(file.storagePath, expires_in)
async def get_file_info(self, file_id: str) -> Optional[UserWorkspaceFile]:
"""
Get file metadata.
Args:
file_id: The file's ID
Returns:
UserWorkspaceFile instance or None
"""
return await get_workspace_file(file_id, self.workspace_id)
async def get_file_info_by_path(self, path: str) -> Optional[UserWorkspaceFile]:
"""
Get file metadata by path.
When session_id is set, paths are resolved relative to the session folder
unless they explicitly reference /sessions/...
Args:
path: Virtual path
Returns:
UserWorkspaceFile instance or None
"""
resolved_path = self._resolve_path(path)
return await get_workspace_file_by_path(self.workspace_id, resolved_path)
async def get_file_count(
self,
path: Optional[str] = None,
include_all_sessions: bool = False,
) -> int:
"""
Get number of files in workspace.
When session_id is set and include_all_sessions is False (default),
only counts files in the current session's folder.
Args:
path: Optional path prefix to filter (e.g., "/documents/")
include_all_sessions: If True, count all files in workspace.
If False (default), only count current session's files.
Returns:
Number of files
"""
effective_path = self._get_effective_path(path, include_all_sessions)
return await count_workspace_files(
self.workspace_id, path_prefix=effective_path
)

View File

@@ -0,0 +1,398 @@
"""
Workspace storage backend abstraction for supporting both cloud and local deployments.
This module provides a unified interface for storing workspace files, with implementations
for Google Cloud Storage (cloud deployments) and local filesystem (self-hosted deployments).
"""
import asyncio
import hashlib
import logging
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
import aiofiles
import aiohttp
from gcloud.aio import storage as async_gcs_storage
from google.cloud import storage as gcs_storage
from backend.util.data import get_data_path
from backend.util.gcs_utils import (
download_with_fresh_session,
generate_signed_url,
parse_gcs_path,
)
from backend.util.settings import Config
logger = logging.getLogger(__name__)
class WorkspaceStorageBackend(ABC):
"""Abstract interface for workspace file storage."""
@abstractmethod
async def store(
self,
workspace_id: str,
file_id: str,
filename: str,
content: bytes,
) -> str:
"""
Store file content, return storage path.
Args:
workspace_id: The workspace ID
file_id: Unique file ID for storage
filename: Original filename
content: File content as bytes
Returns:
Storage path string (cloud path or local path)
"""
pass
@abstractmethod
async def retrieve(self, storage_path: str) -> bytes:
"""
Retrieve file content from storage.
Args:
storage_path: The storage path returned from store()
Returns:
File content as bytes
"""
pass
@abstractmethod
async def delete(self, storage_path: str) -> None:
"""
Delete file from storage.
Args:
storage_path: The storage path to delete
"""
pass
@abstractmethod
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
"""
Get URL for downloading the file.
Args:
storage_path: The storage path
expires_in: URL expiration time in seconds (default 1 hour)
Returns:
Download URL (signed URL for GCS, direct API path for local)
"""
pass
class GCSWorkspaceStorage(WorkspaceStorageBackend):
"""Google Cloud Storage implementation for workspace storage."""
def __init__(self, bucket_name: str):
self.bucket_name = bucket_name
self._async_client: Optional[async_gcs_storage.Storage] = None
self._sync_client: Optional[gcs_storage.Client] = None
self._session: Optional[aiohttp.ClientSession] = None
async def _get_async_client(self) -> async_gcs_storage.Storage:
"""Get or create async GCS client."""
if self._async_client is None:
self._session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=100, force_close=False)
)
self._async_client = async_gcs_storage.Storage(session=self._session)
return self._async_client
def _get_sync_client(self) -> gcs_storage.Client:
"""Get or create sync GCS client (for signed URLs)."""
if self._sync_client is None:
self._sync_client = gcs_storage.Client()
return self._sync_client
async def close(self) -> None:
"""Close all client connections."""
if self._async_client is not None:
try:
await self._async_client.close()
except Exception as e:
logger.warning(f"Error closing GCS client: {e}")
self._async_client = None
if self._session is not None:
try:
await self._session.close()
except Exception as e:
logger.warning(f"Error closing session: {e}")
self._session = None
def _build_blob_name(self, workspace_id: str, file_id: str, filename: str) -> str:
"""Build the blob path for workspace files."""
return f"workspaces/{workspace_id}/{file_id}/{filename}"
async def store(
self,
workspace_id: str,
file_id: str,
filename: str,
content: bytes,
) -> str:
"""Store file in GCS."""
client = await self._get_async_client()
blob_name = self._build_blob_name(workspace_id, file_id, filename)
# Upload with metadata
upload_time = datetime.now(timezone.utc)
await client.upload(
self.bucket_name,
blob_name,
content,
metadata={
"uploaded_at": upload_time.isoformat(),
"workspace_id": workspace_id,
"file_id": file_id,
},
)
return f"gcs://{self.bucket_name}/{blob_name}"
async def retrieve(self, storage_path: str) -> bytes:
"""Retrieve file from GCS."""
bucket_name, blob_name = parse_gcs_path(storage_path)
return await download_with_fresh_session(bucket_name, blob_name)
async def delete(self, storage_path: str) -> None:
"""Delete file from GCS."""
bucket_name, blob_name = parse_gcs_path(storage_path)
client = await self._get_async_client()
try:
await client.delete(bucket_name, blob_name)
except Exception as e:
if "404" not in str(e) and "Not Found" not in str(e):
raise
# File already deleted, that's fine
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
"""
Generate download URL for GCS file.
Attempts to generate a signed URL if running with service account credentials.
Falls back to an API proxy endpoint if signed URL generation fails
(e.g., when running locally with user OAuth credentials).
"""
bucket_name, blob_name = parse_gcs_path(storage_path)
# Extract file_id from blob_name for fallback: workspaces/{workspace_id}/{file_id}/{filename}
blob_parts = blob_name.split("/")
file_id = blob_parts[2] if len(blob_parts) >= 3 else None
# Try to generate signed URL (requires service account credentials)
try:
sync_client = self._get_sync_client()
return await generate_signed_url(
sync_client, bucket_name, blob_name, expires_in
)
except AttributeError as e:
# Signed URL generation requires service account with private key.
# When running with user OAuth credentials, fall back to API proxy.
if "private key" in str(e) and file_id:
logger.debug(
"Cannot generate signed URL (no service account credentials), "
"falling back to API proxy endpoint"
)
return f"/api/workspace/files/{file_id}/download"
raise
class LocalWorkspaceStorage(WorkspaceStorageBackend):
"""Local filesystem implementation for workspace storage (self-hosted deployments)."""
def __init__(self, base_dir: Optional[str] = None):
"""
Initialize local storage backend.
Args:
base_dir: Base directory for workspace storage.
If None, defaults to {app_data}/workspaces
"""
if base_dir:
self.base_dir = Path(base_dir)
else:
self.base_dir = Path(get_data_path()) / "workspaces"
# Ensure base directory exists
self.base_dir.mkdir(parents=True, exist_ok=True)
def _build_file_path(self, workspace_id: str, file_id: str, filename: str) -> Path:
"""Build the local file path with path traversal protection."""
# Import here to avoid circular import
# (file.py imports workspace.py which imports workspace_storage.py)
from backend.util.file import sanitize_filename
# Sanitize filename to prevent path traversal (removes / and \ among others)
safe_filename = sanitize_filename(filename)
file_path = (self.base_dir / workspace_id / file_id / safe_filename).resolve()
# Verify the resolved path is still under base_dir
if not file_path.is_relative_to(self.base_dir.resolve()):
raise ValueError("Invalid filename: path traversal detected")
return file_path
def _parse_storage_path(self, storage_path: str) -> Path:
"""Parse local storage path to filesystem path."""
if storage_path.startswith("local://"):
relative_path = storage_path[8:] # Remove "local://"
else:
relative_path = storage_path
full_path = (self.base_dir / relative_path).resolve()
# Security check: ensure path is under base_dir
# Use is_relative_to() for robust path containment check
# (handles case-insensitive filesystems and edge cases)
if not full_path.is_relative_to(self.base_dir.resolve()):
raise ValueError("Invalid storage path: path traversal detected")
return full_path
async def store(
self,
workspace_id: str,
file_id: str,
filename: str,
content: bytes,
) -> str:
"""Store file locally."""
file_path = self._build_file_path(workspace_id, file_id, filename)
# Create parent directories
file_path.parent.mkdir(parents=True, exist_ok=True)
# Write file asynchronously
async with aiofiles.open(file_path, "wb") as f:
await f.write(content)
# Return relative path as storage path
relative_path = file_path.relative_to(self.base_dir)
return f"local://{relative_path}"
async def retrieve(self, storage_path: str) -> bytes:
"""Retrieve file from local storage."""
file_path = self._parse_storage_path(storage_path)
if not file_path.exists():
raise FileNotFoundError(f"File not found: {storage_path}")
async with aiofiles.open(file_path, "rb") as f:
return await f.read()
async def delete(self, storage_path: str) -> None:
"""Delete file from local storage."""
file_path = self._parse_storage_path(storage_path)
if file_path.exists():
# Remove file
file_path.unlink()
# Clean up empty parent directories
parent = file_path.parent
while parent != self.base_dir:
try:
if parent.exists() and not any(parent.iterdir()):
parent.rmdir()
else:
break
except OSError:
break
parent = parent.parent
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
"""
Get download URL for local file.
For local storage, this returns an API endpoint path.
The actual serving is handled by the API layer.
"""
# Parse the storage path to get the components
if storage_path.startswith("local://"):
relative_path = storage_path[8:]
else:
relative_path = storage_path
# Return the API endpoint for downloading
# The file_id is extracted from the path: {workspace_id}/{file_id}/{filename}
parts = relative_path.split("/")
if len(parts) >= 2:
file_id = parts[1] # Second component is file_id
return f"/api/workspace/files/{file_id}/download"
else:
raise ValueError(f"Invalid storage path format: {storage_path}")
# Global storage backend instance
_workspace_storage: Optional[WorkspaceStorageBackend] = None
_storage_lock = asyncio.Lock()
async def get_workspace_storage() -> WorkspaceStorageBackend:
"""
Get the workspace storage backend instance.
Uses GCS if media_gcs_bucket_name is configured, otherwise uses local storage.
"""
global _workspace_storage
if _workspace_storage is None:
async with _storage_lock:
if _workspace_storage is None:
config = Config()
if config.media_gcs_bucket_name:
logger.info(
f"Using GCS workspace storage: {config.media_gcs_bucket_name}"
)
_workspace_storage = GCSWorkspaceStorage(
config.media_gcs_bucket_name
)
else:
storage_dir = (
config.workspace_storage_dir
if config.workspace_storage_dir
else None
)
logger.info(
f"Using local workspace storage: {storage_dir or 'default'}"
)
_workspace_storage = LocalWorkspaceStorage(storage_dir)
return _workspace_storage
async def shutdown_workspace_storage() -> None:
"""
Properly shutdown the global workspace storage backend.
Closes aiohttp sessions and other resources for GCS backend.
Should be called during application shutdown.
"""
global _workspace_storage
if _workspace_storage is not None:
async with _storage_lock:
if _workspace_storage is not None:
if isinstance(_workspace_storage, GCSWorkspaceStorage):
await _workspace_storage.close()
_workspace_storage = None
def compute_file_checksum(content: bytes) -> str:
"""Compute SHA256 checksum of file content."""
return hashlib.sha256(content).hexdigest()

View File

@@ -0,0 +1,52 @@
-- CreateEnum
CREATE TYPE "WorkspaceFileSource" AS ENUM ('UPLOAD', 'EXECUTION', 'COPILOT', 'IMPORT');
-- CreateTable
CREATE TABLE "UserWorkspace" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"userId" TEXT NOT NULL,
CONSTRAINT "UserWorkspace_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "UserWorkspaceFile" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"workspaceId" TEXT NOT NULL,
"name" TEXT NOT NULL,
"path" TEXT NOT NULL,
"storagePath" TEXT NOT NULL,
"mimeType" TEXT NOT NULL,
"sizeBytes" BIGINT NOT NULL,
"checksum" TEXT,
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
"deletedAt" TIMESTAMP(3),
"source" "WorkspaceFileSource" NOT NULL DEFAULT 'UPLOAD',
"sourceExecId" TEXT,
"sourceSessionId" TEXT,
"metadata" JSONB NOT NULL DEFAULT '{}',
CONSTRAINT "UserWorkspaceFile_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "UserWorkspace_userId_key" ON "UserWorkspace"("userId");
-- CreateIndex
CREATE INDEX "UserWorkspace_userId_idx" ON "UserWorkspace"("userId");
-- CreateIndex
CREATE INDEX "UserWorkspaceFile_workspaceId_isDeleted_idx" ON "UserWorkspaceFile"("workspaceId", "isDeleted");
-- CreateIndex
CREATE UNIQUE INDEX "UserWorkspaceFile_workspaceId_path_key" ON "UserWorkspaceFile"("workspaceId", "path");
-- AddForeignKey
ALTER TABLE "UserWorkspace" ADD CONSTRAINT "UserWorkspace_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "UserWorkspaceFile" ADD CONSTRAINT "UserWorkspaceFile_workspaceId_fkey" FOREIGN KEY ("workspaceId") REFERENCES "UserWorkspace"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -0,0 +1,16 @@
/*
Warnings:
- You are about to drop the column `source` on the `UserWorkspaceFile` table. All the data in the column will be lost.
- You are about to drop the column `sourceExecId` on the `UserWorkspaceFile` table. All the data in the column will be lost.
- You are about to drop the column `sourceSessionId` on the `UserWorkspaceFile` table. All the data in the column will be lost.
*/
-- AlterTable
ALTER TABLE "UserWorkspaceFile" DROP COLUMN "source",
DROP COLUMN "sourceExecId",
DROP COLUMN "sourceSessionId";
-- DropEnum
DROP TYPE "WorkspaceFileSource";

View File

@@ -63,6 +63,7 @@ model User {
IntegrationWebhooks IntegrationWebhook[]
NotificationBatches UserNotificationBatch[]
PendingHumanReviews PendingHumanReview[]
Workspace UserWorkspace?
// OAuth Provider relations
OAuthApplications OAuthApplication[]
@@ -137,6 +138,53 @@ model CoPilotUnderstanding {
@@index([userId])
}
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
//////////////// USER WORKSPACE TABLES /////////////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
// User's persistent file storage workspace
model UserWorkspace {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
userId String @unique
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
Files UserWorkspaceFile[]
@@index([userId])
}
// Individual files in a user's workspace
model UserWorkspaceFile {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
workspaceId String
Workspace UserWorkspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
// File metadata
name String // User-visible filename
path String // Virtual path (e.g., "/documents/report.pdf")
storagePath String // Actual GCS or local storage path
mimeType String
sizeBytes BigInt
checksum String? // SHA256 for integrity
// File state
isDeleted Boolean @default(false)
deletedAt DateTime?
metadata Json @default("{}")
@@unique([workspaceId, path])
@@index([workspaceId, isDeleted])
}
model BuilderSearchHistory {
id String @id @default(uuid())
createdAt DateTime @default(now())

View File

@@ -34,3 +34,6 @@ NEXT_PUBLIC_PREVIEW_STEALING_DEV=
# PostHog Analytics
NEXT_PUBLIC_POSTHOG_KEY=
NEXT_PUBLIC_POSTHOG_HOST=https://eu.i.posthog.com
# OpenAI (for voice transcription)
OPENAI_API_KEY=

View File

@@ -73,9 +73,9 @@ export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
};
const reset = () => {
// Only reset the offset - keep existing sessions visible during refetch
// The effect will replace sessions when new data arrives at offset 0
setOffset(0);
setAccumulatedSessions([]);
setTotalCount(null);
};
return {

View File

@@ -5912,6 +5912,40 @@
}
}
},
"/api/workspace/files/{file_id}/download": {
"get": {
"tags": ["workspace"],
"summary": "Download file by ID",
"description": "Download a file by its ID.\n\nReturns the file content directly or redirects to a signed URL for GCS.",
"operationId": "getWorkspaceDownload file by id",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "file_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "File Id" }
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": { "application/json": { "schema": {} } }
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/health": {
"get": {
"tags": ["health"],

View File

@@ -1,5 +1,6 @@
import {
ApiError,
getServerAuthToken,
makeAuthenticatedFileUpload,
makeAuthenticatedRequest,
} from "@/lib/autogpt-server-api/helpers";
@@ -15,6 +16,69 @@ function buildBackendUrl(path: string[], queryString: string): string {
return `${environment.getAGPTServerBaseUrl()}/${backendPath}${queryString}`;
}
/**
* Check if this is a workspace file download request that needs binary response handling.
*/
function isWorkspaceDownloadRequest(path: string[]): boolean {
// Match pattern: api/workspace/files/{id}/download (5 segments)
return (
path.length == 5 &&
path[0] === "api" &&
path[1] === "workspace" &&
path[2] === "files" &&
path[path.length - 1] === "download"
);
}
/**
* Handle workspace file download requests with proper binary response streaming.
*/
async function handleWorkspaceDownload(
req: NextRequest,
backendUrl: string,
): Promise<NextResponse> {
const token = await getServerAuthToken();
const headers: Record<string, string> = {};
if (token && token !== "no-token-found") {
headers["Authorization"] = `Bearer ${token}`;
}
const response = await fetch(backendUrl, {
method: "GET",
headers,
redirect: "follow", // Follow redirects to signed URLs
});
if (!response.ok) {
return NextResponse.json(
{ error: `Failed to download file: ${response.statusText}` },
{ status: response.status },
);
}
// Get the content type from the backend response
const contentType =
response.headers.get("Content-Type") || "application/octet-stream";
const contentDisposition = response.headers.get("Content-Disposition");
// Stream the response body
const responseHeaders: Record<string, string> = {
"Content-Type": contentType,
};
if (contentDisposition) {
responseHeaders["Content-Disposition"] = contentDisposition;
}
// Return the binary content
const arrayBuffer = await response.arrayBuffer();
return new NextResponse(arrayBuffer, {
status: 200,
headers: responseHeaders,
});
}
async function handleJsonRequest(
req: NextRequest,
method: string,
@@ -180,6 +244,11 @@ async function handler(
};
try {
// Handle workspace file downloads separately (binary response)
if (method === "GET" && isWorkspaceDownloadRequest(path)) {
return await handleWorkspaceDownload(req, backendUrl);
}
if (method === "GET" || method === "DELETE") {
responseBody = await handleGetDeleteRequest(method, backendUrl, req);
} else if (contentType?.includes("application/json")) {

View File

@@ -0,0 +1,77 @@
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
import { NextRequest, NextResponse } from "next/server";
const WHISPER_API_URL = "https://api.openai.com/v1/audio/transcriptions";
const MAX_FILE_SIZE = 25 * 1024 * 1024; // 25MB - Whisper's limit
function getExtensionFromMimeType(mimeType: string): string {
const subtype = mimeType.split("/")[1]?.split(";")[0];
return subtype || "webm";
}
export async function POST(request: NextRequest) {
const token = await getServerAuthToken();
if (!token || token === "no-token-found") {
return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
}
const apiKey = process.env.OPENAI_API_KEY;
if (!apiKey) {
return NextResponse.json(
{ error: "OpenAI API key not configured" },
{ status: 401 },
);
}
try {
const formData = await request.formData();
const audioFile = formData.get("audio");
if (!audioFile || !(audioFile instanceof Blob)) {
return NextResponse.json(
{ error: "No audio file provided" },
{ status: 400 },
);
}
if (audioFile.size > MAX_FILE_SIZE) {
return NextResponse.json(
{ error: "File too large. Maximum size is 25MB." },
{ status: 413 },
);
}
const ext = getExtensionFromMimeType(audioFile.type);
const whisperFormData = new FormData();
whisperFormData.append("file", audioFile, `recording.${ext}`);
whisperFormData.append("model", "whisper-1");
const response = await fetch(WHISPER_API_URL, {
method: "POST",
headers: {
Authorization: `Bearer ${apiKey}`,
},
body: whisperFormData,
});
if (!response.ok) {
const errorData = await response.json().catch(() => ({}));
console.error("Whisper API error:", errorData);
return NextResponse.json(
{ error: errorData.error?.message || "Transcription failed" },
{ status: response.status },
);
}
const result = await response.json();
return NextResponse.json({ text: result.text });
} catch (error) {
console.error("Transcription error:", error);
return NextResponse.json(
{ error: "Failed to process audio" },
{ status: 500 },
);
}
}

View File

@@ -1,7 +1,14 @@
import { Button } from "@/components/atoms/Button/Button";
import { cn } from "@/lib/utils";
import { ArrowUpIcon, StopIcon } from "@phosphor-icons/react";
import {
ArrowUpIcon,
CircleNotchIcon,
MicrophoneIcon,
StopIcon,
} from "@phosphor-icons/react";
import { RecordingIndicator } from "./components/RecordingIndicator";
import { useChatInput } from "./useChatInput";
import { useVoiceRecording } from "./useVoiceRecording";
export interface Props {
onSend: (message: string) => void;
@@ -21,13 +28,36 @@ export function ChatInput({
className,
}: Props) {
const inputId = "chat-input";
const { value, handleKeyDown, handleSubmit, handleChange, hasMultipleLines } =
useChatInput({
onSend,
disabled: disabled || isStreaming,
maxRows: 4,
inputId,
});
const {
value,
setValue,
handleKeyDown: baseHandleKeyDown,
handleSubmit,
handleChange,
hasMultipleLines,
} = useChatInput({
onSend,
disabled: disabled || isStreaming,
maxRows: 4,
inputId,
});
const {
isRecording,
isTranscribing,
elapsedTime,
toggleRecording,
handleKeyDown,
showMicButton,
isInputDisabled,
audioStream,
} = useVoiceRecording({
setValue,
disabled: disabled || isStreaming,
isStreaming,
value,
baseHandleKeyDown,
});
return (
<form onSubmit={handleSubmit} className={cn("relative flex-1", className)}>
@@ -35,8 +65,11 @@ export function ChatInput({
<div
id={`${inputId}-wrapper`}
className={cn(
"relative overflow-hidden border border-neutral-200 bg-white shadow-sm",
"focus-within:border-zinc-400 focus-within:ring-1 focus-within:ring-zinc-400",
"relative overflow-hidden border bg-white shadow-sm",
"focus-within:ring-1",
isRecording
? "border-red-400 focus-within:border-red-400 focus-within:ring-red-400"
: "border-neutral-200 focus-within:border-zinc-400 focus-within:ring-zinc-400",
hasMultipleLines ? "rounded-xlarge" : "rounded-full",
)}
>
@@ -46,48 +79,94 @@ export function ChatInput({
value={value}
onChange={handleChange}
onKeyDown={handleKeyDown}
placeholder={placeholder}
disabled={disabled || isStreaming}
placeholder={
isTranscribing
? "Transcribing..."
: isRecording
? ""
: placeholder
}
disabled={isInputDisabled}
rows={1}
className={cn(
"w-full resize-none overflow-y-auto border-0 bg-transparent text-[1rem] leading-6 text-black",
"placeholder:text-zinc-400",
"focus:outline-none focus:ring-0",
"disabled:text-zinc-500",
hasMultipleLines ? "pb-6 pl-4 pr-4 pt-2" : "pb-4 pl-4 pr-14 pt-4",
hasMultipleLines
? "pb-6 pl-4 pr-4 pt-2"
: showMicButton
? "pb-4 pl-14 pr-14 pt-4"
: "pb-4 pl-4 pr-14 pt-4",
)}
/>
{isRecording && !value && (
<div className="pointer-events-none absolute inset-0 flex items-center justify-center">
<RecordingIndicator
elapsedTime={elapsedTime}
audioStream={audioStream}
/>
</div>
)}
</div>
<span id="chat-input-hint" className="sr-only">
Press Enter to send, Shift+Enter for new line
Press Enter to send, Shift+Enter for new line, Space to record voice
</span>
{isStreaming ? (
<Button
type="button"
variant="icon"
size="icon"
aria-label="Stop generating"
onClick={onStop}
className="absolute bottom-[7px] right-2 border-red-600 bg-red-600 text-white hover:border-red-800 hover:bg-red-800"
>
<StopIcon className="h-4 w-4" weight="bold" />
</Button>
) : (
<Button
type="submit"
variant="icon"
size="icon"
aria-label="Send message"
className={cn(
"absolute bottom-[7px] right-2 border-zinc-800 bg-zinc-800 text-white hover:border-zinc-900 hover:bg-zinc-900",
(disabled || !value.trim()) && "opacity-20",
)}
disabled={disabled || !value.trim()}
>
<ArrowUpIcon className="h-4 w-4" weight="bold" />
</Button>
{showMicButton && (
<div className="absolute bottom-[7px] left-2 flex items-center gap-1">
<Button
type="button"
variant="icon"
size="icon"
aria-label={isRecording ? "Stop recording" : "Start recording"}
onClick={toggleRecording}
disabled={disabled || isTranscribing}
className={cn(
isRecording
? "animate-pulse border-red-500 bg-red-500 text-white hover:border-red-600 hover:bg-red-600"
: isTranscribing
? "border-zinc-300 bg-zinc-100 text-zinc-400"
: "border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
)}
>
{isTranscribing ? (
<CircleNotchIcon className="h-4 w-4 animate-spin" />
) : (
<MicrophoneIcon className="h-4 w-4" weight="bold" />
)}
</Button>
</div>
)}
<div className="absolute bottom-[7px] right-2 flex items-center gap-1">
{isStreaming ? (
<Button
type="button"
variant="icon"
size="icon"
aria-label="Stop generating"
onClick={onStop}
className="border-red-600 bg-red-600 text-white hover:border-red-800 hover:bg-red-800"
>
<StopIcon className="h-4 w-4" weight="bold" />
</Button>
) : (
<Button
type="submit"
variant="icon"
size="icon"
aria-label="Send message"
className={cn(
"border-zinc-800 bg-zinc-800 text-white hover:border-zinc-900 hover:bg-zinc-900",
(disabled || !value.trim() || isRecording) && "opacity-20",
)}
disabled={disabled || !value.trim() || isRecording}
>
<ArrowUpIcon className="h-4 w-4" weight="bold" />
</Button>
)}
</div>
</div>
</form>
);

View File

@@ -0,0 +1,142 @@
"use client";
import { useEffect, useRef, useState } from "react";
interface Props {
stream: MediaStream | null;
barCount?: number;
barWidth?: number;
barGap?: number;
barColor?: string;
minBarHeight?: number;
maxBarHeight?: number;
}
export function AudioWaveform({
stream,
barCount = 24,
barWidth = 3,
barGap = 2,
barColor = "#ef4444", // red-500
minBarHeight = 4,
maxBarHeight = 32,
}: Props) {
const [bars, setBars] = useState<number[]>(() =>
Array(barCount).fill(minBarHeight),
);
const analyserRef = useRef<AnalyserNode | null>(null);
const audioContextRef = useRef<AudioContext | null>(null);
const sourceRef = useRef<MediaStreamAudioSourceNode | null>(null);
const animationRef = useRef<number | null>(null);
useEffect(() => {
if (!stream) {
setBars(Array(barCount).fill(minBarHeight));
return;
}
// Create audio context and analyser
const audioContext = new AudioContext();
const analyser = audioContext.createAnalyser();
analyser.fftSize = 512;
analyser.smoothingTimeConstant = 0.8;
// Connect the stream to the analyser
const source = audioContext.createMediaStreamSource(stream);
source.connect(analyser);
audioContextRef.current = audioContext;
analyserRef.current = analyser;
sourceRef.current = source;
const timeData = new Uint8Array(analyser.frequencyBinCount);
const updateBars = () => {
if (!analyserRef.current) return;
analyserRef.current.getByteTimeDomainData(timeData);
// Distribute time-domain data across bars
// This shows waveform amplitude, making all bars respond to audio
const newBars: number[] = [];
const samplesPerBar = timeData.length / barCount;
for (let i = 0; i < barCount; i++) {
// Sample waveform data for this bar
let maxAmplitude = 0;
const startIdx = Math.floor(i * samplesPerBar);
const endIdx = Math.floor((i + 1) * samplesPerBar);
for (let j = startIdx; j < endIdx && j < timeData.length; j++) {
// Convert to amplitude (distance from center 128)
const amplitude = Math.abs(timeData[j] - 128);
maxAmplitude = Math.max(maxAmplitude, amplitude);
}
// Map amplitude (0-128) to bar height
const normalized = (maxAmplitude / 128) * 255;
const height =
minBarHeight + (normalized / 255) * (maxBarHeight - minBarHeight);
newBars.push(height);
}
setBars(newBars);
animationRef.current = requestAnimationFrame(updateBars);
};
updateBars();
return () => {
if (animationRef.current) {
cancelAnimationFrame(animationRef.current);
}
if (sourceRef.current) {
sourceRef.current.disconnect();
}
if (audioContextRef.current) {
audioContextRef.current.close();
}
analyserRef.current = null;
audioContextRef.current = null;
sourceRef.current = null;
};
}, [stream, barCount, minBarHeight, maxBarHeight]);
const totalWidth = barCount * barWidth + (barCount - 1) * barGap;
return (
<div
className="flex items-center justify-center"
style={{
width: totalWidth,
height: maxBarHeight,
gap: barGap,
}}
>
{bars.map((height, i) => {
const barHeight = Math.max(minBarHeight, height);
return (
<div
key={i}
className="relative"
style={{
width: barWidth,
height: maxBarHeight,
}}
>
<div
className="absolute left-0 rounded-full transition-[height] duration-75"
style={{
width: barWidth,
height: barHeight,
top: "50%",
transform: "translateY(-50%)",
backgroundColor: barColor,
}}
/>
</div>
);
})}
</div>
);
}

View File

@@ -0,0 +1,26 @@
import { formatElapsedTime } from "../helpers";
import { AudioWaveform } from "./AudioWaveform";
type Props = {
elapsedTime: number;
audioStream: MediaStream | null;
};
export function RecordingIndicator({ elapsedTime, audioStream }: Props) {
return (
<div className="flex items-center gap-3">
<AudioWaveform
stream={audioStream}
barCount={20}
barWidth={3}
barGap={2}
barColor="#ef4444"
minBarHeight={4}
maxBarHeight={24}
/>
<span className="min-w-[3ch] text-sm font-medium text-red-500">
{formatElapsedTime(elapsedTime)}
</span>
</div>
);
}

View File

@@ -0,0 +1,6 @@
export function formatElapsedTime(ms: number): string {
const seconds = Math.floor(ms / 1000);
const minutes = Math.floor(seconds / 60);
const remainingSeconds = seconds % 60;
return `${minutes}:${remainingSeconds.toString().padStart(2, "0")}`;
}

View File

@@ -6,7 +6,7 @@ import {
useState,
} from "react";
interface UseChatInputArgs {
interface Args {
onSend: (message: string) => void;
disabled?: boolean;
maxRows?: number;
@@ -18,7 +18,7 @@ export function useChatInput({
disabled = false,
maxRows = 5,
inputId = "chat-input",
}: UseChatInputArgs) {
}: Args) {
const [value, setValue] = useState("");
const [hasMultipleLines, setHasMultipleLines] = useState(false);

View File

@@ -0,0 +1,240 @@
import { useToast } from "@/components/molecules/Toast/use-toast";
import React, {
KeyboardEvent,
useCallback,
useEffect,
useRef,
useState,
} from "react";
const MAX_RECORDING_DURATION = 2 * 60 * 1000; // 2 minutes in ms
interface Args {
setValue: React.Dispatch<React.SetStateAction<string>>;
disabled?: boolean;
isStreaming?: boolean;
value: string;
baseHandleKeyDown: (event: KeyboardEvent<HTMLTextAreaElement>) => void;
}
export function useVoiceRecording({
setValue,
disabled = false,
isStreaming = false,
value,
baseHandleKeyDown,
}: Args) {
const [isRecording, setIsRecording] = useState(false);
const [isTranscribing, setIsTranscribing] = useState(false);
const [error, setError] = useState<string | null>(null);
const [elapsedTime, setElapsedTime] = useState(0);
const mediaRecorderRef = useRef<MediaRecorder | null>(null);
const chunksRef = useRef<Blob[]>([]);
const timerRef = useRef<NodeJS.Timeout | null>(null);
const startTimeRef = useRef<number>(0);
const streamRef = useRef<MediaStream | null>(null);
const isRecordingRef = useRef(false);
const isSupported =
typeof window !== "undefined" &&
!!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
const clearTimer = useCallback(() => {
if (timerRef.current) {
clearInterval(timerRef.current);
timerRef.current = null;
}
}, []);
const cleanup = useCallback(() => {
clearTimer();
if (streamRef.current) {
streamRef.current.getTracks().forEach((track) => track.stop());
streamRef.current = null;
}
mediaRecorderRef.current = null;
chunksRef.current = [];
setElapsedTime(0);
}, [clearTimer]);
const handleTranscription = useCallback(
(text: string) => {
setValue((prev) => {
const trimmedPrev = prev.trim();
if (trimmedPrev) {
return `${trimmedPrev} ${text}`;
}
return text;
});
},
[setValue],
);
const transcribeAudio = useCallback(
async (audioBlob: Blob) => {
setIsTranscribing(true);
setError(null);
try {
const formData = new FormData();
formData.append("audio", audioBlob);
const response = await fetch("/api/transcribe", {
method: "POST",
body: formData,
});
if (!response.ok) {
const data = await response.json().catch(() => ({}));
throw new Error(data.error || "Transcription failed");
}
const data = await response.json();
if (data.text) {
handleTranscription(data.text);
}
} catch (err) {
const message =
err instanceof Error ? err.message : "Transcription failed";
setError(message);
console.error("Transcription error:", err);
} finally {
setIsTranscribing(false);
}
},
[handleTranscription],
);
const stopRecording = useCallback(() => {
if (mediaRecorderRef.current && isRecordingRef.current) {
mediaRecorderRef.current.stop();
isRecordingRef.current = false;
setIsRecording(false);
clearTimer();
}
}, [clearTimer]);
const startRecording = useCallback(async () => {
if (disabled || isRecordingRef.current || isTranscribing) return;
setError(null);
chunksRef.current = [];
try {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
streamRef.current = stream;
const mediaRecorder = new MediaRecorder(stream, {
mimeType: MediaRecorder.isTypeSupported("audio/webm")
? "audio/webm"
: "audio/mp4",
});
mediaRecorderRef.current = mediaRecorder;
mediaRecorder.ondataavailable = (event) => {
if (event.data.size > 0) {
chunksRef.current.push(event.data);
}
};
mediaRecorder.onstop = async () => {
const audioBlob = new Blob(chunksRef.current, {
type: mediaRecorder.mimeType,
});
// Cleanup stream
if (streamRef.current) {
streamRef.current.getTracks().forEach((track) => track.stop());
streamRef.current = null;
}
if (audioBlob.size > 0) {
await transcribeAudio(audioBlob);
}
};
mediaRecorder.start(1000); // Collect data every second
isRecordingRef.current = true;
setIsRecording(true);
startTimeRef.current = Date.now();
// Start elapsed time timer
timerRef.current = setInterval(() => {
const elapsed = Date.now() - startTimeRef.current;
setElapsedTime(elapsed);
// Auto-stop at max duration
if (elapsed >= MAX_RECORDING_DURATION) {
stopRecording();
}
}, 100);
} catch (err) {
console.error("Failed to start recording:", err);
if (err instanceof DOMException && err.name === "NotAllowedError") {
setError("Microphone permission denied");
} else {
setError("Failed to access microphone");
}
cleanup();
}
}, [disabled, isTranscribing, stopRecording, transcribeAudio, cleanup]);
const toggleRecording = useCallback(() => {
if (isRecording) {
stopRecording();
} else {
startRecording();
}
}, [isRecording, startRecording, stopRecording]);
const { toast } = useToast();
useEffect(() => {
if (error) {
toast({
title: "Voice recording failed",
description: error,
variant: "destructive",
});
}
}, [error, toast]);
const handleKeyDown = useCallback(
(event: KeyboardEvent<HTMLTextAreaElement>) => {
if (event.key === " " && !value.trim() && !isTranscribing) {
event.preventDefault();
toggleRecording();
return;
}
baseHandleKeyDown(event);
},
[value, isTranscribing, toggleRecording, baseHandleKeyDown],
);
const showMicButton = isSupported && !isStreaming;
const isInputDisabled = disabled || isStreaming || isTranscribing;
// Cleanup on unmount
useEffect(() => {
return () => {
cleanup();
};
}, [cleanup]);
return {
isRecording,
isTranscribing,
error,
elapsedTime,
startRecording,
stopRecording,
toggleRecording,
isSupported,
handleKeyDown,
showMicButton,
isInputDisabled,
audioStream: streamRef.current,
};
}

View File

@@ -1,6 +1,8 @@
"use client";
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
import { cn } from "@/lib/utils";
import { EyeSlash } from "@phosphor-icons/react";
import React from "react";
import ReactMarkdown from "react-markdown";
import remarkGfm from "remark-gfm";
@@ -29,12 +31,88 @@ interface InputProps extends React.InputHTMLAttributes<HTMLInputElement> {
type?: string;
}
/**
* Converts a workspace:// URL to a proxy URL that routes through Next.js to the backend.
* workspace://abc123 -> /api/proxy/api/workspace/files/abc123/download
*
* Uses the generated API URL helper and routes through the Next.js proxy
* which handles authentication and proper backend routing.
*/
/**
* URL transformer for ReactMarkdown.
* Converts workspace:// URLs to proxy URLs that route through Next.js to the backend.
* workspace://abc123 -> /api/proxy/api/workspace/files/abc123/download
*
* This is needed because ReactMarkdown sanitizes URLs and only allows
* http, https, mailto, and tel protocols by default.
*/
function resolveWorkspaceUrl(src: string): string {
if (src.startsWith("workspace://")) {
const fileId = src.replace("workspace://", "");
// Use the generated API URL helper to get the correct path
const apiPath = getGetWorkspaceDownloadFileByIdUrl(fileId);
// Route through the Next.js proxy (same pattern as customMutator for client-side)
return `/api/proxy${apiPath}`;
}
return src;
}
/**
* Check if the image URL is a workspace file (AI cannot see these yet).
* After URL transformation, workspace files have URLs like /api/proxy/api/workspace/files/...
*/
function isWorkspaceImage(src: string | undefined): boolean {
return src?.includes("/workspace/files/") ?? false;
}
/**
* Custom image component that shows an indicator when the AI cannot see the image.
* Note: src is already transformed by urlTransform, so workspace:// is now /api/workspace/...
*/
function MarkdownImage(props: Record<string, unknown>) {
const src = props.src as string | undefined;
const alt = props.alt as string | undefined;
const aiCannotSee = isWorkspaceImage(src);
// If no src, show a placeholder
if (!src) {
return (
<span className="my-2 inline-block rounded border border-amber-200 bg-amber-50 px-2 py-1 text-sm text-amber-700">
[Image: {alt || "missing src"}]
</span>
);
}
return (
<span className="relative my-2 inline-block">
{/* eslint-disable-next-line @next/next/no-img-element */}
<img
src={src}
alt={alt || "Image"}
className="h-auto max-w-full rounded-md border border-zinc-200"
loading="lazy"
/>
{aiCannotSee && (
<span
className="absolute bottom-2 right-2 flex items-center gap-1 rounded bg-black/70 px-2 py-1 text-xs text-white"
title="The AI cannot see this image"
>
<EyeSlash size={14} />
<span>AI cannot see this image</span>
</span>
)}
</span>
);
}
export function MarkdownContent({ content, className }: MarkdownContentProps) {
return (
<div className={cn("markdown-content", className)}>
<ReactMarkdown
skipHtml={true}
remarkPlugins={[remarkGfm]}
urlTransform={resolveWorkspaceUrl}
components={{
code: ({ children, className, ...props }: CodeProps) => {
const isInline = !className?.includes("language-");
@@ -206,6 +284,9 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) {
{children}
</td>
),
img: ({ src, alt, ...props }) => (
<MarkdownImage src={src} alt={alt} {...props} />
),
}}
>
{content}

View File

@@ -37,6 +37,87 @@ export function getErrorMessage(result: unknown): string {
return "An error occurred";
}
/**
* Check if a value is a workspace file reference.
*/
function isWorkspaceRef(value: unknown): value is string {
return typeof value === "string" && value.startsWith("workspace://");
}
/**
* Check if a workspace reference appears to be an image based on common patterns.
* Since workspace refs don't have extensions, we check the context or assume image
* for certain block types.
*
* TODO: Replace keyword matching with MIME type encoded in workspace ref.
* e.g., workspace://abc123#image/png or workspace://abc123#video/mp4
* This would let frontend render correctly without fragile keyword matching.
*/
function isLikelyImageRef(value: string, outputKey?: string): boolean {
if (!isWorkspaceRef(value)) return false;
// Check output key name for video-related hints (these are NOT images)
const videoKeywords = ["video", "mp4", "mov", "avi", "webm", "movie", "clip"];
if (outputKey) {
const lowerKey = outputKey.toLowerCase();
if (videoKeywords.some((kw) => lowerKey.includes(kw))) {
return false;
}
}
// Check output key name for image-related hints
const imageKeywords = [
"image",
"img",
"photo",
"picture",
"thumbnail",
"avatar",
"icon",
"screenshot",
];
if (outputKey) {
const lowerKey = outputKey.toLowerCase();
if (imageKeywords.some((kw) => lowerKey.includes(kw))) {
return true;
}
}
// Default to treating workspace refs as potential images
// since that's the most common case for generated content
return true;
}
/**
* Format a single output value, converting workspace refs to markdown images.
*/
function formatOutputValue(value: unknown, outputKey?: string): string {
if (isWorkspaceRef(value) && isLikelyImageRef(value, outputKey)) {
// Format as markdown image
return `![${outputKey || "Generated image"}](${value})`;
}
if (typeof value === "string") {
// Check for data URIs (images)
if (value.startsWith("data:image/")) {
return `![${outputKey || "Generated image"}](${value})`;
}
return value;
}
if (Array.isArray(value)) {
return value
.map((item, idx) => formatOutputValue(item, `${outputKey}_${idx}`))
.join("\n\n");
}
if (typeof value === "object" && value !== null) {
return JSON.stringify(value, null, 2);
}
return String(value);
}
function getToolCompletionPhrase(toolName: string): string {
const toolCompletionPhrases: Record<string, string> = {
add_understanding: "Updated your business information",
@@ -127,10 +208,26 @@ export function formatToolResponse(result: unknown, toolName: string): string {
case "block_output":
const blockName = (response.block_name as string) || "Block";
const outputs = response.outputs as Record<string, unknown> | undefined;
const outputs = response.outputs as Record<string, unknown[]> | undefined;
if (outputs && Object.keys(outputs).length > 0) {
const outputKeys = Object.keys(outputs);
return `${blockName} executed successfully. Outputs: ${outputKeys.join(", ")}`;
const formattedOutputs: string[] = [];
for (const [key, values] of Object.entries(outputs)) {
if (!Array.isArray(values) || values.length === 0) continue;
// Format each value in the output array
for (const value of values) {
const formatted = formatOutputValue(value, key);
if (formatted) {
formattedOutputs.push(formatted);
}
}
}
if (formattedOutputs.length > 0) {
return `${blockName} executed successfully.\n\n${formattedOutputs.join("\n\n")}`;
}
return `${blockName} executed successfully.`;
}
return `${blockName} executed successfully.`;

View File

@@ -53,7 +53,7 @@ Below is a comprehensive list of all available blocks, categorized by their prim
| [Block Installation](block-integrations/basic.md#block-installation) | Given a code string, this block allows the verification and installation of a block code into the system |
| [Concatenate Lists](block-integrations/basic.md#concatenate-lists) | Concatenates multiple lists into a single list |
| [Dictionary Is Empty](block-integrations/basic.md#dictionary-is-empty) | Checks if a dictionary is empty |
| [File Store](block-integrations/basic.md#file-store) | Stores the input file in the temporary directory |
| [File Store](block-integrations/basic.md#file-store) | Downloads and stores a file from a URL, data URI, or local path |
| [Find In Dictionary](block-integrations/basic.md#find-in-dictionary) | A block that looks up a value in a dictionary, list, or object by key or index and returns the corresponding value |
| [Find In List](block-integrations/basic.md#find-in-list) | Finds the index of the value in the list |
| [Get All Memories](block-integrations/basic.md#get-all-memories) | Retrieve all memories from Mem0 with optional conversation filtering |

View File

@@ -709,7 +709,7 @@ This is useful for conditional logic where you need to verify if data was return
## File Store
### What it is
Stores the input file in the temporary directory.
Downloads and stores a file from a URL, data URI, or local path. Use this to fetch images, documents, or other files for processing. In CoPilot: saves to workspace (use list_workspace_files to see it). In graphs: outputs a data URI to pass to other blocks.
### How it works
<!-- MANUAL: how_it_works -->
@@ -722,15 +722,15 @@ The block outputs a file path that other blocks can use to access the stored fil
| Input | Description | Type | Required |
|-------|-------------|------|----------|
| file_in | The file to store in the temporary directory, it can be a URL, data URI, or local path. | str (file) | Yes |
| base_64 | Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks). | bool | No |
| file_in | The file to download and store. Can be a URL (https://...), data URI, or local path. | str (file) | Yes |
| base_64 | Whether to produce output in base64 format (not recommended, you can pass the file reference across blocks). | bool | No |
### Outputs
| Output | Description | Type |
|--------|-------------|------|
| error | Error message if the operation failed | str |
| file_out | The relative path to the stored file in the temporary directory. | str (file) |
| file_out | Reference to the stored file. In CoPilot: workspace:// URI (visible in list_workspace_files). In graphs: data URI for passing to other blocks. | str (file) |
### Possible use case
<!-- MANUAL: use_case -->

View File

@@ -12,7 +12,7 @@ Block to attach an audio file to a video file using moviepy.
<!-- MANUAL: how_it_works -->
This block combines a video file with an audio file using the moviepy library. The audio track is attached to the video, optionally with volume adjustment via the volume parameter (1.0 = original volume).
Input files can be URLs, data URIs, or local paths. The output can be returned as either a file path or base64 data URI.
Input files can be URLs, data URIs, or local paths. The output format is automatically determined: `workspace://` URLs in CoPilot, data URIs in graph executions.
<!-- END MANUAL -->
### Inputs
@@ -22,7 +22,6 @@ Input files can be URLs, data URIs, or local paths. The output can be returned a
| video_in | Video input (URL, data URI, or local path). | str (file) | Yes |
| audio_in | Audio input (URL, data URI, or local path). | str (file) | Yes |
| volume | Volume scale for the newly attached audio track (1.0 = original). | float | No |
| output_return_type | Return the final output as a relative path or base64 data URI. | "file_path" \| "data_uri" | No |
### Outputs
@@ -51,7 +50,7 @@ Block to loop a video to a given duration or number of repeats.
<!-- MANUAL: how_it_works -->
This block extends a video by repeating it to reach a target duration or number of loops. Set duration to specify the total length in seconds, or use n_loops to repeat the video a specific number of times.
The looped video is seamlessly concatenated and can be output as a file path or base64 data URI.
The looped video is seamlessly concatenated. The output format is automatically determined: `workspace://` URLs in CoPilot, data URIs in graph executions.
<!-- END MANUAL -->
### Inputs
@@ -61,7 +60,6 @@ The looped video is seamlessly concatenated and can be output as a file path or
| video_in | The input video (can be a URL, data URI, or local path). | str (file) | Yes |
| duration | Target duration (in seconds) to loop the video to. If omitted, defaults to no looping. | float | No |
| n_loops | Number of times to repeat the video. If omitted, defaults to 1 (no repeat). | int | No |
| output_return_type | How to return the output video. Either a relative path or base64 data URI. | "file_path" \| "data_uri" | No |
### Outputs

View File

@@ -277,6 +277,50 @@ async def run(
token = credentials.api_key.get_secret_value()
```
### Handling Files
When your block works with files (images, videos, documents), use `store_media_file()`:
```python
from backend.data.execution import ExecutionContext
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
async def run(
self,
input_data: Input,
*,
execution_context: ExecutionContext,
**kwargs,
):
# PROCESSING: Need local file path for tools like ffmpeg, MoviePy, PIL
local_path = await store_media_file(
file=input_data.video,
execution_context=execution_context,
return_format="for_local_processing",
)
# EXTERNAL API: Need base64 content for APIs like Replicate, OpenAI
image_b64 = await store_media_file(
file=input_data.image,
execution_context=execution_context,
return_format="for_external_api",
)
# OUTPUT: Return to user/next block (auto-adapts to context)
result = await store_media_file(
file=generated_url,
execution_context=execution_context,
return_format="for_block_output", # workspace:// in CoPilot, data URI in graphs
)
yield "image_url", result
```
**Return format options:**
- `"for_local_processing"` - Local file path for processing tools
- `"for_external_api"` - Data URI for external APIs needing base64
- `"for_block_output"` - **Always use for outputs** - automatically picks best format
## Testing Your Block
```bash

View File

@@ -111,6 +111,71 @@ Follow these steps to create and test a new block:
- `graph_exec_id`: The ID of the execution of the agent. This changes every time the agent has a new "run"
- `node_exec_id`: The ID of the execution of the node. This changes every time the node is executed
- `node_id`: The ID of the node that is being executed. It changes every version of the graph, but not every time the node is executed.
- `execution_context`: An `ExecutionContext` object containing user_id, graph_exec_id, workspace_id, and session_id. Required for file handling.
### Handling Files in Blocks
When your block needs to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. This function handles downloading, validation, virus scanning, and storage.
**Import:**
```python
from backend.data.execution import ExecutionContext
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
```
**The `return_format` parameter determines what you get back:**
| Format | Use When | Returns |
|--------|----------|---------|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
**Examples:**
```python
async def run(
self,
input_data: Input,
*,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# PROCESSING: Need to work with file locally (ffmpeg, MoviePy, PIL)
local_path = await store_media_file(
file=input_data.video,
execution_context=execution_context,
return_format="for_local_processing",
)
# local_path = "video.mp4" - use with Path, ffmpeg, subprocess, etc.
full_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
# EXTERNAL API: Need to send content to an API like Replicate
image_b64 = await store_media_file(
file=input_data.image,
execution_context=execution_context,
return_format="for_external_api",
)
# image_b64 = "..." - send to external API
# OUTPUT: Returning result from block to user/next block
result_url = await store_media_file(
file=generated_image_url,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", result_url
# In CoPilot: result_url = "workspace://abc123" (persistent, context-efficient)
# In graphs: result_url = "data:image/png;base64,..." (for next block/display)
```
**Key points:**
- `for_block_output` is the **only** format that auto-adapts to execution context
- Always use `for_block_output` for block outputs unless you have a specific reason not to
- Never manually check for `workspace_id` - let `for_block_output` handle the logic
- The function handles URLs, data URIs, `workspace://` references, and local paths as input
### Field Types