Compare commits

...

28 Commits

Author SHA1 Message Date
Nicholas Tindle
bec0157f9e Update migration to retain 'search' column
Removed the dropping of the 'search' column and its associated index from the migration script.
2026-01-27 23:01:19 -06:00
Nicholas Tindle
57f44e166a fix(backend): update HTTP block tests for execution_context
Update SendAuthenticatedWebRequestBlock to use execution_context
instead of separate graph_exec_id/user_id parameters, matching
the parent class signature.

Update test_http.py to pass execution_context to all test calls.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 22:59:52 -06:00
Nicholas Tindle
2c678f2658 refactor(backend): rename return_format options for clarity and add auto-fallback
Rename store_media_file() return_format options to make intent clear:
- "local_path" -> "for_local_processing" (ffmpeg, MoviePy, PIL)
- "data_uri" -> "for_external_api" (Replicate, OpenAI APIs)
- "workspace_ref" -> "for_block_output" (auto-adapts to context)

The "for_block_output" format now gracefully handles both contexts:
- CoPilot (has workspace): returns workspace:// reference
- Graph execution (no workspace): falls back to data URI

This prevents blocks from failing in graph execution while still
providing workspace persistence in CoPilot.

Also adds documentation to CLAUDE.md, new_blocks.md, and
block-sdk-guide.md explaining when to use each format.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 21:53:09 -06:00
Nicholas Tindle
669e33d709 chore: remove IDEAS.md from tracking
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 21:15:17 -06:00
Nicholas Tindle
953e7a5afb refactor(backend): replace return_content/save_to_workspace with return_format
Simplify store_media_file API with a single return_format parameter:

- "local_path": Return relative path (for local processing like MoviePy)
- "data_uri": Return base64 data URI (for external APIs like Replicate)
- "workspace_ref": Save to workspace and return workspace://id (for CoPilot)

This replaces the confusing combination of return_content and save_to_workspace
parameters. The old parameters are deprecated but still work via a compatibility
layer.

Updated all blocks to use the new explicit return_format parameter:
- Local processing: return_format="local_path"
- External APIs: return_format="data_uri"
- CoPilot outputs: return_format="workspace_ref"

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 21:12:29 -06:00
Nicholas Tindle
e9c55ed5a3 fix(backend): add save_to_workspace param for external API content
When blocks need to pass content to external APIs (Replicate, Discord),
they need data URIs, not workspace references. Added `save_to_workspace`
parameter to control this:

- save_to_workspace=True (default): save to workspace, return ref
- save_to_workspace=False: don't save, return data URI for API use

Updated blocks:
- AIImageEditorBlock (Flux Kontext) - input for Replicate
- AIImageCustomizerBlock - input for Replicate
- SendDiscordFileBlock - input for Discord

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 21:04:44 -06:00
Nicholas Tindle
ce3b8fa8d8 fix(backend): return data URI when reading workspace files for external APIs
When a block reads from workspace:// and needs content for an external API
(e.g., AIImageEditorBlock sending to Replicate), return data URI instead
of workspace reference.

Logic:
- workspace:// input + return_content=True → data URI (for external APIs)
- workspace:// input + return_content=False → local path (for processing)
- URL/data URI + return_content=True → save to workspace, return ref
- URL/data URI + return_content=False → local path

Fixes AIImageEditorBlock "Does not match format 'uri'" error when input
is a workspace reference.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 20:54:06 -06:00
Nicholas Tindle
ce67b7eca4 fix(backend): resolve workspace:// refs to local paths for file processing
When blocks need to process files locally (MoviePy, ffmpeg, etc.), they call
store_media_file with return_content=False expecting a local file path.

Previously, this always returned a workspace:// reference when workspace
was available, causing errors like:
  "File does not exist: /tmp/exec_file/.../workspace:/abc123"

Now the logic is:
- return_content=True: return workspace:// ref (for CoPilot output persistence)
- return_content=False: return local relative path (for file processing)

Also prevents re-saving when input is already a workspace:// reference,
avoiding unique constraint violations on the (workspaceId, path) index.

Fixes MediaDurationBlock, FileReadBlock, LoopVideoBlock, AddAudioToVideoBlock

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 20:52:35 -06:00
Nicholas Tindle
0e34c7e5c4 Merge branch 'dev' into user-workspace 2026-01-27 20:32:29 -06:00
Nicholas Tindle
5c5dd160dd docs(frontend): regenerate OpenAPI spec with workspace endpoints
Adds workspace API endpoints to the generated OpenAPI specification:
- GET /api/workspace - Get workspace info
- GET /api/workspace/files - List workspace files
- POST /api/workspace/files - Upload file
- GET /api/workspace/files/{id} - Get file info
- GET /api/workspace/files/{id}/download - Download file
- DELETE /api/workspace/files/{id} - Delete file

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 20:22:30 -06:00
Nicholas Tindle
759248b7fe refactor(backend/blocks): update block signatures to use ExecutionContext
Update remaining blocks to use the unified ExecutionContext parameter
instead of individual graph_exec_id, user_id, node_exec_id parameters.

Blocks updated:
- FileStoreBlock, FileReadBlock (basic.py, text.py)
- AgentFileInputBlock (io.py)
- MediaDurationBlock, LoopVideoBlock, AddAudioToVideoBlock (media.py)
- SendWebRequestBlock, SendAuthenticatedWebRequestBlock (http.py)
- ScreenshotWebPageBlock (screenshotone.py)
- ReadSpreadsheetBlock (spreadsheet.py)
- SendDiscordFileBlock (discord/bot_blocks.py)
- GmailSendBlock (google/gmail.py)
- Updated test for store_media_file

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 20:22:11 -06:00
Nicholas Tindle
ca5758cce6 feat(backend): pass workspace context through executor
Update executor to propagate workspace context:
- Pass workspace_id in execution kwargs
- Update test utilities with workspace support

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 20:16:34 -06:00
Nicholas Tindle
d40df5a8c8 feat(frontend): render workspace images in chat with AI visibility indicator
Add frontend support for workspace:// image references:
- MarkdownContent: transform workspace:// URLs using generated API
- Route through /api/proxy for proper auth handling
- Add "AI cannot see this image" overlay for workspace files
- Update proxy route to handle binary file downloads
- Format block outputs with workspace refs as markdown images

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 20:16:19 -06:00
Nicholas Tindle
0db228ed43 feat(backend/blocks): store generated media to workspace
Update media-generating blocks to save outputs to workspace:
- AIImageCustomizerBlock: store customized images
- AIImageGeneratorBlock: store generated images
- AIShortformVideoCreatorBlock (3 blocks): store videos
- BannerbearTextOverlayBlock: store generated images
- AIVideoGeneratorBlock (FAL): store generated videos
- AIImageEditorBlock (Flux Kontext): store edited images
- CreateTalkingAvatarVideoBlock: store avatar videos

All blocks now return workspace:// references instead of
direct URLs, enabling persistent storage and preventing
context bloat from large base64 data URIs.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 20:16:03 -06:00
Nicholas Tindle
590f434d0a feat(backend/chat): update CoPilot execution context mapping
Update run_block.py with proper CoPilot-to-graph context mapping:
- graph_id = copilot-session-{session_id} (agent = session)
- graph_exec_id = copilot-session-{session_id} (run = session)
- graph_version = 1 (versions are 1-indexed)
- Pass workspace_id and session_id for file operations
- Each chat session is its own agent with one continuous run

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 20:15:26 -06:00
Nicholas Tindle
8f171a0537 feat(backend/chat): add CoPilot workspace tools
Add tools for CoPilot to manage workspace files:
- list_workspace_files: list files with session scoping
- read_workspace_file: read file content or metadata
- write_workspace_file: save content to workspace
- delete_workspace_file: remove files
- Session-aware operations (default to current session folder)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 20:15:14 -06:00
Nicholas Tindle
c814a43465 feat(backend/api): add workspace REST endpoints
Add API routes for workspace file management:
- GET /api/workspace - get workspace info
- POST /api/workspace/files - upload file
- GET /api/workspace/files - list files
- GET /api/workspace/files/{id} - get file info
- GET /api/workspace/files/{id}/download - download file
- DELETE /api/workspace/files/{id} - soft delete
- Stream file content when signed URLs unavailable

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 20:15:00 -06:00
Nicholas Tindle
5923041fe8 feat(backend): integrate workspace storage into store_media_file
Update store_media_file to use workspace when available:
- Save files to workspace instead of temp exec_file dir
- Return workspace:// references instead of base64 data URIs
- Handle workspace:// input references (read from workspace)
- Pass session_id to WorkspaceManager for session scoping
- Prevents context bloat (100KB file = ~133KB as base64)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 20:14:44 -06:00
Nicholas Tindle
936a2d70db feat(backend): add workspace_id and session_id to ExecutionContext
Extend ExecutionContext with workspace fields:
- workspace_id: user's workspace for file persistence
- session_id: chat session for file isolation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 20:14:26 -06:00
Nicholas Tindle
80c54b7f46 feat(backend): add workspace storage backend abstraction
Add storage layer for workspace files:
- WorkspaceStorageBackend: abstract interface
- GCSWorkspaceStorage: Google Cloud Storage implementation
- LocalWorkspaceStorage: local filesystem for self-hosted
- WorkspaceManager: high-level file operations with session scoping
- Session-scoped virtual paths: /sessions/{session_id}/{filename}
- Fallback to API proxy when GCS signed URLs unavailable

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 20:14:09 -06:00
Nicholas Tindle
28caf01ca7 feat(backend/db): add UserWorkspace and UserWorkspaceFile models
Add database schema for persistent user workspace storage:
- UserWorkspace: one-per-user workspace container
- UserWorkspaceFile: file metadata with virtual paths, checksums, and source tracking
- WorkspaceFileSource enum: UPLOAD, EXECUTION, COPILOT, IMPORT
- CRUD operations in data/workspace.py

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 20:13:49 -06:00
Nicholas Tindle
ea035224bc feat(copilot): Increase max_agent_runs and max_agent_schedules (#11865)
<!-- Clearly explain the need for these changes: -->
Config change to increase the max times an agent can run in the chat and
the max number of scheduels created by copilot in one chat

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> Increases per-chat operational limits for Copilot.
> 
> - Bumps `max_agent_runs` default from `3` to `30` in `ChatConfig`
> - Bumps `max_agent_schedules` default from `3` to `30` in `ChatConfig`
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
93cbae6d27. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->
2026-01-28 01:08:02 +00:00
Nicholas Tindle
62813a1ea6 Delete backend/blocks/video/__init__.py (#11864)
<!-- Clearly explain the need for these changes: -->
oops file
### Changes 🏗️

<!-- Concisely describe all of the changes made in this pull request:
-->
removes file that should have not been commited

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> Removes erroneous `backend/blocks/video/__init__.py`, eliminating an
unintended `video` package.
> 
> - Deletes a placeholder comment-only file
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
3b84576c33. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->
2026-01-28 00:58:49 +00:00
Bently
67405f7eb9 fix(copilot): ensure tool_call/tool_response pairs stay intact during context compaction (#11863)
## Summary

Fixes context compaction breaking tool_call/tool_response pairs, causing
API validation errors.

## Problem

When context compaction slices messages with `messages[-KEEP_RECENT:]`,
a naive slice can separate an assistant message containing `tool_calls`
from its corresponding tool response messages. This causes API
validation errors like:

```
messages.0.content.1: unexpected 'tool_use_id' found in 'tool_result' blocks: orphan_12345.
Each 'tool_result' block must have a corresponding 'tool_use' block in the previous message.
```

## Solution

Added `_ensure_tool_pairs_intact()` helper function that:
1. Detects orphan tool responses in a slice (tool messages whose
`tool_call_id` has no matching assistant message)
2. Extends the slice backwards to include the missing assistant messages
3. Falls back to removing orphan tool responses if the assistant cannot
be found (edge case)

Applied this safeguard to:
- The initial `KEEP_RECENT` slice (line ~990)
- The progressive fallback slices when still over token limit (line
~1079)

## Testing

- Syntax validated with `python -m py_compile`
- Logic reviewed for correctness

## Linear

Fixes SECRT-1839

---
*Debugged by Toran & Orion in #agpt Discord*
2026-01-28 00:21:54 +00:00
Zamil Majdy
171ff6e776 feat(backend): persist long-running tool results to survive SSE disconnects (#11856)
## Summary

Agent generation (`create_agent`, `edit_agent`) can take 1-5 minutes.
Previously, if the user closed their browser tab during this time:
1. The SSE connection would die
2. The tool execution would be cancelled via `CancelledError`
3. The result would be lost - even if the agent-generator service
completed successfully

This PR ensures long-running tool operations survive SSE disconnections.

### Changes 🏗️

**Backend:**
- **base.py**: Added `is_long_running` property to `BaseTool` for tools
to opt-in to background execution
- **create_agent.py / edit_agent.py**: Set `is_long_running = True`
- **models.py**: Added `OperationStartedResponse`,
`OperationPendingResponse`, `OperationInProgressResponse` types
- **service.py**: Modified `_yield_tool_call()` to:
  - Check if tool is `is_long_running`
  - Save "pending" message to chat history immediately
  - Spawn background task that runs independently of SSE
  - Return `operation_started` immediately (don't wait)
  - Update chat history with result when background task completes
- Track running operations for idempotency (prevents duplicate ops on
refresh)
- **db.py**: Added `update_tool_message_content()` to update pending
messages
- **model.py**: Added `invalidate_session_cache()` to clear Redis after
background completion

**Frontend:**
- **useChatMessage.ts**: Added operation message types
- **helpers.ts**: Handle `operation_started`, `operation_pending`,
`operation_in_progress` response types
- **PendingOperationWidget**: New component to display operation status
with spinner
- **ChatMessage.tsx**: Render `PendingOperationWidget` for operation
messages

### How It Works

```
User Request → Save "pending" message → Spawn background task → Return immediately
                                              ↓
                                     Task runs independently of SSE
                                              ↓
                                     On completion: Update message in chat history
                                              ↓
                                     User refreshes → Loads history → Sees result
```

### User Experience

1. User requests agent creation
2. Sees "Agent creation started. You can close this tab - check your
library in a few minutes."
3. Can close browser tab safely
4. When they return, chat shows the completed result (or error)

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] pyright passes (0 errors)
  - [x] TypeScript checks pass
  - [x] Formatters applied

### Test Plan

1. Start agent creation in copilot
2. Close browser tab immediately after seeing "operation_started" 
3. Wait 2-3 minutes
4. Reopen chat
5. Verify: Chat history shows completion message and agent appears in
library

---------

Co-authored-by: Ubbe <hi@ubbe.dev>
2026-01-28 05:09:34 +07:00
Lluis Agusti
349b1f9c79 hotfix(frontend): copilot session handling refinements... 2026-01-28 02:53:45 +07:00
Lluis Agusti
277b0537e9 hotfix(frontend): copilot simplication... 2026-01-28 02:10:18 +07:00
Ubbe
071b3bb5cd fix(frontend): more copilot refinements (#11858)
## Changes 🏗️

On the **Copilot** page:

- prevent unnecessary sidebar repaints 
- show a disclaimer when switching chats on the sidebar to terminate a
current stream
- handle loading better
- save streams better when disconnecting


### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Run the app locally and test the above
2026-01-28 00:49:28 +07:00
78 changed files with 5824 additions and 945 deletions

View File

@@ -194,6 +194,50 @@ ex: do the inputs and outputs tie well together?
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
**Handling files in blocks with `store_media_file()`:**
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
| Format | Use When | Returns |
|--------|----------|---------|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
**Examples:**
```python
# INPUT: Need to process file locally with ffmpeg
local_path = await store_media_file(
file=input_data.video,
execution_context=execution_context,
return_format="for_local_processing",
)
# local_path = "video.mp4" - use with Path/ffmpeg/etc
# INPUT: Need to send to external API like Replicate
image_b64 = await store_media_file(
file=input_data.image,
execution_context=execution_context,
return_format="for_external_api",
)
# image_b64 = "..." - send to API
# OUTPUT: Returning result from block
result_url = await store_media_file(
file=generated_image_url,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", result_url
# In CoPilot: result_url = "workspace://abc123"
# In graphs: result_url = "data:image/png;base64,..."
```
**Key points:**
- `for_block_output` is the ONLY format that auto-adapts to execution context
- Always use `for_block_output` for block outputs unless you have a specific reason not to
- Never hardcode workspace checks - let `for_block_output` handle it
**Modifying the API:**
1. Update route in `/backend/backend/server/routers/`

View File

@@ -33,9 +33,15 @@ class ChatConfig(BaseSettings):
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(default=3, description="Maximum number of retries")
max_agent_runs: int = Field(default=3, description="Maximum number of agent runs")
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
max_agent_schedules: int = Field(
default=3, description="Maximum number of agent schedules"
default=30, description="Maximum number of agent schedules"
)
# Long-running operation configuration
long_running_operation_ttl: int = Field(
default=600,
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
)
# Langfuse Prompt Management Configuration

View File

@@ -247,3 +247,45 @@ async def get_chat_session_message_count(session_id: str) -> int:
"""Get the number of messages in a chat session."""
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
return count
async def update_tool_message_content(
session_id: str,
tool_call_id: str,
new_content: str,
) -> bool:
"""Update the content of a tool message in chat history.
Used by background tasks to update pending operation messages with final results.
Args:
session_id: The chat session ID.
tool_call_id: The tool call ID to find the message.
new_content: The new content to set.
Returns:
True if a message was updated, False otherwise.
"""
try:
result = await PrismaChatMessage.prisma().update_many(
where={
"sessionId": session_id,
"toolCallId": tool_call_id,
},
data={
"content": new_content,
},
)
if result == 0:
logger.warning(
f"No message found to update for session {session_id}, "
f"tool_call_id {tool_call_id}"
)
return False
return True
except Exception as e:
logger.error(
f"Failed to update tool message for session {session_id}, "
f"tool_call_id {tool_call_id}: {e}"
)
return False

View File

@@ -295,6 +295,21 @@ async def cache_chat_session(session: ChatSession) -> None:
await _cache_session(session)
async def invalidate_session_cache(session_id: str) -> None:
"""Invalidate a chat session from Redis cache.
Used by background tasks to ensure fresh data is loaded on next access.
This is best-effort - Redis failures are logged but don't fail the operation.
"""
try:
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
except Exception as e:
# Best-effort: log but don't fail - cache will expire naturally
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
async def _get_session_from_db(session_id: str) -> ChatSession | None:
"""Get a chat session from the database."""
prisma_session = await chat_db.get_chat_session(session_id)

View File

@@ -17,6 +17,7 @@ from openai import (
)
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
from backend.data.redis_client import get_redis_async
from backend.data.understanding import (
format_understanding_for_prompt,
get_business_understanding,
@@ -24,6 +25,7 @@ from backend.data.understanding import (
from backend.util.exceptions import NotFoundError
from backend.util.settings import Settings
from . import db as chat_db
from .config import ChatConfig
from .model import (
ChatMessage,
@@ -31,6 +33,7 @@ from .model import (
Usage,
cache_chat_session,
get_chat_session,
invalidate_session_cache,
update_session_title,
upsert_chat_session,
)
@@ -48,8 +51,13 @@ from .response_model import (
StreamToolOutputAvailable,
StreamUsage,
)
from .tools import execute_tool, tools
from .tools.models import ErrorResponse
from .tools import execute_tool, get_tool, tools
from .tools.models import (
ErrorResponse,
OperationInProgressResponse,
OperationPendingResponse,
OperationStartedResponse,
)
from .tracking import track_user_message
logger = logging.getLogger(__name__)
@@ -61,6 +69,43 @@ client = openai.AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
langfuse = get_client()
# Redis key prefix for tracking running long-running operations
# Used for idempotency across Kubernetes pods - prevents duplicate executions on browser refresh
RUNNING_OPERATION_PREFIX = "chat:running_operation:"
# Module-level set to hold strong references to background tasks.
# This prevents asyncio from garbage collecting tasks before they complete.
# Tasks are automatically removed on completion via done_callback.
_background_tasks: set[asyncio.Task] = set()
async def _mark_operation_started(tool_call_id: str) -> bool:
"""Mark a long-running operation as started (Redis-based).
Returns True if successfully marked (operation was not already running),
False if operation was already running (lost race condition).
Raises exception if Redis is unavailable (fail-closed).
"""
redis = await get_redis_async()
key = f"{RUNNING_OPERATION_PREFIX}{tool_call_id}"
# SETNX with TTL - atomic "set if not exists"
result = await redis.set(key, "1", ex=config.long_running_operation_ttl, nx=True)
return result is not None
async def _mark_operation_completed(tool_call_id: str) -> None:
"""Mark a long-running operation as completed (remove Redis key).
This is best-effort - if Redis fails, the TTL will eventually clean up.
"""
try:
redis = await get_redis_async()
key = f"{RUNNING_OPERATION_PREFIX}{tool_call_id}"
await redis.delete(key)
except Exception as e:
# Non-critical: TTL will clean up eventually
logger.warning(f"Failed to delete running operation key {tool_call_id}: {e}")
class LangfuseNotConfiguredError(Exception):
"""Raised when Langfuse is required but not configured."""
@@ -315,6 +360,7 @@ async def stream_chat_completion(
has_yielded_end = False
has_yielded_error = False
has_done_tool_call = False
has_long_running_tool_call = False # Track if we had a long-running tool call
has_received_text = False
text_streaming_ended = False
tool_response_messages: list[ChatMessage] = []
@@ -336,7 +382,6 @@ async def stream_chat_completion(
system_prompt=system_prompt,
text_block_id=text_block_id,
):
if isinstance(chunk, StreamTextStart):
# Emit text-start before first text delta
if not has_received_text:
@@ -394,13 +439,34 @@ async def stream_chat_completion(
if isinstance(chunk.output, str)
else orjson.dumps(chunk.output).decode("utf-8")
)
tool_response_messages.append(
ChatMessage(
role="tool",
content=result_content,
tool_call_id=chunk.toolCallId,
# Skip saving long-running operation responses - messages already saved in _yield_tool_call
# Use JSON parsing instead of substring matching to avoid false positives
is_long_running_response = False
try:
parsed = orjson.loads(result_content)
if isinstance(parsed, dict) and parsed.get("type") in (
"operation_started",
"operation_in_progress",
):
is_long_running_response = True
except (orjson.JSONDecodeError, TypeError):
pass # Not JSON or not a dict - treat as regular response
if is_long_running_response:
# Remove from accumulated_tool_calls since assistant message was already saved
accumulated_tool_calls[:] = [
tc
for tc in accumulated_tool_calls
if tc["id"] != chunk.toolCallId
]
has_long_running_tool_call = True
else:
tool_response_messages.append(
ChatMessage(
role="tool",
content=result_content,
tool_call_id=chunk.toolCallId,
)
)
)
has_done_tool_call = True
# Track if any tool execution failed
if not chunk.success:
@@ -576,7 +642,14 @@ async def stream_chat_completion(
logger.info(
f"Extended session messages, new message_count={len(session.messages)}"
)
if messages_to_save or has_appended_streaming_message:
# Save if there are regular (non-long-running) tool responses or streaming message.
# Long-running tools save their own state, but we still need to save regular tools
# that may be in the same response.
has_regular_tool_responses = len(tool_response_messages) > 0
if has_regular_tool_responses or (
not has_long_running_tool_call
and (messages_to_save or has_appended_streaming_message)
):
await upsert_chat_session(session)
else:
logger.info(
@@ -585,7 +658,9 @@ async def stream_chat_completion(
)
# If we did a tool call, stream the chat completion again to get the next response
if has_done_tool_call:
# Skip only if ALL tools were long-running (they handle their own completion)
has_regular_tools = len(tool_response_messages) > 0
if has_done_tool_call and (has_regular_tools or not has_long_running_tool_call):
logger.info(
"Tool call executed, streaming chat completion again to get assistant response"
)
@@ -725,6 +800,114 @@ async def _summarize_messages(
return summary or "No summary available."
def _ensure_tool_pairs_intact(
recent_messages: list[dict],
all_messages: list[dict],
start_index: int,
) -> list[dict]:
"""
Ensure tool_call/tool_response pairs stay together after slicing.
When slicing messages for context compaction, a naive slice can separate
an assistant message containing tool_calls from its corresponding tool
response messages. This causes API validation errors (e.g., Anthropic's
"unexpected tool_use_id found in tool_result blocks").
This function checks for orphan tool responses in the slice and extends
backwards to include their corresponding assistant messages.
Args:
recent_messages: The sliced messages to validate
all_messages: The complete message list (for looking up missing assistants)
start_index: The index in all_messages where recent_messages begins
Returns:
A potentially extended list of messages with tool pairs intact
"""
if not recent_messages:
return recent_messages
# Collect all tool_call_ids from assistant messages in the slice
available_tool_call_ids: set[str] = set()
for msg in recent_messages:
if msg.get("role") == "assistant" and msg.get("tool_calls"):
for tc in msg["tool_calls"]:
tc_id = tc.get("id")
if tc_id:
available_tool_call_ids.add(tc_id)
# Find orphan tool responses (tool messages whose tool_call_id is missing)
orphan_tool_call_ids: set[str] = set()
for msg in recent_messages:
if msg.get("role") == "tool":
tc_id = msg.get("tool_call_id")
if tc_id and tc_id not in available_tool_call_ids:
orphan_tool_call_ids.add(tc_id)
if not orphan_tool_call_ids:
# No orphans, slice is valid
return recent_messages
# Find the assistant messages that contain the orphan tool_call_ids
# Search backwards from start_index in all_messages
messages_to_prepend: list[dict] = []
for i in range(start_index - 1, -1, -1):
msg = all_messages[i]
if msg.get("role") == "assistant" and msg.get("tool_calls"):
msg_tool_ids = {tc.get("id") for tc in msg["tool_calls"] if tc.get("id")}
if msg_tool_ids & orphan_tool_call_ids:
# This assistant message has tool_calls we need
# Also collect its contiguous tool responses that follow it
assistant_and_responses: list[dict] = [msg]
# Scan forward from this assistant to collect tool responses
for j in range(i + 1, start_index):
following_msg = all_messages[j]
if following_msg.get("role") == "tool":
tool_id = following_msg.get("tool_call_id")
if tool_id and tool_id in msg_tool_ids:
assistant_and_responses.append(following_msg)
else:
# Stop at first non-tool message
break
# Prepend the assistant and its tool responses (maintain order)
messages_to_prepend = assistant_and_responses + messages_to_prepend
# Mark these as found
orphan_tool_call_ids -= msg_tool_ids
# Also add this assistant's tool_call_ids to available set
available_tool_call_ids |= msg_tool_ids
if not orphan_tool_call_ids:
# Found all missing assistants
break
if orphan_tool_call_ids:
# Some tool_call_ids couldn't be resolved - remove those tool responses
# This shouldn't happen in normal operation but handles edge cases
logger.warning(
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
"Removing orphan tool responses."
)
recent_messages = [
msg
for msg in recent_messages
if not (
msg.get("role") == "tool"
and msg.get("tool_call_id") in orphan_tool_call_ids
)
]
if messages_to_prepend:
logger.info(
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
f"tool_call/tool_response pairs"
)
return messages_to_prepend + recent_messages
return recent_messages
async def _stream_chat_chunks(
session: ChatSession,
tools: list[ChatCompletionToolParam],
@@ -816,7 +999,15 @@ async def _stream_chat_chunks(
# Always attempt mitigation when over limit, even with few messages
if messages:
# Split messages based on whether system prompt exists
recent_messages = messages[-KEEP_RECENT:]
# Calculate start index for the slice
slice_start = max(0, len(messages_dict) - KEEP_RECENT)
recent_messages = messages_dict[-KEEP_RECENT:]
# Ensure tool_call/tool_response pairs stay together
# This prevents API errors from orphan tool responses
recent_messages = _ensure_tool_pairs_intact(
recent_messages, messages_dict, slice_start
)
if has_system_prompt:
# Keep system prompt separate, summarize everything between system and recent
@@ -903,6 +1094,13 @@ async def _stream_chat_chunks(
if len(recent_messages) >= keep_count
else recent_messages
)
# Ensure tool pairs stay intact in the reduced slice
reduced_slice_start = max(
0, len(recent_messages) - keep_count
)
reduced_recent = _ensure_tool_pairs_intact(
reduced_recent, recent_messages, reduced_slice_start
)
if has_system_prompt:
messages = [
system_msg,
@@ -961,7 +1159,10 @@ async def _stream_chat_chunks(
# Create a base list excluding system prompt to avoid duplication
# This is the pool of messages we'll slice from in the loop
base_msgs = messages[1:] if has_system_prompt else messages
# Use messages_dict for type consistency with _ensure_tool_pairs_intact
base_msgs = (
messages_dict[1:] if has_system_prompt else messages_dict
)
# Try progressively smaller keep counts
new_token_count = token_count # Initialize with current count
@@ -984,6 +1185,12 @@ async def _stream_chat_chunks(
# Slice from base_msgs to get recent messages (without system prompt)
recent_messages = base_msgs[-keep_count:]
# Ensure tool pairs stay intact in the reduced slice
reduced_slice_start = max(0, len(base_msgs) - keep_count)
recent_messages = _ensure_tool_pairs_intact(
recent_messages, base_msgs, reduced_slice_start
)
if has_system_prompt:
messages = [system_msg] + recent_messages
else:
@@ -1260,14 +1467,17 @@ async def _yield_tool_call(
"""
Yield a tool call and its execution result.
For long-running tools, yields heartbeat events every 15 seconds to keep
the SSE connection alive through proxies and load balancers.
For tools marked with `is_long_running=True` (like agent generation), spawns a
background task so the operation survives SSE disconnections. For other tools,
yields heartbeat events every 15 seconds to keep the SSE connection alive.
Raises:
orjson.JSONDecodeError: If tool call arguments cannot be parsed as JSON
KeyError: If expected tool call fields are missing
TypeError: If tool call structure is invalid
"""
import uuid as uuid_module
tool_name = tool_calls[yield_idx]["function"]["name"]
tool_call_id = tool_calls[yield_idx]["id"]
logger.info(f"Yielding tool call: {tool_calls[yield_idx]}")
@@ -1285,7 +1495,151 @@ async def _yield_tool_call(
input=arguments,
)
# Run tool execution in background task with heartbeats to keep connection alive
# Check if this tool is long-running (survives SSE disconnection)
tool = get_tool(tool_name)
if tool and tool.is_long_running:
# Atomic check-and-set: returns False if operation already running (lost race)
if not await _mark_operation_started(tool_call_id):
logger.info(
f"Tool call {tool_call_id} already in progress, returning status"
)
# Build dynamic message based on tool name
if tool_name == "create_agent":
in_progress_msg = "Agent creation already in progress. Please wait..."
elif tool_name == "edit_agent":
in_progress_msg = "Agent edit already in progress. Please wait..."
else:
in_progress_msg = f"{tool_name} already in progress. Please wait..."
yield StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=OperationInProgressResponse(
message=in_progress_msg,
tool_call_id=tool_call_id,
).model_dump_json(),
success=True,
)
return
# Generate operation ID
operation_id = str(uuid_module.uuid4())
# Build a user-friendly message based on tool and arguments
if tool_name == "create_agent":
agent_desc = arguments.get("description", "")
# Truncate long descriptions for the message
desc_preview = (
(agent_desc[:100] + "...") if len(agent_desc) > 100 else agent_desc
)
pending_msg = (
f"Creating your agent: {desc_preview}"
if desc_preview
else "Creating agent... This may take a few minutes."
)
started_msg = (
"Agent creation started. You can close this tab - "
"check your library in a few minutes."
)
elif tool_name == "edit_agent":
changes = arguments.get("changes", "")
changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes
pending_msg = (
f"Editing agent: {changes_preview}"
if changes_preview
else "Editing agent... This may take a few minutes."
)
started_msg = (
"Agent edit started. You can close this tab - "
"check your library in a few minutes."
)
else:
pending_msg = f"Running {tool_name}... This may take a few minutes."
started_msg = (
f"{tool_name} started. You can close this tab - "
"check back in a few minutes."
)
# Track appended messages for rollback on failure
assistant_message: ChatMessage | None = None
pending_message: ChatMessage | None = None
# Wrap session save and task creation in try-except to release lock on failure
try:
# Save assistant message with tool_call FIRST (required by LLM)
assistant_message = ChatMessage(
role="assistant",
content="",
tool_calls=[tool_calls[yield_idx]],
)
session.messages.append(assistant_message)
# Then save pending tool result
pending_message = ChatMessage(
role="tool",
content=OperationPendingResponse(
message=pending_msg,
operation_id=operation_id,
tool_name=tool_name,
).model_dump_json(),
tool_call_id=tool_call_id,
)
session.messages.append(pending_message)
await upsert_chat_session(session)
logger.info(
f"Saved pending operation {operation_id} for tool {tool_name} "
f"in session {session.session_id}"
)
# Store task reference in module-level set to prevent GC before completion
task = asyncio.create_task(
_execute_long_running_tool(
tool_name=tool_name,
parameters=arguments,
tool_call_id=tool_call_id,
operation_id=operation_id,
session_id=session.session_id,
user_id=session.user_id,
)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
except Exception as e:
# Roll back appended messages to prevent data corruption on subsequent saves
if (
pending_message
and session.messages
and session.messages[-1] == pending_message
):
session.messages.pop()
if (
assistant_message
and session.messages
and session.messages[-1] == assistant_message
):
session.messages.pop()
# Release the Redis lock since the background task won't be spawned
await _mark_operation_completed(tool_call_id)
logger.error(
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
)
raise
# Return immediately - don't wait for completion
yield StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=OperationStartedResponse(
message=started_msg,
operation_id=operation_id,
tool_name=tool_name,
).model_dump_json(),
success=True,
)
return
# Normal flow: Run tool execution in background task with heartbeats
tool_task = asyncio.create_task(
execute_tool(
tool_name=tool_name,
@@ -1335,3 +1689,190 @@ async def _yield_tool_call(
)
yield tool_execution_response
async def _execute_long_running_tool(
tool_name: str,
parameters: dict[str, Any],
tool_call_id: str,
operation_id: str,
session_id: str,
user_id: str | None,
) -> None:
"""Execute a long-running tool in background and update chat history with result.
This function runs independently of the SSE connection, so the operation
survives if the user closes their browser tab.
"""
try:
# Load fresh session (not stale reference)
session = await get_chat_session(session_id, user_id)
if not session:
logger.error(f"Session {session_id} not found for background tool")
return
# Execute the actual tool
result = await execute_tool(
tool_name=tool_name,
parameters=parameters,
tool_call_id=tool_call_id,
user_id=user_id,
session=session,
)
# Update the pending message with result
await _update_pending_operation(
session_id=session_id,
tool_call_id=tool_call_id,
result=(
result.output
if isinstance(result.output, str)
else orjson.dumps(result.output).decode("utf-8")
),
)
logger.info(f"Background tool {tool_name} completed for session {session_id}")
# Generate LLM continuation so user sees response when they poll/refresh
await _generate_llm_continuation(session_id=session_id, user_id=user_id)
except Exception as e:
logger.error(f"Background tool {tool_name} failed: {e}", exc_info=True)
error_response = ErrorResponse(
message=f"Tool {tool_name} failed: {str(e)}",
)
await _update_pending_operation(
session_id=session_id,
tool_call_id=tool_call_id,
result=error_response.model_dump_json(),
)
finally:
await _mark_operation_completed(tool_call_id)
async def _update_pending_operation(
session_id: str,
tool_call_id: str,
result: str,
) -> None:
"""Update the pending tool message with final result.
This is called by background tasks when long-running operations complete.
"""
# Update the message in database
updated = await chat_db.update_tool_message_content(
session_id=session_id,
tool_call_id=tool_call_id,
new_content=result,
)
if updated:
# Invalidate Redis cache so next load gets fresh data
# Wrap in try/except to prevent cache failures from triggering error handling
# that would overwrite our successful DB update
try:
await invalidate_session_cache(session_id)
except Exception as e:
# Non-critical: cache will eventually be refreshed on next load
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
logger.info(
f"Updated pending operation for tool_call_id {tool_call_id} "
f"in session {session_id}"
)
else:
logger.warning(
f"Failed to update pending operation for tool_call_id {tool_call_id} "
f"in session {session_id}"
)
async def _generate_llm_continuation(
session_id: str,
user_id: str | None,
) -> None:
"""Generate an LLM response after a long-running tool completes.
This is called by background tasks to continue the conversation
after a tool result is saved. The response is saved to the database
so users see it when they refresh or poll.
"""
try:
# Load fresh session from DB (bypass cache to get the updated tool result)
await invalidate_session_cache(session_id)
session = await get_chat_session(session_id, user_id)
if not session:
logger.error(f"Session {session_id} not found for LLM continuation")
return
# Build system prompt
system_prompt, _ = await _build_system_prompt(user_id)
# Build messages in OpenAI format
messages = session.to_openai_messages()
if system_prompt:
from openai.types.chat import ChatCompletionSystemMessageParam
system_message = ChatCompletionSystemMessageParam(
role="system",
content=system_prompt,
)
messages = [system_message] + messages
# Build extra_body for tracing
extra_body: dict[str, Any] = {
"posthogProperties": {
"environment": settings.config.app_env.value,
},
}
if user_id:
extra_body["user"] = user_id[:128]
extra_body["posthogDistinctId"] = user_id
if session_id:
extra_body["session_id"] = session_id[:128]
# Make non-streaming LLM call (no tools - just text response)
from typing import cast
from openai.types.chat import ChatCompletionMessageParam
# No tools parameter = text-only response (no tool calls)
response = await client.chat.completions.create(
model=config.model,
messages=cast(list[ChatCompletionMessageParam], messages),
extra_body=extra_body,
)
if response.choices and response.choices[0].message.content:
assistant_content = response.choices[0].message.content
# Reload session from DB to avoid race condition with user messages
# that may have been sent while we were generating the LLM response
fresh_session = await get_chat_session(session_id, user_id)
if not fresh_session:
logger.error(
f"Session {session_id} disappeared during LLM continuation"
)
return
# Save assistant message to database
assistant_message = ChatMessage(
role="assistant",
content=assistant_content,
)
fresh_session.messages.append(assistant_message)
# Save to database (not cache) to persist the response
await upsert_chat_session(fresh_session)
# Invalidate cache so next poll/refresh gets fresh data
await invalidate_session_cache(session_id)
logger.info(
f"Generated LLM continuation for session {session_id}, "
f"response length: {len(assistant_content)}"
)
else:
logger.warning(f"LLM continuation returned empty response for {session_id}")
except Exception as e:
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)

View File

@@ -18,6 +18,12 @@ from .get_doc_page import GetDocPageTool
from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .search_docs import SearchDocsTool
from .workspace_tools import (
DeleteWorkspaceFileTool,
ListWorkspaceFilesTool,
ReadWorkspaceFileTool,
WriteWorkspaceFileTool,
)
if TYPE_CHECKING:
from backend.api.features.chat.response_model import StreamToolOutputAvailable
@@ -37,6 +43,11 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"view_agent_output": AgentOutputTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
# Workspace tools for CoPilot file operations
"list_workspace_files": ListWorkspaceFilesTool(),
"read_workspace_file": ReadWorkspaceFileTool(),
"write_workspace_file": WriteWorkspaceFileTool(),
"delete_workspace_file": DeleteWorkspaceFileTool(),
}
# Export individual tool instances for backwards compatibility
@@ -49,6 +60,11 @@ tools: list[ChatCompletionToolParam] = [
]
def get_tool(tool_name: str) -> BaseTool | None:
"""Get a tool instance by name."""
return TOOL_REGISTRY.get(tool_name)
async def execute_tool(
tool_name: str,
parameters: dict[str, Any],
@@ -57,7 +73,7 @@ async def execute_tool(
tool_call_id: str,
) -> "StreamToolOutputAvailable":
"""Execute a tool by name."""
tool = TOOL_REGISTRY.get(tool_name)
tool = get_tool(tool_name)
if not tool:
raise ValueError(f"Tool {tool_name} not found")

View File

@@ -36,6 +36,16 @@ class BaseTool:
"""Whether this tool requires authentication."""
return False
@property
def is_long_running(self) -> bool:
"""Whether this tool is long-running and should execute in background.
Long-running tools (like agent generation) are executed via background
tasks to survive SSE disconnections. The result is persisted to chat
history and visible when the user refreshes.
"""
return False
def as_openai_tool(self) -> ChatCompletionToolParam:
"""Convert to OpenAI tool format."""
return ChatCompletionToolParam(

View File

@@ -42,6 +42,10 @@ class CreateAgentTool(BaseTool):
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {

View File

@@ -42,6 +42,10 @@ class EditAgentTool(BaseTool):
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {

View File

@@ -28,6 +28,17 @@ class ResponseType(str, Enum):
BLOCK_OUTPUT = "block_output"
DOC_SEARCH_RESULTS = "doc_search_results"
DOC_PAGE = "doc_page"
# Workspace response types
WORKSPACE_FILE_LIST = "workspace_file_list"
WORKSPACE_FILE_CONTENT = "workspace_file_content"
WORKSPACE_FILE_METADATA = "workspace_file_metadata"
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
WORKSPACE_FILE_INFO = "workspace_file_info"
# Long-running operation types
OPERATION_STARTED = "operation_started"
OPERATION_PENDING = "operation_pending"
OPERATION_IN_PROGRESS = "operation_in_progress"
# Base response model
@@ -334,3 +345,39 @@ class BlockOutputResponse(ToolResponseBase):
block_name: str
outputs: dict[str, list[Any]]
success: bool = True
# Long-running operation models
class OperationStartedResponse(ToolResponseBase):
"""Response when a long-running operation has been started in the background.
This is returned immediately to the client while the operation continues
to execute. The user can close the tab and check back later.
"""
type: ResponseType = ResponseType.OPERATION_STARTED
operation_id: str
tool_name: str
class OperationPendingResponse(ToolResponseBase):
"""Response stored in chat history while a long-running operation is executing.
This is persisted to the database so users see a pending state when they
refresh before the operation completes.
"""
type: ResponseType = ResponseType.OPERATION_PENDING
operation_id: str
tool_name: str
class OperationInProgressResponse(ToolResponseBase):
"""Response when an operation is already in progress.
Returned for idempotency when the same tool_call_id is requested again
while the background task is still running.
"""
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
tool_call_id: str

View File

@@ -1,6 +1,7 @@
"""Tool for executing blocks directly."""
import logging
import uuid
from collections import defaultdict
from typing import Any
@@ -8,6 +9,7 @@ from backend.api.features.chat.model import ChatSession
from backend.data.block import get_block
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
@@ -223,11 +225,48 @@ class RunBlockTool(BaseTool):
)
try:
# Fetch actual credentials and prepare kwargs for block execution
# Create execution context with defaults (blocks may require it)
# Get or create user's workspace for CoPilot file operations
workspace = await get_or_create_workspace(user_id)
# Generate synthetic IDs for CoPilot context
# Each chat session is treated as its own agent with one continuous run
# This means:
# - graph_id (agent) = session (memories scoped to session when limit_to_agent=True)
# - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True)
# - node_exec_id = unique per block execution
synthetic_graph_id = f"copilot-session-{session.session_id}"
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
synthetic_node_id = f"copilot-node-{block_id}"
synthetic_node_exec_id = (
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
)
# Create unified execution context with all required fields
execution_context = ExecutionContext(
# Execution identity
user_id=user_id,
graph_id=synthetic_graph_id,
graph_exec_id=synthetic_graph_exec_id,
graph_version=1, # Versions are 1-indexed
node_id=synthetic_node_id,
node_exec_id=synthetic_node_exec_id,
# Workspace with session scoping
workspace_id=workspace.id,
session_id=session.session_id,
)
# Prepare kwargs for block execution
# Keep individual kwargs for backwards compatibility with existing blocks
exec_kwargs: dict[str, Any] = {
"user_id": user_id,
"execution_context": ExecutionContext(),
"execution_context": execution_context,
# Legacy: individual kwargs for blocks not yet using execution_context
"workspace_id": workspace.id,
"graph_exec_id": synthetic_graph_exec_id,
"node_exec_id": synthetic_node_exec_id,
"node_id": synthetic_node_id,
"graph_version": 1, # Versions are 1-indexed
"graph_id": synthetic_graph_id,
}
for field_name, cred_meta in matched_credentials.items():

View File

@@ -0,0 +1,619 @@
"""CoPilot tools for workspace file operations."""
import base64
import logging
from typing import Any, Optional
from prisma.enums import WorkspaceFileSource
from pydantic import BaseModel
from backend.api.features.chat.model import ChatSession
from backend.data.workspace import get_or_create_workspace
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import MAX_FILE_SIZE_BYTES, WorkspaceManager
from .base import BaseTool
from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
class WorkspaceFileInfoData(BaseModel):
"""Data model for workspace file information (not a response itself)."""
file_id: str
name: str
path: str
mime_type: str
size_bytes: int
source: str
class WorkspaceFileListResponse(ToolResponseBase):
"""Response containing list of workspace files."""
type: ResponseType = ResponseType.WORKSPACE_FILE_LIST
files: list[WorkspaceFileInfoData]
total_count: int
class WorkspaceFileContentResponse(ToolResponseBase):
"""Response containing workspace file content (legacy, for small text files)."""
type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT
file_id: str
name: str
path: str
mime_type: str
content_base64: str
class WorkspaceFileMetadataResponse(ToolResponseBase):
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA
file_id: str
name: str
path: str
mime_type: str
size_bytes: int
download_url: str
preview: str | None = None # First 500 chars for text files
class WorkspaceWriteResponse(ToolResponseBase):
"""Response after writing a file to workspace."""
type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN
file_id: str
name: str
path: str
size_bytes: int
class WorkspaceDeleteResponse(ToolResponseBase):
"""Response after deleting a file from workspace."""
type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED
file_id: str
success: bool
class ListWorkspaceFilesTool(BaseTool):
"""Tool for listing files in user's workspace."""
@property
def name(self) -> str:
return "list_workspace_files"
@property
def description(self) -> str:
return (
"List files in the user's workspace. "
"Returns file names, paths, sizes, and metadata. "
"Optionally filter by path prefix."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"path_prefix": {
"type": "string",
"description": (
"Optional path prefix to filter files "
"(e.g., '/documents/' to list only files in documents folder). "
"By default, only files from the current session are listed."
),
},
"limit": {
"type": "integer",
"description": "Maximum number of files to return (default 50, max 100)",
"minimum": 1,
"maximum": 100,
},
"include_all_sessions": {
"type": "boolean",
"description": (
"If true, list files from all sessions. "
"Default is false (only current session's files)."
),
},
},
"required": [],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
path_prefix: Optional[str] = kwargs.get("path_prefix")
limit = min(kwargs.get("limit", 50), 100)
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
files = await manager.list_files(
path=path_prefix,
limit=limit,
include_all_sessions=include_all_sessions,
)
total = await manager.get_file_count()
file_infos = [
WorkspaceFileInfoData(
file_id=f.id,
name=f.name,
path=f.path,
mime_type=f.mimeType,
size_bytes=f.sizeBytes,
source=f.source,
)
for f in files
]
scope_msg = "all sessions" if include_all_sessions else "current session"
return WorkspaceFileListResponse(
files=file_infos,
total_count=total,
message=f"Found {len(files)} files in workspace ({scope_msg})",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error listing workspace files: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to list workspace files: {str(e)}",
error=str(e),
session_id=session_id,
)
class ReadWorkspaceFileTool(BaseTool):
"""Tool for reading file content from workspace."""
# Size threshold for returning full content vs metadata+URL
# Files larger than this return metadata with download URL to prevent context bloat
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
# Preview size for text files
PREVIEW_SIZE = 500
@property
def name(self) -> str:
return "read_workspace_file"
@property
def description(self) -> str:
return (
"Read a file from the user's workspace. "
"Specify either file_id or path to identify the file. "
"For small text files, returns content directly. "
"For large or binary files, returns metadata and a download URL. "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"file_id": {
"type": "string",
"description": "The file's unique ID (from list_workspace_files)",
},
"path": {
"type": "string",
"description": (
"The virtual file path (e.g., '/documents/report.pdf'). "
"Scoped to current session by default."
),
},
"force_download_url": {
"type": "boolean",
"description": (
"If true, always return metadata+URL instead of inline content. "
"Default is false (auto-selects based on file size/type)."
),
},
},
"required": [], # At least one must be provided
}
@property
def requires_auth(self) -> bool:
return True
def _is_text_mime_type(self, mime_type: str) -> bool:
"""Check if the MIME type is a text-based type."""
text_types = [
"text/",
"application/json",
"application/xml",
"application/javascript",
"application/x-python",
"application/x-sh",
]
return any(mime_type.startswith(t) for t in text_types)
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
force_download_url: bool = kwargs.get("force_download_url", False)
if not file_id and not path:
return ErrorResponse(
message="Please provide either file_id or path",
session_id=session_id,
)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
# Get file info
if file_id:
file_info = await manager.get_file_info(file_id)
if file_info is None:
return ErrorResponse(
message=f"File not found: {file_id}",
session_id=session_id,
)
target_file_id = file_id
else:
# path is guaranteed to be non-None here due to the check above
assert path is not None
file_info = await manager.get_file_info_by_path(path)
if file_info is None:
return ErrorResponse(
message=f"File not found at path: {path}",
session_id=session_id,
)
target_file_id = file_info.id
# Decide whether to return inline content or metadata+URL
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
is_text_file = self._is_text_mime_type(file_info.mimeType)
# Return inline content for small text files (unless force_download_url)
if is_small_file and is_text_file and not force_download_url:
content = await manager.read_file_by_id(target_file_id)
content_b64 = base64.b64encode(content).decode("utf-8")
return WorkspaceFileContentResponse(
file_id=file_info.id,
name=file_info.name,
path=file_info.path,
mime_type=file_info.mimeType,
content_base64=content_b64,
message=f"Successfully read file: {file_info.name}",
session_id=session_id,
)
# Return metadata + URL for large or binary files
# This prevents context bloat (100KB file = ~133KB as base64)
download_url = await manager.get_download_url(target_file_id)
# Generate preview for text files
preview: str | None = None
if is_text_file:
try:
content = await manager.read_file_by_id(target_file_id)
preview_text = content[: self.PREVIEW_SIZE].decode(
"utf-8", errors="replace"
)
if len(content) > self.PREVIEW_SIZE:
preview_text += "..."
preview = preview_text
except Exception:
pass # Preview is optional
return WorkspaceFileMetadataResponse(
file_id=file_info.id,
name=file_info.name,
path=file_info.path,
mime_type=file_info.mimeType,
size_bytes=file_info.sizeBytes,
download_url=download_url,
preview=preview,
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
session_id=session_id,
)
except FileNotFoundError as e:
return ErrorResponse(
message=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Error reading workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to read workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)
class WriteWorkspaceFileTool(BaseTool):
"""Tool for writing files to workspace."""
@property
def name(self) -> str:
return "write_workspace_file"
@property
def description(self) -> str:
return (
"Write or create a file in the user's workspace. "
"Provide the content as a base64-encoded string. "
"Maximum file size is 100MB. "
"Files are saved to the current session's folder by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "Name for the file (e.g., 'report.pdf')",
},
"content_base64": {
"type": "string",
"description": "Base64-encoded file content",
},
"path": {
"type": "string",
"description": (
"Optional virtual path where to save the file "
"(e.g., '/documents/report.pdf'). "
"Defaults to '/{filename}'. Scoped to current session."
),
},
"mime_type": {
"type": "string",
"description": (
"Optional MIME type of the file. "
"Auto-detected from filename if not provided."
),
},
"overwrite": {
"type": "boolean",
"description": "Whether to overwrite if file exists at path (default: false)",
},
},
"required": ["filename", "content_base64"],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
filename: str = kwargs.get("filename", "")
content_b64: str = kwargs.get("content_base64", "")
path: Optional[str] = kwargs.get("path")
mime_type: Optional[str] = kwargs.get("mime_type")
overwrite: bool = kwargs.get("overwrite", False)
if not filename:
return ErrorResponse(
message="Please provide a filename",
session_id=session_id,
)
if not content_b64:
return ErrorResponse(
message="Please provide content_base64",
session_id=session_id,
)
# Decode content
try:
content = base64.b64decode(content_b64)
except Exception:
return ErrorResponse(
message="Invalid base64-encoded content",
session_id=session_id,
)
# Check size
if len(content) > MAX_FILE_SIZE_BYTES:
return ErrorResponse(
message=f"File too large. Maximum size is {MAX_FILE_SIZE_BYTES // (1024*1024)}MB",
session_id=session_id,
)
try:
# Virus scan
await scan_content_safe(content, filename=filename)
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
file_record = await manager.write_file(
content=content,
filename=filename,
path=path,
mime_type=mime_type,
source=WorkspaceFileSource.COPILOT,
source_session_id=session.session_id,
overwrite=overwrite,
)
return WorkspaceWriteResponse(
file_id=file_record.id,
name=file_record.name,
path=file_record.path,
size_bytes=file_record.sizeBytes,
message=f"Successfully wrote file: {file_record.name}",
session_id=session_id,
)
except ValueError as e:
return ErrorResponse(
message=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Error writing workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to write workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)
class DeleteWorkspaceFileTool(BaseTool):
"""Tool for deleting files from workspace."""
@property
def name(self) -> str:
return "delete_workspace_file"
@property
def description(self) -> str:
return (
"Delete a file from the user's workspace. "
"Specify either file_id or path to identify the file. "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"file_id": {
"type": "string",
"description": "The file's unique ID (from list_workspace_files)",
},
"path": {
"type": "string",
"description": (
"The virtual file path (e.g., '/documents/report.pdf'). "
"Scoped to current session by default."
),
},
},
"required": [], # At least one must be provided
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
if not file_id and not path:
return ErrorResponse(
message="Please provide either file_id or path",
session_id=session_id,
)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
# Determine the file_id to delete
target_file_id: str
if file_id:
target_file_id = file_id
else:
# path is guaranteed to be non-None here due to the check above
assert path is not None
file_info = await manager.get_file_info_by_path(path)
if file_info is None:
return ErrorResponse(
message=f"File not found at path: {path}",
session_id=session_id,
)
target_file_id = file_info.id
success = await manager.delete_file(target_file_id)
if not success:
return ErrorResponse(
message=f"File not found: {target_file_id}",
session_id=session_id,
)
return WorkspaceDeleteResponse(
file_id=target_file_id,
success=True,
message="File deleted successfully",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to delete workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)

View File

@@ -0,0 +1 @@
# Workspace API feature module

View File

@@ -0,0 +1,85 @@
"""
Pydantic models for the Workspace API.
"""
from datetime import datetime
from typing import Any, Optional
from prisma.enums import WorkspaceFileSource
from pydantic import BaseModel, Field
class WorkspaceInfo(BaseModel):
"""Response model for workspace information."""
id: str
user_id: str
created_at: datetime
updated_at: datetime
file_count: int = 0
class WorkspaceFileInfo(BaseModel):
"""Response model for workspace file information."""
id: str
name: str
path: str
mime_type: str
size_bytes: int
checksum: Optional[str] = None
source: WorkspaceFileSource
source_exec_id: Optional[str] = None
source_session_id: Optional[str] = None
created_at: datetime
updated_at: datetime
metadata: dict[str, Any] = Field(default_factory=dict)
class WorkspaceFileListResponse(BaseModel):
"""Response model for listing workspace files."""
files: list[WorkspaceFileInfo]
total_count: int
path_filter: Optional[str] = None
class UploadFileRequest(BaseModel):
"""Request model for file upload metadata."""
filename: str
path: Optional[str] = None
mime_type: Optional[str] = None
overwrite: bool = False
class WriteFileRequest(BaseModel):
"""Request model for writing file content directly (for CoPilot tools)."""
filename: str
content_base64: str = Field(description="Base64-encoded file content")
path: Optional[str] = None
mime_type: Optional[str] = None
overwrite: bool = False
class UploadFileResponse(BaseModel):
"""Response model for file upload."""
file: WorkspaceFileInfo
message: str
class DeleteFileResponse(BaseModel):
"""Response model for file deletion."""
success: bool
file_id: str
message: str
class DownloadUrlResponse(BaseModel):
"""Response model for download URL."""
url: str
expires_in_seconds: int

View File

@@ -0,0 +1,495 @@
"""
Workspace API routes for managing user file storage.
"""
import base64
import logging
from typing import Annotated, Optional
import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
from fastapi import File, Query, UploadFile
from fastapi.responses import Response
from prisma.enums import WorkspaceFileSource
from backend.data.workspace import (
count_workspace_files,
get_or_create_workspace,
get_workspace,
get_workspace_file,
get_workspace_file_by_path,
)
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import MAX_FILE_SIZE_BYTES, WorkspaceManager
from backend.util.workspace_storage import get_workspace_storage
from .models import (
DeleteFileResponse,
DownloadUrlResponse,
UploadFileResponse,
WorkspaceFileInfo,
WorkspaceFileListResponse,
WorkspaceInfo,
WriteFileRequest,
)
logger = logging.getLogger(__name__)
router = fastapi.APIRouter(
dependencies=[fastapi.Security(requires_user)],
)
def _file_to_info(file) -> WorkspaceFileInfo:
"""Convert database file record to API response model."""
return WorkspaceFileInfo(
id=file.id,
name=file.name,
path=file.path,
mime_type=file.mimeType,
size_bytes=file.sizeBytes,
checksum=file.checksum,
source=file.source,
source_exec_id=file.sourceExecId,
source_session_id=file.sourceSessionId,
created_at=file.createdAt,
updated_at=file.updatedAt,
metadata=file.metadata if file.metadata else {},
)
@router.get(
"",
summary="Get workspace info",
response_model=WorkspaceInfo,
)
async def get_workspace_info(
user_id: Annotated[str, fastapi.Security(get_user_id)],
) -> WorkspaceInfo:
"""
Get the current user's workspace information.
Creates workspace if it doesn't exist.
"""
workspace = await get_or_create_workspace(user_id)
file_count = await count_workspace_files(workspace.id)
return WorkspaceInfo(
id=workspace.id,
user_id=workspace.userId,
created_at=workspace.createdAt,
updated_at=workspace.updatedAt,
file_count=file_count,
)
@router.post(
"/files",
summary="Upload file to workspace",
response_model=UploadFileResponse,
)
async def upload_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
file: UploadFile = File(...),
path: Annotated[Optional[str], Query()] = None,
overwrite: Annotated[bool, Query()] = False,
) -> UploadFileResponse:
"""
Upload a file to the user's workspace.
- **file**: The file to upload (max 100MB)
- **path**: Optional virtual path (defaults to "/{filename}")
- **overwrite**: Whether to overwrite existing file at path
"""
workspace = await get_or_create_workspace(user_id)
manager = WorkspaceManager(user_id, workspace.id)
# Read file content
content = await file.read()
# Check file size
if len(content) > MAX_FILE_SIZE_BYTES:
raise fastapi.HTTPException(
status_code=413,
detail=f"File too large. Maximum size is {MAX_FILE_SIZE_BYTES // (1024*1024)}MB",
)
# Virus scan
filename = file.filename or "uploaded_file"
await scan_content_safe(content, filename=filename)
# Write file to workspace
try:
workspace_file = await manager.write_file(
content=content,
filename=filename,
path=path,
mime_type=file.content_type,
source=WorkspaceFileSource.UPLOAD,
overwrite=overwrite,
)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=str(e))
return UploadFileResponse(
file=_file_to_info(workspace_file),
message="File uploaded successfully",
)
@router.post(
"/files/write",
summary="Write file content directly",
response_model=UploadFileResponse,
)
async def write_file_content(
user_id: Annotated[str, fastapi.Security(get_user_id)],
request: WriteFileRequest,
) -> UploadFileResponse:
"""
Write file content directly to workspace (for programmatic access).
- **filename**: Name for the file
- **content_base64**: Base64-encoded file content
- **path**: Optional virtual path (defaults to "/{filename}")
- **mime_type**: Optional MIME type (auto-detected if not provided)
- **overwrite**: Whether to overwrite existing file at path
"""
workspace = await get_or_create_workspace(user_id)
manager = WorkspaceManager(user_id, workspace.id)
# Decode content
try:
content = base64.b64decode(request.content_base64)
except Exception:
raise fastapi.HTTPException(
status_code=400, detail="Invalid base64-encoded content"
)
# Check file size
if len(content) > MAX_FILE_SIZE_BYTES:
raise fastapi.HTTPException(
status_code=413,
detail=f"File too large. Maximum size is {MAX_FILE_SIZE_BYTES // (1024*1024)}MB",
)
# Virus scan
await scan_content_safe(content, filename=request.filename)
# Write file to workspace
try:
workspace_file = await manager.write_file(
content=content,
filename=request.filename,
path=request.path,
mime_type=request.mime_type,
source=WorkspaceFileSource.UPLOAD,
overwrite=request.overwrite,
)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=str(e))
return UploadFileResponse(
file=_file_to_info(workspace_file),
message="File written successfully",
)
@router.get(
"/files",
summary="List workspace files",
response_model=WorkspaceFileListResponse,
)
async def list_files(
user_id: Annotated[str, fastapi.Security(get_user_id)],
path: Annotated[Optional[str], Query(description="Path prefix filter")] = None,
limit: Annotated[int, Query(ge=1, le=100)] = 50,
offset: Annotated[int, Query(ge=0)] = 0,
) -> WorkspaceFileListResponse:
"""
List files in the user's workspace.
- **path**: Optional path prefix to filter results
- **limit**: Maximum number of files to return (1-100)
- **offset**: Number of files to skip
"""
workspace = await get_workspace(user_id)
if workspace is None:
return WorkspaceFileListResponse(
files=[],
total_count=0,
path_filter=path,
)
manager = WorkspaceManager(user_id, workspace.id)
files = await manager.list_files(path=path, limit=limit, offset=offset)
total = await manager.get_file_count()
return WorkspaceFileListResponse(
files=[_file_to_info(f) for f in files],
total_count=total,
path_filter=path,
)
@router.get(
"/files/{file_id}",
summary="Get file info by ID",
response_model=WorkspaceFileInfo,
)
async def get_file_info(
user_id: Annotated[str, fastapi.Security(get_user_id)],
file_id: str,
) -> WorkspaceFileInfo:
"""
Get file metadata by file ID.
"""
workspace = await get_workspace(user_id)
if workspace is None:
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
file = await get_workspace_file(file_id, workspace.id)
if file is None:
raise fastapi.HTTPException(status_code=404, detail="File not found")
return _file_to_info(file)
@router.get(
"/files/{file_id}/download",
summary="Download file by ID",
)
async def download_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
file_id: str,
) -> Response:
"""
Download a file by its ID.
Returns the file content directly or redirects to a signed URL for GCS.
"""
workspace = await get_workspace(user_id)
if workspace is None:
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
file = await get_workspace_file(file_id, workspace.id)
if file is None:
raise fastapi.HTTPException(status_code=404, detail="File not found")
storage = await get_workspace_storage()
# For local storage, stream the file directly
if file.storagePath.startswith("local://"):
content = await storage.retrieve(file.storagePath)
return Response(
content=content,
media_type=file.mimeType,
headers={
"Content-Disposition": f'attachment; filename="{file.name}"',
"Content-Length": str(len(content)),
},
)
# For GCS, try to redirect to signed URL, fall back to streaming
try:
url = await storage.get_download_url(file.storagePath, expires_in=300)
# If we got back an API path (fallback), stream directly instead
if url.startswith("/api/"):
content = await storage.retrieve(file.storagePath)
return Response(
content=content,
media_type=file.mimeType,
headers={
"Content-Disposition": f'attachment; filename="{file.name}"',
"Content-Length": str(len(content)),
},
)
return fastapi.responses.RedirectResponse(url=url, status_code=302)
except Exception:
# Fall back to streaming directly from GCS
content = await storage.retrieve(file.storagePath)
return Response(
content=content,
media_type=file.mimeType,
headers={
"Content-Disposition": f'attachment; filename="{file.name}"',
"Content-Length": str(len(content)),
},
)
@router.get(
"/files/{file_id}/url",
summary="Get download URL",
response_model=DownloadUrlResponse,
)
async def get_download_url(
user_id: Annotated[str, fastapi.Security(get_user_id)],
file_id: str,
expires_in: Annotated[int, Query(ge=60, le=86400)] = 3600,
) -> DownloadUrlResponse:
"""
Get a download URL for a file.
- **expires_in**: URL expiration time in seconds (60-86400, default 3600)
"""
workspace = await get_workspace(user_id)
if workspace is None:
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
manager = WorkspaceManager(user_id, workspace.id)
try:
url = await manager.get_download_url(file_id, expires_in)
except FileNotFoundError:
raise fastapi.HTTPException(status_code=404, detail="File not found")
return DownloadUrlResponse(
url=url,
expires_in_seconds=expires_in,
)
@router.delete(
"/files/{file_id}",
summary="Delete file by ID",
response_model=DeleteFileResponse,
)
async def delete_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
file_id: str,
) -> DeleteFileResponse:
"""
Delete a file from the workspace (soft-delete).
"""
workspace = await get_workspace(user_id)
if workspace is None:
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
manager = WorkspaceManager(user_id, workspace.id)
success = await manager.delete_file(file_id)
if not success:
raise fastapi.HTTPException(status_code=404, detail="File not found")
return DeleteFileResponse(
success=True,
file_id=file_id,
message="File deleted successfully",
)
# By-path endpoints
@router.get(
"/files/by-path",
summary="Get file info by path",
response_model=WorkspaceFileInfo,
)
async def get_file_by_path(
user_id: Annotated[str, fastapi.Security(get_user_id)],
path: Annotated[str, Query(description="Virtual file path")],
) -> WorkspaceFileInfo:
"""
Get file metadata by virtual path.
"""
workspace = await get_workspace(user_id)
if workspace is None:
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
file = await get_workspace_file_by_path(workspace.id, path)
if file is None:
raise fastapi.HTTPException(status_code=404, detail="File not found")
return _file_to_info(file)
@router.get(
"/files/by-path/download",
summary="Download file by path",
)
async def download_file_by_path(
user_id: Annotated[str, fastapi.Security(get_user_id)],
path: Annotated[str, Query(description="Virtual file path")],
) -> Response:
"""
Download a file by its virtual path.
"""
workspace = await get_workspace(user_id)
if workspace is None:
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
file = await get_workspace_file_by_path(workspace.id, path)
if file is None:
raise fastapi.HTTPException(status_code=404, detail="File not found")
storage = await get_workspace_storage()
# For local storage, stream the file directly
if file.storagePath.startswith("local://"):
content = await storage.retrieve(file.storagePath)
return Response(
content=content,
media_type=file.mimeType,
headers={
"Content-Disposition": f'attachment; filename="{file.name}"',
"Content-Length": str(len(content)),
},
)
# For GCS, try to redirect to signed URL, fall back to streaming
try:
url = await storage.get_download_url(file.storagePath, expires_in=300)
# If we got back an API path (fallback), stream directly instead
if url.startswith("/api/"):
content = await storage.retrieve(file.storagePath)
return Response(
content=content,
media_type=file.mimeType,
headers={
"Content-Disposition": f'attachment; filename="{file.name}"',
"Content-Length": str(len(content)),
},
)
return fastapi.responses.RedirectResponse(url=url, status_code=302)
except Exception:
# Fall back to streaming directly from GCS
content = await storage.retrieve(file.storagePath)
return Response(
content=content,
media_type=file.mimeType,
headers={
"Content-Disposition": f'attachment; filename="{file.name}"',
"Content-Length": str(len(content)),
},
)
@router.delete(
"/files/by-path",
summary="Delete file by path",
response_model=DeleteFileResponse,
)
async def delete_file_by_path(
user_id: Annotated[str, fastapi.Security(get_user_id)],
path: Annotated[str, Query(description="Virtual file path")],
) -> DeleteFileResponse:
"""
Delete a file by its virtual path (soft-delete).
"""
workspace = await get_workspace(user_id)
if workspace is None:
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
file = await get_workspace_file_by_path(workspace.id, path)
if file is None:
raise fastapi.HTTPException(status_code=404, detail="File not found")
manager = WorkspaceManager(user_id, workspace.id)
success = await manager.delete_file(file.id)
return DeleteFileResponse(
success=success,
file_id=file.id,
message="File deleted successfully" if success else "Failed to delete file",
)

View File

@@ -32,6 +32,7 @@ import backend.api.features.postmark.postmark
import backend.api.features.store.model
import backend.api.features.store.routes
import backend.api.features.v1
import backend.api.features.workspace.routes as workspace_routes
import backend.data.block
import backend.data.db
import backend.data.graph
@@ -315,6 +316,11 @@ app.include_router(
tags=["v2", "chat"],
prefix="/api/chat",
)
app.include_router(
workspace_routes.router,
tags=["v2", "workspace"],
prefix="/api/workspace",
)
app.include_router(
backend.api.features.oauth.router,
tags=["oauth"],

View File

@@ -13,6 +13,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -132,8 +133,7 @@ class AIImageCustomizerBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
try:
@@ -141,10 +141,9 @@ class AIImageCustomizerBlock(Block):
processed_images = await asyncio.gather(
*(
store_media_file(
graph_exec_id=graph_exec_id,
file=img,
user_id=user_id,
return_content=True,
execution_context=execution_context,
return_format="for_external_api", # Get content for Replicate API
)
for img in input_data.images
)
@@ -158,7 +157,14 @@ class AIImageCustomizerBlock(Block):
aspect_ratio=input_data.aspect_ratio.value,
output_format=input_data.output_format.value,
)
yield "image_url", result
# Store the generated image to the user's workspace for persistence
stored_url = await store_media_file(
file=result,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", stored_url
except Exception as e:
yield "error", str(e)

View File

@@ -6,6 +6,7 @@ from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -13,6 +14,8 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
class ImageSize(str, Enum):
@@ -165,11 +168,13 @@ class AIImageGeneratorBlock(Block):
test_output=[
(
"image_url",
"https://replicate.delivery/generated-image.webp",
# Test output is a data URI since we now store images
lambda x: x.startswith(""
},
)
@@ -318,11 +323,24 @@ class AIImageGeneratorBlock(Block):
style_text = style_map.get(style, "")
return f"{style_text} of" if style_text else ""
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
):
try:
url = await self.generate_image(input_data, credentials)
if url:
yield "image_url", url
# Store the generated image to the user's workspace/execution folder
stored_url = await store_media_file(
file=MediaFileType(url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", stored_url
else:
yield "error", "Image generation returned an empty result."
except Exception as e:

View File

@@ -13,6 +13,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -21,7 +22,9 @@ from backend.data.model import (
)
from backend.integrations.providers import ProviderName
from backend.util.exceptions import BlockExecutionError
from backend.util.file import store_media_file
from backend.util.request import Requests
from backend.util.type import MediaFileType
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@@ -288,7 +291,12 @@ class AIShortformVideoCreatorBlock(Block):
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# Create a new Webhook.site URL
webhook_token, webhook_url = await self.create_webhook()
@@ -340,7 +348,13 @@ class AIShortformVideoCreatorBlock(Block):
)
video_url = await self.wait_for_video(credentials.api_key, pid)
logger.debug(f"Video ready: {video_url}")
yield "video_url", video_url
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
class AIAdMakerVideoCreatorBlock(Block):
@@ -463,7 +477,14 @@ class AIAdMakerVideoCreatorBlock(Block):
test_credentials=TEST_CREDENTIALS,
)
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
):
webhook_token, webhook_url = await self.create_webhook()
payload = {
@@ -531,7 +552,13 @@ class AIAdMakerVideoCreatorBlock(Block):
raise RuntimeError("Failed to create video: No project ID returned")
video_url = await self.wait_for_video(credentials.api_key, pid)
yield "video_url", video_url
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
class AIScreenshotToVideoAdBlock(Block):
@@ -642,7 +669,14 @@ class AIScreenshotToVideoAdBlock(Block):
test_credentials=TEST_CREDENTIALS,
)
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
):
webhook_token, webhook_url = await self.create_webhook()
payload = {
@@ -710,4 +744,10 @@ class AIScreenshotToVideoAdBlock(Block):
raise RuntimeError("Failed to create video: No project ID returned")
video_url = await self.wait_for_video(credentials.api_key, pid)
yield "video_url", video_url
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url

View File

@@ -6,6 +6,7 @@ if TYPE_CHECKING:
from pydantic import SecretStr
from backend.data.execution import ExecutionContext
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -17,6 +18,8 @@ from backend.sdk import (
Requests,
SchemaField,
)
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
from ._config import bannerbear
@@ -177,7 +180,12 @@ class BannerbearTextOverlayBlock(Block):
raise Exception(error_msg)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# Build the modifications array
modifications = []
@@ -234,6 +242,18 @@ class BannerbearTextOverlayBlock(Block):
# Synchronous request - image should be ready
yield "success", True
yield "image_url", data.get("image_url", "")
# Store the generated image to workspace for persistence
image_url = data.get("image_url", "")
if image_url:
stored_url = await store_media_file(
file=MediaFileType(image_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", stored_url
else:
yield "image_url", ""
yield "uid", data.get("uid", "")
yield "status", data.get("status", "completed")

View File

@@ -9,6 +9,7 @@ from backend.data.block import (
BlockSchemaOutput,
BlockType,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import store_media_file
from backend.util.type import MediaFileType, convert
@@ -45,15 +46,20 @@ class FileStoreBlock(Block):
self,
input_data: Input,
*,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# Determine return format based on user preference
# for_block_output: returns workspace:// if available, else data URI
# for_local_processing: returns local file path
return_format = (
"for_block_output" if input_data.base_64 else "for_local_processing"
)
yield "file_out", await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.file_in,
user_id=user_id,
return_content=input_data.base_64,
execution_context=execution_context,
return_format=return_format,
)

View File

@@ -15,6 +15,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import APIKeyCredentials, SchemaField
from backend.util.file import store_media_file
from backend.util.request import Requests
@@ -666,8 +667,7 @@ class SendDiscordFileBlock(Block):
file: MediaFileType,
filename: str,
message_content: str,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
) -> dict:
intents = discord.Intents.default()
intents.guilds = True
@@ -731,10 +731,9 @@ class SendDiscordFileBlock(Block):
# Local file path - read from stored media file
# This would be a path from a previous block's output
stored_file = await store_media_file(
graph_exec_id=graph_exec_id,
file=file,
user_id=user_id,
return_content=True, # Get as data URI
execution_context=execution_context,
return_format="for_external_api", # Get content to send to Discord
)
# Now process as data URI
header, encoded = stored_file.split(",", 1)
@@ -781,8 +780,7 @@ class SendDiscordFileBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
try:
@@ -793,8 +791,7 @@ class SendDiscordFileBlock(Block):
file=input_data.file,
filename=input_data.filename,
message_content=input_data.message_content,
graph_exec_id=graph_exec_id,
user_id=user_id,
execution_context=execution_context,
)
yield "status", result.get("status", "Unknown error")

View File

@@ -17,8 +17,11 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import store_media_file
from backend.util.request import ClientResponseError, Requests
from backend.util.type import MediaFileType
logger = logging.getLogger(__name__)
@@ -208,11 +211,22 @@ class AIVideoGeneratorBlock(Block):
raise RuntimeError(f"API request failed: {str(e)}")
async def run(
self, input_data: Input, *, credentials: FalCredentials, **kwargs
self,
input_data: Input,
*,
credentials: FalCredentials,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
try:
video_url = await self.generate_video(input_data, credentials)
yield "video_url", video_url
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
except Exception as e:
error_message = str(e)
yield "error", error_message

View File

@@ -12,6 +12,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -134,8 +135,7 @@ class AIImageEditorBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
result = await self.run_model(
@@ -144,20 +144,25 @@ class AIImageEditorBlock(Block):
prompt=input_data.prompt,
input_image_b64=(
await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.input_image,
user_id=user_id,
return_content=True,
execution_context=execution_context,
return_format="for_external_api", # Get content for Replicate API
)
if input_data.input_image
else None
),
aspect_ratio=input_data.aspect_ratio.value,
seed=input_data.seed,
user_id=user_id,
graph_exec_id=graph_exec_id,
user_id=execution_context.user_id or "",
graph_exec_id=execution_context.graph_exec_id or "",
)
yield "output_image", result
# Store the generated image to the user's workspace for persistence
stored_url = await store_media_file(
file=result,
execution_context=execution_context,
return_format="for_block_output",
)
yield "output_image", stored_url
async def run_model(
self,

View File

@@ -21,6 +21,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
from backend.util.settings import Settings
@@ -95,8 +96,7 @@ def _make_mime_text(
async def create_mime_message(
input_data,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
) -> str:
"""Create a MIME message with attachments and return base64-encoded raw message."""
@@ -117,12 +117,13 @@ async def create_mime_message(
if input_data.attachments:
for attach in input_data.attachments:
local_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=attach,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
abs_path = get_exec_file_path(
execution_context.graph_exec_id or "", local_path
)
abs_path = get_exec_file_path(graph_exec_id, local_path)
part = MIMEBase("application", "octet-stream")
with open(abs_path, "rb") as f:
part.set_payload(f.read())
@@ -582,27 +583,25 @@ class GmailSendBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
result = await self._send_email(
service,
input_data,
graph_exec_id,
user_id,
execution_context,
)
yield "result", result
async def _send_email(
self, service, input_data: Input, graph_exec_id: str, user_id: str
self, service, input_data: Input, execution_context: ExecutionContext
) -> dict:
if not input_data.to or not input_data.subject or not input_data.body:
raise ValueError(
"At least one recipient, subject, and body are required for sending an email"
)
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
raw_message = await create_mime_message(input_data, execution_context)
sent_message = await asyncio.to_thread(
lambda: service.users()
.messages()
@@ -692,30 +691,28 @@ class GmailCreateDraftBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
result = await self._create_draft(
service,
input_data,
graph_exec_id,
user_id,
execution_context,
)
yield "result", GmailDraftResult(
id=result["id"], message_id=result["message"]["id"], status="draft_created"
)
async def _create_draft(
self, service, input_data: Input, graph_exec_id: str, user_id: str
self, service, input_data: Input, execution_context: ExecutionContext
) -> dict:
if not input_data.to or not input_data.subject:
raise ValueError(
"At least one recipient and subject are required for creating a draft"
)
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
raw_message = await create_mime_message(input_data, execution_context)
draft = await asyncio.to_thread(
lambda: service.users()
.drafts()
@@ -1100,7 +1097,7 @@ class GmailGetThreadBlock(GmailBase):
async def _build_reply_message(
service, input_data, graph_exec_id: str, user_id: str
service, input_data, execution_context: ExecutionContext
) -> tuple[str, str]:
"""
Builds a reply MIME message for Gmail threads.
@@ -1190,12 +1187,11 @@ async def _build_reply_message(
# Handle attachments
for attach in input_data.attachments:
local_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=attach,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
abs_path = get_exec_file_path(graph_exec_id, local_path)
abs_path = get_exec_file_path(execution_context.graph_exec_id or "", local_path)
part = MIMEBase("application", "octet-stream")
with open(abs_path, "rb") as f:
part.set_payload(f.read())
@@ -1311,16 +1307,14 @@ class GmailReplyBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
message = await self._reply(
service,
input_data,
graph_exec_id,
user_id,
execution_context,
)
yield "messageId", message["id"]
yield "threadId", message.get("threadId", input_data.threadId)
@@ -1343,11 +1337,11 @@ class GmailReplyBlock(GmailBase):
yield "email", email
async def _reply(
self, service, input_data: Input, graph_exec_id: str, user_id: str
self, service, input_data: Input, execution_context: ExecutionContext
) -> dict:
# Build the reply message using the shared helper
raw, thread_id = await _build_reply_message(
service, input_data, graph_exec_id, user_id
service, input_data, execution_context
)
# Send the message
@@ -1441,16 +1435,14 @@ class GmailDraftReplyBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
draft = await self._create_draft_reply(
service,
input_data,
graph_exec_id,
user_id,
execution_context,
)
yield "draftId", draft["id"]
yield "messageId", draft["message"]["id"]
@@ -1458,11 +1450,11 @@ class GmailDraftReplyBlock(GmailBase):
yield "status", "draft_created"
async def _create_draft_reply(
self, service, input_data: Input, graph_exec_id: str, user_id: str
self, service, input_data: Input, execution_context: ExecutionContext
) -> dict:
# Build the reply message using the shared helper
raw, thread_id = await _build_reply_message(
service, input_data, graph_exec_id, user_id
service, input_data, execution_context
)
# Create draft with proper thread association
@@ -1629,23 +1621,21 @@ class GmailForwardBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
result = await self._forward_message(
service,
input_data,
graph_exec_id,
user_id,
execution_context,
)
yield "messageId", result["id"]
yield "threadId", result.get("threadId", "")
yield "status", "forwarded"
async def _forward_message(
self, service, input_data: Input, graph_exec_id: str, user_id: str
self, service, input_data: Input, execution_context: ExecutionContext
) -> dict:
if not input_data.to:
raise ValueError("At least one recipient is required for forwarding")
@@ -1727,12 +1717,13 @@ To: {original_to}
# Add any additional attachments
for attach in input_data.additionalAttachments:
local_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=attach,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
abs_path = get_exec_file_path(
execution_context.graph_exec_id or "", local_path
)
abs_path = get_exec_file_path(graph_exec_id, local_path)
part = MIMEBase("application", "octet-stream")
with open(abs_path, "rb") as f:
part.set_payload(f.read())

View File

@@ -15,6 +15,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
@@ -116,10 +117,9 @@ class SendWebRequestBlock(Block):
@staticmethod
async def _prepare_files(
graph_exec_id: str,
execution_context: ExecutionContext,
files_name: str,
files: list[MediaFileType],
user_id: str,
) -> list[tuple[str, tuple[str, BytesIO, str]]]:
"""
Prepare files for the request by storing them and reading their content.
@@ -127,11 +127,15 @@ class SendWebRequestBlock(Block):
(files_name, (filename, BytesIO, mime_type))
"""
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
graph_exec_id = execution_context.graph_exec_id
assert graph_exec_id is not None
for media in files:
# Normalise to a list so we can repeat the same key
rel_path = await store_media_file(
graph_exec_id, media, user_id, return_content=False
file=media,
execution_context=execution_context,
return_format="for_local_processing",
)
abs_path = get_exec_file_path(graph_exec_id, rel_path)
async with aiofiles.open(abs_path, "rb") as f:
@@ -143,7 +147,7 @@ class SendWebRequestBlock(Block):
return files_payload
async def run(
self, input_data: Input, *, graph_exec_id: str, user_id: str, **kwargs
self, input_data: Input, *, execution_context: ExecutionContext, **kwargs
) -> BlockOutput:
# ─── Parse/normalise body ────────────────────────────────────
body = input_data.body
@@ -174,7 +178,7 @@ class SendWebRequestBlock(Block):
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
if use_files:
files_payload = await self._prepare_files(
graph_exec_id, input_data.files_name, input_data.files, user_id
execution_context, input_data.files_name, input_data.files
)
# Enforce body format rules
@@ -238,9 +242,8 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
self,
input_data: Input,
*,
graph_exec_id: str,
execution_context: ExecutionContext,
credentials: HostScopedCredentials,
user_id: str,
**kwargs,
) -> BlockOutput:
# Create SendWebRequestBlock.Input from our input (removing credentials field)
@@ -271,6 +274,6 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
# Use parent class run method
async for output_name, output_data in super().run(
base_input, graph_exec_id=graph_exec_id, user_id=user_id, **kwargs
base_input, execution_context=execution_context, **kwargs
):
yield output_name, output_data

View File

@@ -12,6 +12,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockType,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import store_media_file
from backend.util.mock import MockObject
@@ -462,18 +463,23 @@ class AgentFileInputBlock(AgentInputBlock):
self,
input_data: Input,
*,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
if not input_data.value:
return
# Determine return format based on user preference
# for_block_output: returns workspace:// if available, else data URI
# for_local_processing: returns local file path
return_format = (
"for_block_output" if input_data.base_64 else "for_local_processing"
)
yield "result", await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.value,
user_id=user_id,
return_content=input_data.base_64,
execution_context=execution_context,
return_format=return_format,
)

View File

@@ -13,6 +13,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
@@ -46,18 +47,19 @@ class MediaDurationBlock(Block):
self,
input_data: Input,
*,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# 1) Store the input media locally
local_media_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.media_in,
user_id=user_id,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
assert execution_context.graph_exec_id is not None
media_abspath = get_exec_file_path(
execution_context.graph_exec_id, local_media_path
)
media_abspath = get_exec_file_path(graph_exec_id, local_media_path)
# 2) Load the clip
if input_data.is_video:
@@ -111,17 +113,19 @@ class LoopVideoBlock(Block):
self,
input_data: Input,
*,
node_exec_id: str,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
assert execution_context.graph_exec_id is not None
assert execution_context.node_exec_id is not None
graph_exec_id = execution_context.graph_exec_id
node_exec_id = execution_context.node_exec_id
# 1) Store the input video locally
local_video_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.video_in,
user_id=user_id,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
@@ -149,12 +153,11 @@ class LoopVideoBlock(Block):
looped_clip = looped_clip.with_audio(clip.audio)
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
# Return as data URI
# Return output - for_block_output returns workspace:// if available, else data URI
video_out = await store_media_file(
graph_exec_id=graph_exec_id,
file=output_filename,
user_id=user_id,
return_content=input_data.output_return_type == "data_uri",
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_out", video_out
@@ -200,23 +203,24 @@ class AddAudioToVideoBlock(Block):
self,
input_data: Input,
*,
node_exec_id: str,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
assert execution_context.graph_exec_id is not None
assert execution_context.node_exec_id is not None
graph_exec_id = execution_context.graph_exec_id
node_exec_id = execution_context.node_exec_id
# 1) Store the inputs locally
local_video_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.video_in,
user_id=user_id,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
local_audio_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.audio_in,
user_id=user_id,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
@@ -240,12 +244,11 @@ class AddAudioToVideoBlock(Block):
output_abspath = os.path.join(abs_temp_dir, output_filename)
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
# 5) Return either path or data URI
# 5) Return output - for_block_output returns workspace:// if available, else data URI
video_out = await store_media_file(
graph_exec_id=graph_exec_id,
file=output_filename,
user_id=user_id,
return_content=input_data.output_return_type == "data_uri",
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_out", video_out

View File

@@ -11,6 +11,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -112,8 +113,7 @@ class ScreenshotWebPageBlock(Block):
@staticmethod
async def take_screenshot(
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
url: str,
viewport_width: int,
viewport_height: int,
@@ -155,12 +155,11 @@ class ScreenshotWebPageBlock(Block):
return {
"image": await store_media_file(
graph_exec_id=graph_exec_id,
file=MediaFileType(
f"data:image/{format.value};base64,{b64encode(content).decode('utf-8')}"
),
user_id=user_id,
return_content=True,
execution_context=execution_context,
return_format="for_block_output",
)
}
@@ -169,15 +168,13 @@ class ScreenshotWebPageBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
try:
screenshot_data = await self.take_screenshot(
credentials=credentials,
graph_exec_id=graph_exec_id,
user_id=user_id,
execution_context=execution_context,
url=input_data.url,
viewport_width=input_data.viewport_width,
viewport_height=input_data.viewport_height,

View File

@@ -7,6 +7,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import ContributorDetails, SchemaField
from backend.util.file import get_exec_file_path, store_media_file
from backend.util.type import MediaFileType
@@ -98,7 +99,7 @@ class ReadSpreadsheetBlock(Block):
)
async def run(
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
) -> BlockOutput:
import csv
from io import StringIO
@@ -106,14 +107,15 @@ class ReadSpreadsheetBlock(Block):
# Determine data source - prefer file_input if provided, otherwise use contents
if input_data.file_input:
stored_file_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=input_data.file_input,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
# Get full file path
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
file_path = get_exec_file_path(
execution_context.graph_exec_id or "", stored_file_path
)
if not Path(file_path).exists():
raise ValueError(f"File does not exist: {file_path}")

View File

@@ -10,6 +10,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -17,7 +18,9 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.file import store_media_file
from backend.util.request import Requests
from backend.util.type import MediaFileType
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@@ -138,7 +141,12 @@ class CreateTalkingAvatarVideoBlock(Block):
return response.json()
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# Create the clip
payload = {
@@ -165,7 +173,14 @@ class CreateTalkingAvatarVideoBlock(Block):
for _ in range(input_data.max_polling_attempts):
status_response = await self.get_clip_status(credentials.api_key, clip_id)
if status_response["status"] == "done":
yield "video_url", status_response["result_url"]
# Store the generated video to the user's workspace for persistence
video_url = status_response["result_url"]
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
return
elif status_response["status"] == "error":
raise RuntimeError(

View File

@@ -12,6 +12,7 @@ from backend.blocks.iteration import StepThroughItemsBlock
from backend.blocks.llm import AITextSummarizerBlock
from backend.blocks.text import ExtractTextInformationBlock
from backend.blocks.xml_parser import XMLParserBlock
from backend.data.execution import ExecutionContext
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
@@ -233,9 +234,11 @@ class TestStoreMediaFileSecurity:
with pytest.raises(ValueError, match="File too large"):
await store_media_file(
graph_exec_id="test",
file=MediaFileType(large_data_uri),
user_id="test_user",
execution_context=ExecutionContext(
user_id="test_user",
graph_exec_id="test",
),
)
@patch("backend.util.file.Path")
@@ -270,9 +273,11 @@ class TestStoreMediaFileSecurity:
# Should raise an error when directory size exceeds limit
with pytest.raises(ValueError, match="Disk usage limit exceeded"):
await store_media_file(
graph_exec_id="test",
file=MediaFileType(
"data:text/plain;base64,dGVzdA=="
), # Small test file
user_id="test_user",
execution_context=ExecutionContext(
user_id="test_user",
graph_exec_id="test",
),
)

View File

@@ -11,10 +11,22 @@ from backend.blocks.http import (
HttpMethod,
SendAuthenticatedWebRequestBlock,
)
from backend.data.execution import ExecutionContext
from backend.data.model import HostScopedCredentials
from backend.util.request import Response
def make_test_context(
graph_exec_id: str = "test-exec-id",
user_id: str = "test-user-id",
) -> ExecutionContext:
"""Helper to create test ExecutionContext."""
return ExecutionContext(
user_id=user_id,
graph_exec_id=graph_exec_id,
)
class TestHttpBlockWithHostScopedCredentials:
"""Test suite for HTTP block integration with HostScopedCredentials."""
@@ -105,8 +117,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=exact_match_credentials,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -161,8 +172,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=wildcard_credentials,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -208,8 +218,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=non_matching_credentials,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -258,8 +267,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=exact_match_credentials,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -318,8 +326,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=auto_discovered_creds, # Execution manager found these
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -382,8 +389,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=multi_header_creds,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -471,8 +477,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=test_creds,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))

View File

@@ -11,6 +11,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util import json, text
from backend.util.file import get_exec_file_path, store_media_file
@@ -444,18 +445,19 @@ class FileReadBlock(Block):
)
async def run(
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
) -> BlockOutput:
# Store the media file properly (handles URLs, data URIs, etc.)
stored_file_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=input_data.file_input,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
# Get full file path
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
file_path = get_exec_file_path(
execution_context.graph_exec_id or "", stored_file_path
)
if not Path(file_path).exists():
raise ValueError(f"File does not exist: {file_path}")

View File

@@ -83,12 +83,29 @@ class ExecutionContext(BaseModel):
model_config = {"extra": "ignore"}
# Execution identity
user_id: Optional[str] = None
graph_id: Optional[str] = None
graph_exec_id: Optional[str] = None
graph_version: Optional[int] = None
node_id: Optional[str] = None
node_exec_id: Optional[str] = None
# Safety settings
human_in_the_loop_safe_mode: bool = True
sensitive_action_safe_mode: bool = False
# User settings
user_timezone: str = "UTC"
# Execution hierarchy
root_execution_id: Optional[str] = None
parent_execution_id: Optional[str] = None
# Workspace
workspace_id: Optional[str] = None
session_id: Optional[str] = None
# -------------------------- Models -------------------------- #

View File

@@ -0,0 +1,365 @@
"""
Database CRUD operations for User Workspace.
This module provides functions for managing user workspaces and workspace files.
"""
import logging
from datetime import datetime, timezone
from typing import Optional
from prisma.enums import WorkspaceFileSource
from prisma.models import UserWorkspace, UserWorkspaceFile
from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
async def get_or_create_workspace(user_id: str) -> UserWorkspace:
"""
Get user's workspace, creating one if it doesn't exist.
Args:
user_id: The user's ID
Returns:
UserWorkspace instance
"""
workspace = await UserWorkspace.prisma().find_unique(where={"userId": user_id})
if workspace is None:
workspace = await UserWorkspace.prisma().create(
data={
"userId": user_id,
}
)
logger.info(f"Created new workspace {workspace.id} for user {user_id}")
return workspace
async def get_workspace(user_id: str) -> Optional[UserWorkspace]:
"""
Get user's workspace if it exists.
Args:
user_id: The user's ID
Returns:
UserWorkspace instance or None
"""
return await UserWorkspace.prisma().find_unique(where={"userId": user_id})
async def get_workspace_by_id(workspace_id: str) -> Optional[UserWorkspace]:
"""
Get workspace by its ID.
Args:
workspace_id: The workspace ID
Returns:
UserWorkspace instance or None
"""
return await UserWorkspace.prisma().find_unique(where={"id": workspace_id})
async def create_workspace_file(
workspace_id: str,
name: str,
path: str,
storage_path: str,
mime_type: str,
size_bytes: int,
checksum: Optional[str] = None,
source: WorkspaceFileSource = WorkspaceFileSource.UPLOAD,
source_exec_id: Optional[str] = None,
source_session_id: Optional[str] = None,
metadata: Optional[dict] = None,
) -> UserWorkspaceFile:
"""
Create a new workspace file record.
Args:
workspace_id: The workspace ID
name: User-visible filename
path: Virtual path (e.g., "/documents/report.pdf")
storage_path: Actual storage path (GCS or local)
mime_type: MIME type of the file
size_bytes: File size in bytes
checksum: Optional SHA256 checksum
source: How the file was created
source_exec_id: Graph execution ID if from execution
source_session_id: Chat session ID if from CoPilot
metadata: Optional additional metadata
Returns:
Created UserWorkspaceFile instance
"""
# Normalize path to start with /
if not path.startswith("/"):
path = f"/{path}"
file = await UserWorkspaceFile.prisma().create(
data={
"workspaceId": workspace_id,
"name": name,
"path": path,
"storagePath": storage_path,
"mimeType": mime_type,
"sizeBytes": size_bytes,
"checksum": checksum,
"source": source,
"sourceExecId": source_exec_id,
"sourceSessionId": source_session_id,
"metadata": SafeJson(metadata or {}),
}
)
logger.info(
f"Created workspace file {file.id} at path {path} "
f"in workspace {workspace_id}"
)
return file
async def get_workspace_file(
file_id: str,
workspace_id: Optional[str] = None,
) -> Optional[UserWorkspaceFile]:
"""
Get a workspace file by ID.
Args:
file_id: The file ID
workspace_id: Optional workspace ID for validation
Returns:
UserWorkspaceFile instance or None
"""
where_clause: dict = {"id": file_id, "isDeleted": False}
if workspace_id:
where_clause["workspaceId"] = workspace_id
return await UserWorkspaceFile.prisma().find_first(where=where_clause)
async def get_workspace_file_by_path(
workspace_id: str,
path: str,
) -> Optional[UserWorkspaceFile]:
"""
Get a workspace file by its virtual path.
Args:
workspace_id: The workspace ID
path: Virtual path
Returns:
UserWorkspaceFile instance or None
"""
# Normalize path
if not path.startswith("/"):
path = f"/{path}"
return await UserWorkspaceFile.prisma().find_first(
where={
"workspaceId": workspace_id,
"path": path,
"isDeleted": False,
}
)
async def list_workspace_files(
workspace_id: str,
path_prefix: Optional[str] = None,
include_deleted: bool = False,
limit: Optional[int] = None,
offset: int = 0,
) -> list[UserWorkspaceFile]:
"""
List files in a workspace.
Args:
workspace_id: The workspace ID
path_prefix: Optional path prefix to filter (e.g., "/documents/")
include_deleted: Whether to include soft-deleted files
limit: Maximum number of files to return
offset: Number of files to skip
Returns:
List of UserWorkspaceFile instances
"""
where_clause: dict = {"workspaceId": workspace_id}
if not include_deleted:
where_clause["isDeleted"] = False
if path_prefix:
# Normalize prefix
if not path_prefix.startswith("/"):
path_prefix = f"/{path_prefix}"
where_clause["path"] = {"startswith": path_prefix}
return await UserWorkspaceFile.prisma().find_many(
where=where_clause,
order={"createdAt": "desc"},
take=limit,
skip=offset,
)
async def count_workspace_files(
workspace_id: str,
include_deleted: bool = False,
) -> int:
"""
Count files in a workspace.
Args:
workspace_id: The workspace ID
include_deleted: Whether to include soft-deleted files
Returns:
Number of files
"""
where_clause: dict = {"workspaceId": workspace_id}
if not include_deleted:
where_clause["isDeleted"] = False
return await UserWorkspaceFile.prisma().count(where=where_clause)
async def soft_delete_workspace_file(
file_id: str,
workspace_id: Optional[str] = None,
) -> Optional[UserWorkspaceFile]:
"""
Soft-delete a workspace file.
Args:
file_id: The file ID
workspace_id: Optional workspace ID for validation
Returns:
Updated UserWorkspaceFile instance or None if not found
"""
# First verify the file exists and belongs to workspace
file = await get_workspace_file(file_id, workspace_id)
if file is None:
return None
updated = await UserWorkspaceFile.prisma().update(
where={"id": file_id},
data={
"isDeleted": True,
"deletedAt": datetime.now(timezone.utc),
},
)
logger.info(f"Soft-deleted workspace file {file_id}")
return updated
async def hard_delete_workspace_file(file_id: str) -> bool:
"""
Permanently delete a workspace file record.
Note: This only deletes the database record. The actual file should be
deleted from storage separately using the storage backend.
Args:
file_id: The file ID
Returns:
True if deleted, False if not found
"""
try:
await UserWorkspaceFile.prisma().delete(where={"id": file_id})
logger.info(f"Hard-deleted workspace file {file_id}")
return True
except Exception:
return False
async def update_workspace_file(
file_id: str,
name: Optional[str] = None,
path: Optional[str] = None,
metadata: Optional[dict] = None,
) -> Optional[UserWorkspaceFile]:
"""
Update workspace file metadata.
Args:
file_id: The file ID
name: New filename
path: New virtual path
metadata: New metadata (merged with existing)
Returns:
Updated UserWorkspaceFile instance or None if not found
"""
update_data: dict = {}
if name is not None:
update_data["name"] = name
if path is not None:
if not path.startswith("/"):
path = f"/{path}"
update_data["path"] = path
if metadata is not None:
# Get existing metadata and merge
file = await get_workspace_file(file_id)
if file is None:
return None
existing_metadata = file.metadata if file.metadata else {}
merged_metadata = {**existing_metadata, **metadata}
update_data["metadata"] = SafeJson(merged_metadata)
if not update_data:
return await get_workspace_file(file_id)
try:
return await UserWorkspaceFile.prisma().update(
where={"id": file_id},
data=update_data,
)
except Exception:
return None
async def workspace_file_exists(
workspace_id: str,
path: str,
) -> bool:
"""
Check if a file exists at the given path in the workspace.
Args:
workspace_id: The workspace ID
path: Virtual path to check
Returns:
True if file exists, False otherwise
"""
file = await get_workspace_file_by_path(workspace_id, path)
return file is not None
async def get_workspace_total_size(workspace_id: str) -> int:
"""
Get the total size of all files in a workspace.
Args:
workspace_id: The workspace ID
Returns:
Total size in bytes
"""
files = await list_workspace_files(workspace_id)
return sum(file.sizeBytes for file in files)

View File

@@ -236,7 +236,12 @@ async def execute_node(
input_size = len(input_data_str)
log_metadata.debug("Executed node with input", input=input_data_str)
# Update execution_context with node-level info
execution_context.node_id = node_id
execution_context.node_exec_id = node_exec_id
# Inject extra execution arguments for the blocks via kwargs
# Keep individual kwargs for backwards compatibility with existing blocks
extra_exec_kwargs: dict = {
"graph_id": graph_id,
"graph_version": graph_version,

View File

@@ -892,11 +892,19 @@ async def add_graph_execution(
settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id)
execution_context = ExecutionContext(
# Execution identity
user_id=user_id,
graph_id=graph_id,
graph_exec_id=graph_exec.id,
graph_version=graph_version,
# Safety settings
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
# User settings
user_timezone=(
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
),
# Execution hierarchy
root_execution_id=graph_exec.id,
)

View File

@@ -4,13 +4,29 @@ import re
import shutil
import tempfile
import uuid
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Literal
from urllib.parse import urlparse
from prisma.enums import WorkspaceFileSource
from backend.util.cloud_storage import get_cloud_storage_handler
from backend.util.request import Requests
from backend.util.type import MediaFileType
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import WorkspaceManager
if TYPE_CHECKING:
from backend.data.execution import ExecutionContext
# Return format options for store_media_file
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
# - "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs
MediaReturnFormat = Literal[
"for_local_processing", "for_external_api", "for_block_output"
]
TEMP_DIR = Path(tempfile.gettempdir()).resolve()
@@ -67,36 +83,73 @@ def clean_exec_files(graph_exec_id: str, file: str = "") -> None:
async def store_media_file(
graph_exec_id: str,
file: MediaFileType,
user_id: str,
return_content: bool = False,
execution_context: "ExecutionContext",
*,
return_format: MediaReturnFormat | None = None,
# Deprecated parameters - use return_format instead
return_content: bool | None = None,
save_to_workspace: bool | None = None,
) -> MediaFileType:
"""
Safely handle 'file' (a data URI, a URL, or a local path relative to {temp}/exec_file/{exec_id}),
placing or verifying it under:
Safely handle 'file' (a data URI, a URL, a workspace:// reference, or a local path
relative to {temp}/exec_file/{exec_id}), placing or verifying it under:
{tempdir}/exec_file/{exec_id}/...
If 'return_content=True', return a data URI (data:<mime>;base64,<content>).
Otherwise, returns the file media path relative to the exec_id folder.
For each MediaFileType input:
- Data URI: decode and store locally
- URL: download and store locally
- workspace:// reference: read from workspace, store locally
- Local path: verify it exists in exec_file directory
For each MediaFileType type:
- Data URI:
-> decode and store in a new random file in that folder
- URL:
-> download and store in that folder
- Local path:
-> interpret as relative to that folder; verify it exists
(no copying, as it's presumably already there).
We realpath-check so no symlink or '..' can escape the folder.
Return format options:
- "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
- "for_external_api": Returns data URI (base64) - use when sending to external APIs
- "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs
:param graph_exec_id: The unique ID of the graph execution.
:param file: Data URI, URL, or local (relative) path.
:param return_content: If True, return a data URI of the file content.
If False, return the *relative* path inside the exec_id folder.
:return: The requested result: data URI or relative path of the media.
:param file: Data URI, URL, workspace://, or local (relative) path.
:param execution_context: ExecutionContext with user_id, graph_exec_id, workspace_id.
:param return_format: What to return: "for_local_processing", "for_external_api", or "for_block_output".
:param return_content: DEPRECATED. Use return_format instead.
:param save_to_workspace: DEPRECATED. Use return_format instead.
:return: The requested result based on return_format.
"""
# Handle deprecated parameters
if return_format is None:
if return_content is not None or save_to_workspace is not None:
warnings.warn(
"return_content and save_to_workspace are deprecated. "
"Use return_format='for_local_processing', 'for_external_api', or 'for_block_output' instead.",
DeprecationWarning,
stacklevel=2,
)
# Map old parameters to new return_format
if return_content is False or (
return_content is None and save_to_workspace is None
):
# Default or explicit return_content=False -> for_local_processing
return_format = "for_local_processing"
elif save_to_workspace is False:
# return_content=True, save_to_workspace=False -> for_external_api
return_format = "for_external_api"
else:
# return_content=True, save_to_workspace=True (or default) -> for_block_output
return_format = "for_block_output"
# Extract values from execution_context
graph_exec_id = execution_context.graph_exec_id
user_id = execution_context.user_id
if not graph_exec_id:
raise ValueError("execution_context.graph_exec_id is required")
if not user_id:
raise ValueError("execution_context.user_id is required")
# Create workspace_manager if we have workspace_id (with session scoping)
workspace_manager: WorkspaceManager | None = None
if execution_context.workspace_id:
workspace_manager = WorkspaceManager(
user_id, execution_context.workspace_id, execution_context.session_id
)
# Build base path
base_path = Path(get_exec_file_path(graph_exec_id, ""))
base_path.mkdir(parents=True, exist_ok=True)
@@ -142,9 +195,57 @@ async def store_media_file(
"""
return str(absolute_path.relative_to(base))
# Check if this is a cloud storage path
# Get cloud storage handler for checking cloud paths
cloud_storage = await get_cloud_storage_handler()
if cloud_storage.is_cloud_path(file):
# Track if the input came from workspace (don't re-save it)
is_from_workspace = file.startswith("workspace://")
# Check if this is a workspace file reference
if is_from_workspace:
if workspace_manager is None:
raise ValueError(
"Workspace file reference requires workspace context. "
"This file type is only available in CoPilot sessions."
)
# Parse workspace reference
# workspace://abc123 - by file ID
# workspace:///path/to/file.txt - by virtual path
file_ref = file[12:] # Remove "workspace://"
if file_ref.startswith("/"):
# Path reference
workspace_content = await workspace_manager.read_file(file_ref)
file_info = await workspace_manager.get_file_info_by_path(file_ref)
filename = sanitize_filename(
file_info.name if file_info else f"{uuid.uuid4()}.bin"
)
else:
# ID reference
workspace_content = await workspace_manager.read_file_by_id(file_ref)
file_info = await workspace_manager.get_file_info(file_ref)
filename = sanitize_filename(
file_info.name if file_info else f"{uuid.uuid4()}.bin"
)
try:
target_path = _ensure_inside_base(base_path / filename, base_path)
except OSError as e:
raise ValueError(f"Invalid file path '{filename}': {e}") from e
# Check file size limit
if len(workspace_content) > MAX_FILE_SIZE:
raise ValueError(
f"File too large: {len(workspace_content)} bytes > {MAX_FILE_SIZE} bytes"
)
# Virus scan the workspace content before writing locally
await scan_content_safe(workspace_content, filename=filename)
target_path.write_bytes(workspace_content)
# Check if this is a cloud storage path
elif cloud_storage.is_cloud_path(file):
# Download from cloud storage and store locally
cloud_content = await cloud_storage.retrieve_file(
file, user_id=user_id, graph_exec_id=graph_exec_id
@@ -230,12 +331,45 @@ async def store_media_file(
if not target_path.is_file():
raise ValueError(f"Local file does not exist: {target_path}")
# Return result
if return_content:
return MediaFileType(_file_to_data_uri(target_path))
else:
# Return based on requested format
if return_format == "for_local_processing":
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL
# Returns: relative path in exec_file directory (e.g., "image.png")
return MediaFileType(_strip_base_prefix(target_path, base_path))
elif return_format == "for_external_api":
# Use when sending content to external APIs that need base64
# Returns: data URI (e.g., "...")
return MediaFileType(_file_to_data_uri(target_path))
elif return_format == "for_block_output":
# Use when returning output from a block to user/next block
# Returns: workspace:// ref (CoPilot) or data URI (graph execution)
if workspace_manager is None:
# No workspace available (graph execution without CoPilot)
# Fallback to data URI so the content can still be used/displayed
return MediaFileType(_file_to_data_uri(target_path))
# Don't re-save if input was already from workspace
if is_from_workspace:
# Return original workspace reference
return MediaFileType(file)
# Save new content to workspace
content = target_path.read_bytes()
filename = target_path.name
file_record = await workspace_manager.write_file(
content=content,
filename=filename,
source=WorkspaceFileSource.COPILOT,
overwrite=True,
)
return MediaFileType(f"workspace://{file_record.id}")
else:
raise ValueError(f"Invalid return_format: {return_format}")
def get_dir_size(path: Path) -> int:
"""Get total size of directory."""

View File

@@ -7,10 +7,22 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.data.execution import ExecutionContext
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
def make_test_context(
graph_exec_id: str = "test-exec-123",
user_id: str = "test-user-123",
) -> ExecutionContext:
"""Helper to create test ExecutionContext."""
return ExecutionContext(
user_id=user_id,
graph_exec_id=graph_exec_id,
)
class TestFileCloudIntegration:
"""Test cases for cloud storage integration in file utilities."""
@@ -70,10 +82,9 @@ class TestFileCloudIntegration:
mock_path_class.side_effect = path_constructor
result = await store_media_file(
graph_exec_id,
MediaFileType(cloud_path),
"test-user-123",
return_content=False,
file=MediaFileType(cloud_path),
execution_context=make_test_context(graph_exec_id=graph_exec_id),
return_format="for_local_processing",
)
# Verify cloud storage operations
@@ -144,10 +155,9 @@ class TestFileCloudIntegration:
mock_path_obj.name = "image.png"
with patch("backend.util.file.Path", return_value=mock_path_obj):
result = await store_media_file(
graph_exec_id,
MediaFileType(cloud_path),
"test-user-123",
return_content=True,
file=MediaFileType(cloud_path),
execution_context=make_test_context(graph_exec_id=graph_exec_id),
return_format="for_external_api",
)
# Verify result is a data URI
@@ -198,10 +208,9 @@ class TestFileCloudIntegration:
mock_resolved_path.relative_to.return_value = Path("test-uuid-789.txt")
await store_media_file(
graph_exec_id,
MediaFileType(data_uri),
"test-user-123",
return_content=False,
file=MediaFileType(data_uri),
execution_context=make_test_context(graph_exec_id=graph_exec_id),
return_format="for_local_processing",
)
# Verify cloud handler was checked but not used for retrieval
@@ -234,5 +243,6 @@ class TestFileCloudIntegration:
FileNotFoundError, match="File not found in cloud storage"
):
await store_media_file(
graph_exec_id, MediaFileType(cloud_path), "test-user-123"
file=MediaFileType(cloud_path),
execution_context=make_test_context(graph_exec_id=graph_exec_id),
)

View File

@@ -263,6 +263,12 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="The name of the Google Cloud Storage bucket for media files",
)
workspace_storage_dir: str = Field(
default="",
description="Local directory for workspace file storage when GCS is not configured. "
"If empty, defaults to {app_data}/workspaces. Used for self-hosted deployments.",
)
reddit_user_agent: str = Field(
default="web:AutoGPT:v0.6.0 (by /u/autogpt)",
description="The user agent for the Reddit API",
@@ -359,8 +365,8 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="The port for the Agent Generator service",
)
agentgenerator_timeout: int = Field(
default=120,
description="The timeout in seconds for Agent Generator service requests",
default=600,
description="The timeout in seconds for Agent Generator service requests (includes retries for rate limits)",
)
enable_example_blocks: bool = Field(

View File

@@ -140,14 +140,29 @@ async def execute_block_test(block: Block):
setattr(block, mock_name, mock_obj)
# Populate credentials argument(s)
# Generate IDs for execution context
graph_id = str(uuid.uuid4())
node_id = str(uuid.uuid4())
graph_exec_id = str(uuid.uuid4())
node_exec_id = str(uuid.uuid4())
user_id = str(uuid.uuid4())
graph_version = 1 # Default version for tests
extra_exec_kwargs: dict = {
"graph_id": str(uuid.uuid4()),
"node_id": str(uuid.uuid4()),
"graph_exec_id": str(uuid.uuid4()),
"node_exec_id": str(uuid.uuid4()),
"user_id": str(uuid.uuid4()),
"graph_version": 1, # Default version for tests
"execution_context": ExecutionContext(),
"graph_id": graph_id,
"node_id": node_id,
"graph_exec_id": graph_exec_id,
"node_exec_id": node_exec_id,
"user_id": user_id,
"graph_version": graph_version,
"execution_context": ExecutionContext(
user_id=user_id,
graph_id=graph_id,
graph_exec_id=graph_exec_id,
graph_version=graph_version,
node_id=node_id,
node_exec_id=node_exec_id,
),
}
input_model = cast(type[BlockSchema], block.input_schema)

View File

@@ -0,0 +1,370 @@
"""
WorkspaceManager for managing user workspace file operations.
This module provides a high-level interface for workspace file operations,
combining the storage backend and database layer.
"""
import logging
import mimetypes
import uuid
from typing import Optional
from prisma.enums import WorkspaceFileSource
from prisma.models import UserWorkspaceFile
from backend.data.workspace import (
count_workspace_files,
create_workspace_file,
get_workspace_file,
get_workspace_file_by_path,
list_workspace_files,
soft_delete_workspace_file,
workspace_file_exists,
)
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
logger = logging.getLogger(__name__)
# Maximum file size: 100MB per file
MAX_FILE_SIZE_BYTES = 100 * 1024 * 1024
class WorkspaceManager:
"""
Manages workspace file operations.
Combines storage backend operations with database record management.
Supports session-scoped file segmentation where files are stored in
session-specific virtual paths: /sessions/{session_id}/{filename}
"""
def __init__(
self, user_id: str, workspace_id: str, session_id: Optional[str] = None
):
"""
Initialize WorkspaceManager.
Args:
user_id: The user's ID
workspace_id: The workspace ID
session_id: Optional session ID for session-scoped file access
"""
self.user_id = user_id
self.workspace_id = workspace_id
self.session_id = session_id
# Session path prefix for file isolation
self.session_path = f"/sessions/{session_id}" if session_id else ""
def _resolve_path(self, path: str) -> str:
"""
Resolve a path, defaulting to session folder if session_id is set.
Cross-session access is allowed by explicitly using /sessions/other-session-id/...
Args:
path: Virtual path (e.g., "/file.txt" or "/sessions/abc123/file.txt")
Returns:
Resolved path with session prefix if applicable
"""
# If path explicitly references a session folder, use it as-is
if path.startswith("/sessions/"):
return path
# If we have a session context, prepend session path
if self.session_path:
# Normalize the path
if not path.startswith("/"):
path = f"/{path}"
return f"{self.session_path}{path}"
# No session context, use path as-is
return path if path.startswith("/") else f"/{path}"
async def read_file(self, path: str) -> bytes:
"""
Read file from workspace by virtual path.
When session_id is set, paths are resolved relative to the session folder
unless they explicitly reference /sessions/...
Args:
path: Virtual path (e.g., "/documents/report.pdf")
Returns:
File content as bytes
Raises:
FileNotFoundError: If file doesn't exist
"""
resolved_path = self._resolve_path(path)
file = await get_workspace_file_by_path(self.workspace_id, resolved_path)
if file is None:
raise FileNotFoundError(f"File not found at path: {resolved_path}")
storage = await get_workspace_storage()
return await storage.retrieve(file.storagePath)
async def read_file_by_id(self, file_id: str) -> bytes:
"""
Read file from workspace by file ID.
Args:
file_id: The file's ID
Returns:
File content as bytes
Raises:
FileNotFoundError: If file doesn't exist
"""
file = await get_workspace_file(file_id, self.workspace_id)
if file is None:
raise FileNotFoundError(f"File not found: {file_id}")
storage = await get_workspace_storage()
return await storage.retrieve(file.storagePath)
async def write_file(
self,
content: bytes,
filename: str,
path: Optional[str] = None,
mime_type: Optional[str] = None,
source: WorkspaceFileSource = WorkspaceFileSource.UPLOAD,
source_exec_id: Optional[str] = None,
source_session_id: Optional[str] = None,
overwrite: bool = False,
) -> UserWorkspaceFile:
"""
Write file to workspace.
When session_id is set, files are written to /sessions/{session_id}/...
by default. Use explicit /sessions/... paths for cross-session access.
Args:
content: File content as bytes
filename: Filename for the file
path: Virtual path (defaults to "/{filename}", session-scoped if session_id set)
mime_type: MIME type (auto-detected if not provided)
source: How the file was created
source_exec_id: Graph execution ID if from execution
source_session_id: Chat session ID if from CoPilot
overwrite: Whether to overwrite existing file at path
Returns:
Created UserWorkspaceFile instance
Raises:
ValueError: If file exceeds size limit or path already exists
"""
# Enforce file size limit
if len(content) > MAX_FILE_SIZE_BYTES:
raise ValueError(
f"File too large: {len(content)} bytes exceeds "
f"{MAX_FILE_SIZE_BYTES // (1024*1024)}MB limit"
)
# Determine path with session scoping
if path is None:
path = f"/{filename}"
elif not path.startswith("/"):
path = f"/{path}"
# Resolve path with session prefix
path = self._resolve_path(path)
# Check if file exists at path
existing = await get_workspace_file_by_path(self.workspace_id, path)
if existing is not None:
if overwrite:
# Delete existing file first
await self.delete_file(existing.id)
else:
raise ValueError(f"File already exists at path: {path}")
# Auto-detect MIME type if not provided
if mime_type is None:
mime_type, _ = mimetypes.guess_type(filename)
mime_type = mime_type or "application/octet-stream"
# Compute checksum
checksum = compute_file_checksum(content)
# Generate unique file ID for storage
file_id = str(uuid.uuid4())
# Store file in storage backend
storage = await get_workspace_storage()
storage_path = await storage.store(
workspace_id=self.workspace_id,
file_id=file_id,
filename=filename,
content=content,
)
# Create database record
file = await create_workspace_file(
workspace_id=self.workspace_id,
name=filename,
path=path,
storage_path=storage_path,
mime_type=mime_type,
size_bytes=len(content),
checksum=checksum,
source=source,
source_exec_id=source_exec_id,
source_session_id=source_session_id,
)
logger.info(
f"Wrote file {file.id} ({filename}) to workspace {self.workspace_id} "
f"at path {path}, size={len(content)} bytes"
)
return file
async def list_files(
self,
path: Optional[str] = None,
limit: Optional[int] = None,
offset: int = 0,
include_all_sessions: bool = False,
) -> list[UserWorkspaceFile]:
"""
List files in workspace.
When session_id is set and include_all_sessions is False (default),
only files in the current session's folder are listed.
Args:
path: Optional path prefix to filter (e.g., "/documents/")
limit: Maximum number of files to return
offset: Number of files to skip
include_all_sessions: If True, list files from all sessions.
If False (default), only list current session's files.
Returns:
List of UserWorkspaceFile instances
"""
# Determine the effective path prefix
if include_all_sessions:
# Use provided path as-is (or None for all files)
effective_path = path
elif path is not None:
# Resolve the provided path with session scoping
effective_path = self._resolve_path(path)
elif self.session_path:
# Default to session folder
effective_path = self.session_path
else:
# No session context, list all
effective_path = path
return await list_workspace_files(
workspace_id=self.workspace_id,
path_prefix=effective_path,
limit=limit,
offset=offset,
)
async def delete_file(self, file_id: str) -> bool:
"""
Delete a file (soft-delete).
Args:
file_id: The file's ID
Returns:
True if deleted, False if not found
"""
file = await get_workspace_file(file_id, self.workspace_id)
if file is None:
return False
# Delete from storage
storage = await get_workspace_storage()
try:
await storage.delete(file.storagePath)
except Exception as e:
logger.warning(f"Failed to delete file from storage: {e}")
# Continue with database soft-delete even if storage delete fails
# Soft-delete database record
result = await soft_delete_workspace_file(file_id, self.workspace_id)
return result is not None
async def get_download_url(self, file_id: str, expires_in: int = 3600) -> str:
"""
Get download URL for a file.
Args:
file_id: The file's ID
expires_in: URL expiration in seconds (default 1 hour)
Returns:
Download URL (signed URL for GCS, API endpoint for local)
Raises:
FileNotFoundError: If file doesn't exist
"""
file = await get_workspace_file(file_id, self.workspace_id)
if file is None:
raise FileNotFoundError(f"File not found: {file_id}")
storage = await get_workspace_storage()
return await storage.get_download_url(file.storagePath, expires_in)
async def get_file_info(self, file_id: str) -> Optional[UserWorkspaceFile]:
"""
Get file metadata.
Args:
file_id: The file's ID
Returns:
UserWorkspaceFile instance or None
"""
return await get_workspace_file(file_id, self.workspace_id)
async def get_file_info_by_path(self, path: str) -> Optional[UserWorkspaceFile]:
"""
Get file metadata by path.
When session_id is set, paths are resolved relative to the session folder
unless they explicitly reference /sessions/...
Args:
path: Virtual path
Returns:
UserWorkspaceFile instance or None
"""
resolved_path = self._resolve_path(path)
return await get_workspace_file_by_path(self.workspace_id, resolved_path)
async def file_exists(self, path: str) -> bool:
"""
Check if a file exists at the given path.
When session_id is set, paths are resolved relative to the session folder
unless they explicitly reference /sessions/...
Args:
path: Virtual path
Returns:
True if file exists
"""
resolved_path = self._resolve_path(path)
return await workspace_file_exists(self.workspace_id, resolved_path)
async def get_file_count(self) -> int:
"""
Get number of files in workspace.
Returns:
Number of files
"""
return await count_workspace_files(self.workspace_id)

View File

@@ -0,0 +1,449 @@
"""
Workspace storage backend abstraction for supporting both cloud and local deployments.
This module provides a unified interface for storing workspace files, with implementations
for Google Cloud Storage (cloud deployments) and local filesystem (self-hosted deployments).
"""
import asyncio
import hashlib
import logging
from abc import ABC, abstractmethod
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Optional
import aiofiles
import aiohttp
from gcloud.aio import storage as async_gcs_storage
from google.cloud import storage as gcs_storage
from backend.util.data import get_data_path
from backend.util.settings import Config
logger = logging.getLogger(__name__)
class WorkspaceStorageBackend(ABC):
"""Abstract interface for workspace file storage."""
@abstractmethod
async def store(
self,
workspace_id: str,
file_id: str,
filename: str,
content: bytes,
) -> str:
"""
Store file content, return storage path.
Args:
workspace_id: The workspace ID
file_id: Unique file ID for storage
filename: Original filename
content: File content as bytes
Returns:
Storage path string (cloud path or local path)
"""
pass
@abstractmethod
async def retrieve(self, storage_path: str) -> bytes:
"""
Retrieve file content from storage.
Args:
storage_path: The storage path returned from store()
Returns:
File content as bytes
"""
pass
@abstractmethod
async def delete(self, storage_path: str) -> None:
"""
Delete file from storage.
Args:
storage_path: The storage path to delete
"""
pass
@abstractmethod
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
"""
Get URL for downloading the file.
Args:
storage_path: The storage path
expires_in: URL expiration time in seconds (default 1 hour)
Returns:
Download URL (signed URL for GCS, direct API path for local)
"""
pass
@abstractmethod
async def exists(self, storage_path: str) -> bool:
"""
Check if a file exists at the storage path.
Args:
storage_path: The storage path to check
Returns:
True if file exists, False otherwise
"""
pass
class GCSWorkspaceStorage(WorkspaceStorageBackend):
"""Google Cloud Storage implementation for workspace storage."""
def __init__(self, bucket_name: str):
self.bucket_name = bucket_name
self._async_client: Optional[async_gcs_storage.Storage] = None
self._sync_client: Optional[gcs_storage.Client] = None
self._session: Optional[aiohttp.ClientSession] = None
async def _get_async_client(self) -> async_gcs_storage.Storage:
"""Get or create async GCS client."""
if self._async_client is None:
self._session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=100, force_close=False)
)
self._async_client = async_gcs_storage.Storage(session=self._session)
return self._async_client
def _get_sync_client(self) -> gcs_storage.Client:
"""Get or create sync GCS client (for signed URLs)."""
if self._sync_client is None:
self._sync_client = gcs_storage.Client()
return self._sync_client
async def close(self) -> None:
"""Close all client connections."""
if self._async_client is not None:
try:
await self._async_client.close()
except Exception as e:
logger.warning(f"Error closing GCS client: {e}")
self._async_client = None
if self._session is not None:
try:
await self._session.close()
except Exception as e:
logger.warning(f"Error closing session: {e}")
self._session = None
def _build_blob_name(self, workspace_id: str, file_id: str, filename: str) -> str:
"""Build the blob path for workspace files."""
return f"workspaces/{workspace_id}/{file_id}/{filename}"
async def store(
self,
workspace_id: str,
file_id: str,
filename: str,
content: bytes,
) -> str:
"""Store file in GCS."""
client = await self._get_async_client()
blob_name = self._build_blob_name(workspace_id, file_id, filename)
# Upload with metadata
upload_time = datetime.now(timezone.utc)
await client.upload(
self.bucket_name,
blob_name,
content,
metadata={
"uploaded_at": upload_time.isoformat(),
"workspace_id": workspace_id,
"file_id": file_id,
},
)
return f"gcs://{self.bucket_name}/{blob_name}"
async def retrieve(self, storage_path: str) -> bytes:
"""Retrieve file from GCS."""
if not storage_path.startswith("gcs://"):
raise ValueError(f"Invalid GCS path: {storage_path}")
# Parse bucket and blob name
path = storage_path[6:] # Remove "gcs://"
parts = path.split("/", 1)
if len(parts) != 2:
raise ValueError(f"Invalid GCS path format: {storage_path}")
bucket_name, blob_name = parts
# Create fresh session for download
session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=10, force_close=True)
)
try:
client = async_gcs_storage.Storage(session=session)
content = await client.download(bucket_name, blob_name)
await client.close()
return content
except Exception as e:
if "404" in str(e) or "Not Found" in str(e):
raise FileNotFoundError(f"File not found: {storage_path}")
raise
finally:
await session.close()
async def delete(self, storage_path: str) -> None:
"""Delete file from GCS."""
if not storage_path.startswith("gcs://"):
raise ValueError(f"Invalid GCS path: {storage_path}")
path = storage_path[6:]
parts = path.split("/", 1)
if len(parts) != 2:
raise ValueError(f"Invalid GCS path format: {storage_path}")
bucket_name, blob_name = parts
client = await self._get_async_client()
try:
await client.delete(bucket_name, blob_name)
except Exception as e:
if "404" not in str(e) and "Not Found" not in str(e):
raise
# File already deleted, that's fine
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
"""
Generate download URL for GCS file.
Attempts to generate a signed URL if running with service account credentials.
Falls back to an API proxy endpoint if signed URL generation fails
(e.g., when running locally with user OAuth credentials).
"""
if not storage_path.startswith("gcs://"):
raise ValueError(f"Invalid GCS path: {storage_path}")
path = storage_path[6:]
parts = path.split("/", 1)
if len(parts) != 2:
raise ValueError(f"Invalid GCS path format: {storage_path}")
bucket_name, blob_name = parts
# Extract file_id from blob_name for fallback: workspaces/{workspace_id}/{file_id}/{filename}
blob_parts = blob_name.split("/")
file_id = blob_parts[2] if len(blob_parts) >= 3 else None
# Try to generate signed URL (requires service account credentials)
try:
sync_client = self._get_sync_client()
bucket = sync_client.bucket(bucket_name)
blob = bucket.blob(blob_name)
url = await asyncio.to_thread(
blob.generate_signed_url,
version="v4",
expiration=datetime.now(timezone.utc) + timedelta(seconds=expires_in),
method="GET",
)
return url
except AttributeError as e:
# Signed URL generation requires service account with private key.
# When running with user OAuth credentials, fall back to API proxy.
if "private key" in str(e) and file_id:
logger.debug(
"Cannot generate signed URL (no service account credentials), "
"falling back to API proxy endpoint"
)
return f"/api/workspace/files/{file_id}/download"
raise
async def exists(self, storage_path: str) -> bool:
"""Check if file exists in GCS."""
if not storage_path.startswith("gcs://"):
return False
path = storage_path[6:]
parts = path.split("/", 1)
if len(parts) != 2:
return False
bucket_name, blob_name = parts
try:
client = await self._get_async_client()
await client.download_metadata(bucket_name, blob_name)
return True
except Exception:
return False
class LocalWorkspaceStorage(WorkspaceStorageBackend):
"""Local filesystem implementation for workspace storage (self-hosted deployments)."""
def __init__(self, base_dir: Optional[str] = None):
"""
Initialize local storage backend.
Args:
base_dir: Base directory for workspace storage.
If None, defaults to {app_data}/workspaces
"""
if base_dir:
self.base_dir = Path(base_dir)
else:
self.base_dir = Path(get_data_path()) / "workspaces"
# Ensure base directory exists
self.base_dir.mkdir(parents=True, exist_ok=True)
def _build_file_path(self, workspace_id: str, file_id: str, filename: str) -> Path:
"""Build the local file path."""
return self.base_dir / workspace_id / file_id / filename
def _parse_storage_path(self, storage_path: str) -> Path:
"""Parse local storage path to filesystem path."""
if storage_path.startswith("local://"):
relative_path = storage_path[8:] # Remove "local://"
else:
relative_path = storage_path
full_path = (self.base_dir / relative_path).resolve()
# Security check: ensure path is under base_dir
if not str(full_path).startswith(str(self.base_dir.resolve())):
raise ValueError("Invalid storage path: path traversal detected")
return full_path
async def store(
self,
workspace_id: str,
file_id: str,
filename: str,
content: bytes,
) -> str:
"""Store file locally."""
file_path = self._build_file_path(workspace_id, file_id, filename)
# Create parent directories
file_path.parent.mkdir(parents=True, exist_ok=True)
# Write file asynchronously
async with aiofiles.open(file_path, "wb") as f:
await f.write(content)
# Return relative path as storage path
relative_path = file_path.relative_to(self.base_dir)
return f"local://{relative_path}"
async def retrieve(self, storage_path: str) -> bytes:
"""Retrieve file from local storage."""
file_path = self._parse_storage_path(storage_path)
if not file_path.exists():
raise FileNotFoundError(f"File not found: {storage_path}")
async with aiofiles.open(file_path, "rb") as f:
return await f.read()
async def delete(self, storage_path: str) -> None:
"""Delete file from local storage."""
file_path = self._parse_storage_path(storage_path)
if file_path.exists():
# Remove file
file_path.unlink()
# Clean up empty parent directories
parent = file_path.parent
while parent != self.base_dir:
try:
if parent.exists() and not any(parent.iterdir()):
parent.rmdir()
else:
break
except OSError:
break
parent = parent.parent
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
"""
Get download URL for local file.
For local storage, this returns an API endpoint path.
The actual serving is handled by the API layer.
"""
# Parse the storage path to get the components
if storage_path.startswith("local://"):
relative_path = storage_path[8:]
else:
relative_path = storage_path
# Return the API endpoint for downloading
# The file_id is extracted from the path: {workspace_id}/{file_id}/{filename}
parts = relative_path.split("/")
if len(parts) >= 2:
file_id = parts[1] # Second component is file_id
return f"/api/workspace/files/{file_id}/download"
else:
raise ValueError(f"Invalid storage path format: {storage_path}")
async def exists(self, storage_path: str) -> bool:
"""Check if file exists locally."""
try:
file_path = self._parse_storage_path(storage_path)
return file_path.exists()
except ValueError:
return False
# Global storage backend instance
_workspace_storage: Optional[WorkspaceStorageBackend] = None
_storage_lock = asyncio.Lock()
async def get_workspace_storage() -> WorkspaceStorageBackend:
"""
Get the workspace storage backend instance.
Uses GCS if media_gcs_bucket_name is configured, otherwise uses local storage.
"""
global _workspace_storage
if _workspace_storage is None:
async with _storage_lock:
if _workspace_storage is None:
config = Config()
if config.media_gcs_bucket_name:
logger.info(
f"Using GCS workspace storage: {config.media_gcs_bucket_name}"
)
_workspace_storage = GCSWorkspaceStorage(
config.media_gcs_bucket_name
)
else:
storage_dir = (
config.workspace_storage_dir
if config.workspace_storage_dir
else None
)
logger.info(
f"Using local workspace storage: {storage_dir or 'default'}"
)
_workspace_storage = LocalWorkspaceStorage(storage_dir)
return _workspace_storage
def compute_file_checksum(content: bytes) -> str:
"""Compute SHA256 checksum of file content."""
return hashlib.sha256(content).hexdigest()

View File

@@ -0,0 +1,52 @@
-- CreateEnum
CREATE TYPE "WorkspaceFileSource" AS ENUM ('UPLOAD', 'EXECUTION', 'COPILOT', 'IMPORT');
-- CreateTable
CREATE TABLE "UserWorkspace" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"userId" TEXT NOT NULL,
CONSTRAINT "UserWorkspace_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "UserWorkspaceFile" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"workspaceId" TEXT NOT NULL,
"name" TEXT NOT NULL,
"path" TEXT NOT NULL,
"storagePath" TEXT NOT NULL,
"mimeType" TEXT NOT NULL,
"sizeBytes" BIGINT NOT NULL,
"checksum" TEXT,
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
"deletedAt" TIMESTAMP(3),
"source" "WorkspaceFileSource" NOT NULL DEFAULT 'UPLOAD',
"sourceExecId" TEXT,
"sourceSessionId" TEXT,
"metadata" JSONB NOT NULL DEFAULT '{}',
CONSTRAINT "UserWorkspaceFile_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "UserWorkspace_userId_key" ON "UserWorkspace"("userId");
-- CreateIndex
CREATE INDEX "UserWorkspace_userId_idx" ON "UserWorkspace"("userId");
-- CreateIndex
CREATE INDEX "UserWorkspaceFile_workspaceId_isDeleted_idx" ON "UserWorkspaceFile"("workspaceId", "isDeleted");
-- CreateIndex
CREATE UNIQUE INDEX "UserWorkspaceFile_workspaceId_path_key" ON "UserWorkspaceFile"("workspaceId", "path");
-- AddForeignKey
ALTER TABLE "UserWorkspace" ADD CONSTRAINT "UserWorkspace_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "UserWorkspaceFile" ADD CONSTRAINT "UserWorkspaceFile_workspaceId_fkey" FOREIGN KEY ("workspaceId") REFERENCES "UserWorkspace"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -63,6 +63,7 @@ model User {
IntegrationWebhooks IntegrationWebhook[]
NotificationBatches UserNotificationBatch[]
PendingHumanReviews PendingHumanReview[]
Workspace UserWorkspace?
// OAuth Provider relations
OAuthApplications OAuthApplication[]
@@ -136,6 +137,66 @@ model CoPilotUnderstanding {
@@index([userId])
}
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
//////////////// USER WORKSPACE TABLES /////////////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
// User's persistent file storage workspace
model UserWorkspace {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
userId String @unique
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
Files UserWorkspaceFile[]
@@index([userId])
}
// Source of workspace file creation
enum WorkspaceFileSource {
UPLOAD // Direct user upload
EXECUTION // Created by graph execution
COPILOT // Created by CoPilot session
IMPORT // Imported from external source
}
// Individual files in a user's workspace
model UserWorkspaceFile {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
workspaceId String
Workspace UserWorkspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
// File metadata
name String // User-visible filename
path String // Virtual path (e.g., "/documents/report.pdf")
storagePath String // Actual GCS or local storage path
mimeType String
sizeBytes BigInt
checksum String? // SHA256 for integrity
// File state
isDeleted Boolean @default(false)
deletedAt DateTime?
// Source tracking
source WorkspaceFileSource @default(UPLOAD)
sourceExecId String? // graph_exec_id if from execution
sourceSessionId String? // chat_session_id if from CoPilot
metadata Json @default("{}")
@@unique([workspaceId, path])
@@index([workspaceId, isDeleted])
}
model BuilderSearchHistory {
id String @id @default(uuid())
createdAt DateTime @default(now())

View File

@@ -1,12 +1,10 @@
"use client";
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
import { Text } from "@/components/atoms/Text/Text";
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
import type { ReactNode } from "react";
import { useEffect } from "react";
import { useCopilotStore } from "../../copilot-page-store";
import { DesktopSidebar } from "./components/DesktopSidebar/DesktopSidebar";
import { LoadingState } from "./components/LoadingState/LoadingState";
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
import { useCopilotShell } from "./useCopilotShell";
@@ -20,38 +18,21 @@ export function CopilotShell({ children }: Props) {
isMobile,
isDrawerOpen,
isLoading,
isCreatingSession,
isLoggedIn,
hasActiveSession,
sessions,
currentSessionId,
handleSelectSession,
handleOpenDrawer,
handleCloseDrawer,
handleDrawerOpenChange,
handleNewChat,
handleNewChatClick,
handleSessionClick,
hasNextPage,
isFetchingNextPage,
fetchNextPage,
isReadyToShowContent,
} = useCopilotShell();
const setNewChatHandler = useCopilotStore((s) => s.setNewChatHandler);
const requestNewChat = useCopilotStore((s) => s.requestNewChat);
useEffect(
function registerNewChatHandler() {
setNewChatHandler(handleNewChat);
return function cleanup() {
setNewChatHandler(null);
};
},
[handleNewChat],
);
function handleNewChatClick() {
requestNewChat();
}
if (!isLoggedIn) {
return (
<div className="flex h-full items-center justify-center">
@@ -72,7 +53,7 @@ export function CopilotShell({ children }: Props) {
isLoading={isLoading}
hasNextPage={hasNextPage}
isFetchingNextPage={isFetchingNextPage}
onSelectSession={handleSelectSession}
onSelectSession={handleSessionClick}
onFetchNextPage={fetchNextPage}
onNewChat={handleNewChatClick}
hasActiveSession={Boolean(hasActiveSession)}
@@ -82,7 +63,18 @@ export function CopilotShell({ children }: Props) {
<div className="relative flex min-h-0 flex-1 flex-col">
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
<div className="flex min-h-0 flex-1 flex-col">
{isReadyToShowContent ? children : <LoadingState />}
{isCreatingSession ? (
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
<div className="flex flex-col items-center gap-4">
<ChatLoader />
<Text variant="body" className="text-zinc-500">
Creating your chat...
</Text>
</div>
</div>
) : (
children
)}
</div>
</div>
@@ -94,7 +86,7 @@ export function CopilotShell({ children }: Props) {
isLoading={isLoading}
hasNextPage={hasNextPage}
isFetchingNextPage={isFetchingNextPage}
onSelectSession={handleSelectSession}
onSelectSession={handleSessionClick}
onFetchNextPage={fetchNextPage}
onNewChat={handleNewChatClick}
onClose={handleCloseDrawer}

View File

@@ -1,15 +0,0 @@
import { Text } from "@/components/atoms/Text/Text";
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
export function LoadingState() {
return (
<div className="flex flex-1 items-center justify-center">
<div className="flex flex-col items-center gap-4">
<ChatLoader />
<Text variant="body" className="text-zinc-500">
Loading your chats...
</Text>
</div>
</div>
);
}

View File

@@ -3,17 +3,17 @@ import { useState } from "react";
export function useMobileDrawer() {
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
function handleOpenDrawer() {
const handleOpenDrawer = () => {
setIsDrawerOpen(true);
}
};
function handleCloseDrawer() {
const handleCloseDrawer = () => {
setIsDrawerOpen(false);
}
};
function handleDrawerOpenChange(open: boolean) {
const handleDrawerOpenChange = (open: boolean) => {
setIsDrawerOpen(open);
}
};
return {
isDrawerOpen,

View File

@@ -1,11 +1,6 @@
import {
getGetV2ListSessionsQueryKey,
useGetV2ListSessions,
} from "@/app/api/__generated__/endpoints/chat/chat";
import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat";
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { okData } from "@/app/api/helpers";
import { useChatStore } from "@/components/contextual/Chat/chat-store";
import { useQueryClient } from "@tanstack/react-query";
import { useEffect, useState } from "react";
const PAGE_SIZE = 50;
@@ -16,12 +11,12 @@ export interface UseSessionsPaginationArgs {
export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
const [offset, setOffset] = useState(0);
const [accumulatedSessions, setAccumulatedSessions] = useState<
SessionSummaryResponse[]
>([]);
const [totalCount, setTotalCount] = useState<number | null>(null);
const queryClient = useQueryClient();
const onStreamComplete = useChatStore((state) => state.onStreamComplete);
const { data, isLoading, isFetching, isError } = useGetV2ListSessions(
{ limit: PAGE_SIZE, offset },
@@ -32,38 +27,23 @@ export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
},
);
useEffect(function refreshOnStreamComplete() {
const unsubscribe = onStreamComplete(function handleStreamComplete() {
setOffset(0);
useEffect(() => {
const responseData = okData(data);
if (responseData) {
const newSessions = responseData.sessions;
const total = responseData.total;
setTotalCount(total);
if (offset === 0) {
setAccumulatedSessions(newSessions);
} else {
setAccumulatedSessions((prev) => [...prev, ...newSessions]);
}
} else if (!enabled) {
setAccumulatedSessions([]);
setTotalCount(null);
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
});
return unsubscribe;
}, []);
useEffect(
function updateSessionsFromResponse() {
const responseData = okData(data);
if (responseData) {
const newSessions = responseData.sessions;
const total = responseData.total;
setTotalCount(total);
if (offset === 0) {
setAccumulatedSessions(newSessions);
} else {
setAccumulatedSessions((prev) => [...prev, ...newSessions]);
}
} else if (!enabled) {
setAccumulatedSessions([]);
setTotalCount(null);
}
},
[data, offset, enabled],
);
}
}, [data, offset, enabled]);
const hasNextPage =
totalCount !== null && accumulatedSessions.length < totalCount;
@@ -86,17 +66,17 @@ export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
}
}, [hasNextPage, isFetching, isLoading, isError, totalCount]);
function fetchNextPage() {
const fetchNextPage = () => {
if (hasNextPage && !isFetching) {
setOffset((prev) => prev + PAGE_SIZE);
}
}
};
function reset() {
const reset = () => {
setOffset(0);
setAccumulatedSessions([]);
setTotalCount(null);
}
};
return {
sessions: accumulatedSessions,

View File

@@ -104,76 +104,3 @@ export function mergeCurrentSessionIntoList(
export function getCurrentSessionId(searchParams: URLSearchParams) {
return searchParams.get("sessionId");
}
export function shouldAutoSelectSession(
areAllSessionsLoaded: boolean,
hasAutoSelectedSession: boolean,
paramSessionId: string | null,
visibleSessions: SessionSummaryResponse[],
accumulatedSessions: SessionSummaryResponse[],
isLoading: boolean,
totalCount: number | null,
) {
if (!areAllSessionsLoaded || hasAutoSelectedSession) {
return {
shouldSelect: false,
sessionIdToSelect: null,
shouldCreate: false,
};
}
if (paramSessionId) {
return {
shouldSelect: false,
sessionIdToSelect: null,
shouldCreate: false,
};
}
if (visibleSessions.length > 0) {
return {
shouldSelect: true,
sessionIdToSelect: visibleSessions[0].id,
shouldCreate: false,
};
}
if (accumulatedSessions.length === 0 && !isLoading && totalCount === 0) {
return { shouldSelect: false, sessionIdToSelect: null, shouldCreate: true };
}
if (totalCount === 0) {
return {
shouldSelect: false,
sessionIdToSelect: null,
shouldCreate: false,
};
}
return { shouldSelect: false, sessionIdToSelect: null, shouldCreate: false };
}
export function checkReadyToShowContent(
areAllSessionsLoaded: boolean,
paramSessionId: string | null,
accumulatedSessions: SessionSummaryResponse[],
isCurrentSessionLoading: boolean,
currentSessionData: SessionDetailResponse | null | undefined,
hasAutoSelectedSession: boolean,
) {
if (!areAllSessionsLoaded) return false;
if (paramSessionId) {
const sessionFound = accumulatedSessions.some(
(s) => s.id === paramSessionId,
);
return (
sessionFound ||
(!isCurrentSessionLoading &&
currentSessionData !== undefined &&
currentSessionData !== null)
);
}
return hasAutoSelectedSession;
}

View File

@@ -1,26 +1,22 @@
"use client";
import {
getGetV2GetSessionQueryKey,
getGetV2ListSessionsQueryKey,
useGetV2GetSession,
} from "@/app/api/__generated__/endpoints/chat/chat";
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { okData } from "@/app/api/helpers";
import { useChatStore } from "@/components/contextual/Chat/chat-store";
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { useQueryClient } from "@tanstack/react-query";
import { parseAsString, useQueryState } from "nuqs";
import { usePathname, useSearchParams } from "next/navigation";
import { useEffect, useRef, useState } from "react";
import { useRef } from "react";
import { useCopilotStore } from "../../copilot-page-store";
import { useCopilotSessionId } from "../../useCopilotSessionId";
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination";
import {
checkReadyToShowContent,
convertSessionDetailToSummary,
filterVisibleSessions,
getCurrentSessionId,
mergeCurrentSessionIntoList,
} from "./helpers";
import { getCurrentSessionId } from "./helpers";
import { useShellSessionList } from "./useShellSessionList";
export function useCopilotShell() {
const pathname = usePathname();
@@ -31,7 +27,7 @@ export function useCopilotShell() {
const isMobile =
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
const [, setUrlSessionId] = useQueryState("sessionId", parseAsString);
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
const isOnHomepage = pathname === "/copilot";
const paramSessionId = searchParams.get("sessionId");
@@ -45,123 +41,80 @@ export function useCopilotShell() {
const paginationEnabled = !isMobile || isDrawerOpen || !!paramSessionId;
const {
sessions: accumulatedSessions,
isLoading: isSessionsLoading,
isFetching: isSessionsFetching,
hasNextPage,
areAllSessionsLoaded,
fetchNextPage,
reset: resetPagination,
} = useSessionsPagination({
enabled: paginationEnabled,
});
const currentSessionId = getCurrentSessionId(searchParams);
const { data: currentSessionData, isLoading: isCurrentSessionLoading } =
useGetV2GetSession(currentSessionId || "", {
const { data: currentSessionData } = useGetV2GetSession(
currentSessionId || "",
{
query: {
enabled: !!currentSessionId,
select: okData,
},
});
const [hasAutoSelectedSession, setHasAutoSelectedSession] = useState(false);
const hasAutoSelectedRef = useRef(false);
const recentlyCreatedSessionsRef = useRef<
Map<string, SessionSummaryResponse>
>(new Map());
// Mark as auto-selected when sessionId is in URL
useEffect(() => {
if (paramSessionId && !hasAutoSelectedRef.current) {
hasAutoSelectedRef.current = true;
setHasAutoSelectedSession(true);
}
}, [paramSessionId]);
// On homepage without sessionId, mark as ready immediately
useEffect(() => {
if (isOnHomepage && !paramSessionId && !hasAutoSelectedRef.current) {
hasAutoSelectedRef.current = true;
setHasAutoSelectedSession(true);
}
}, [isOnHomepage, paramSessionId]);
// Invalidate sessions list when navigating to homepage (to show newly created sessions)
useEffect(() => {
if (isOnHomepage && !paramSessionId) {
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
}
}, [isOnHomepage, paramSessionId, queryClient]);
// Track newly created sessions to ensure they stay visible even when switching away
useEffect(() => {
if (currentSessionId && currentSessionData) {
const isNewSession =
currentSessionData.updated_at === currentSessionData.created_at;
const isNotInAccumulated = !accumulatedSessions.some(
(s) => s.id === currentSessionId,
);
if (isNewSession || isNotInAccumulated) {
const summary = convertSessionDetailToSummary(currentSessionData);
recentlyCreatedSessionsRef.current.set(currentSessionId, summary);
}
}
}, [currentSessionId, currentSessionData, accumulatedSessions]);
// Clean up recently created sessions that are now in the accumulated list
useEffect(() => {
for (const sessionId of recentlyCreatedSessionsRef.current.keys()) {
if (accumulatedSessions.some((s) => s.id === sessionId)) {
recentlyCreatedSessionsRef.current.delete(sessionId);
}
}
}, [accumulatedSessions]);
// Reset pagination when query becomes disabled
const prevPaginationEnabledRef = useRef(paginationEnabled);
useEffect(() => {
if (prevPaginationEnabledRef.current && !paginationEnabled) {
resetPagination();
resetAutoSelect();
}
prevPaginationEnabledRef.current = paginationEnabled;
}, [paginationEnabled, resetPagination]);
const sessions = mergeCurrentSessionIntoList(
accumulatedSessions,
currentSessionId,
currentSessionData,
recentlyCreatedSessionsRef.current,
},
);
const visibleSessions = filterVisibleSessions(sessions);
const {
sessions,
isLoading,
isSessionsFetching,
hasNextPage,
fetchNextPage,
resetPagination,
recentlyCreatedSessionsRef,
} = useShellSessionList({
paginationEnabled,
currentSessionId,
currentSessionData,
isOnHomepage,
paramSessionId,
});
const sidebarSelectedSessionId =
isOnHomepage && !paramSessionId ? null : currentSessionId;
const stopStream = useChatStore((s) => s.stopStream);
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
const isStreaming = useCopilotStore((s) => s.isStreaming);
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession);
const openInterruptModal = useCopilotStore((s) => s.openInterruptModal);
const isReadyToShowContent = isOnHomepage
? true
: checkReadyToShowContent(
areAllSessionsLoaded,
paramSessionId,
accumulatedSessions,
isCurrentSessionLoading,
currentSessionData,
hasAutoSelectedSession,
);
const pendingActionRef = useRef<(() => void) | null>(null);
function handleSelectSession(sessionId: string) {
async function stopCurrentStream() {
if (!currentSessionId) return;
setIsSwitchingSession(true);
await new Promise<void>((resolve) => {
const unsubscribe = onStreamComplete((completedId) => {
if (completedId === currentSessionId) {
clearTimeout(timeout);
unsubscribe();
resolve();
}
});
const timeout = setTimeout(() => {
unsubscribe();
resolve();
}, 3000);
stopStream(currentSessionId);
});
queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(currentSessionId),
});
setIsSwitchingSession(false);
}
function selectSession(sessionId: string) {
if (sessionId === currentSessionId) return;
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(sessionId),
});
}
setUrlSessionId(sessionId, { shallow: false });
if (isMobile) handleCloseDrawer();
}
function handleNewChat() {
resetAutoSelect();
function startNewChat() {
resetPagination();
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
@@ -170,12 +123,31 @@ export function useCopilotShell() {
if (isMobile) handleCloseDrawer();
}
function resetAutoSelect() {
hasAutoSelectedRef.current = false;
setHasAutoSelectedSession(false);
function handleSessionClick(sessionId: string) {
if (sessionId === currentSessionId) return;
if (isStreaming) {
pendingActionRef.current = async () => {
await stopCurrentStream();
selectSession(sessionId);
};
openInterruptModal(pendingActionRef.current);
} else {
selectSession(sessionId);
}
}
const isLoading = isSessionsLoading && accumulatedSessions.length === 0;
function handleNewChatClick() {
if (isStreaming) {
pendingActionRef.current = async () => {
await stopCurrentStream();
startNewChat();
};
openInterruptModal(pendingActionRef.current);
} else {
startNewChat();
}
}
return {
isMobile,
@@ -183,17 +155,17 @@ export function useCopilotShell() {
isLoggedIn,
hasActiveSession:
Boolean(currentSessionId) && (!isOnHomepage || Boolean(paramSessionId)),
isLoading,
sessions: visibleSessions,
currentSessionId: sidebarSelectedSessionId,
handleSelectSession,
isLoading: isLoading || isCreatingSession,
isCreatingSession,
sessions,
currentSessionId: urlSessionId,
handleOpenDrawer,
handleCloseDrawer,
handleDrawerOpenChange,
handleNewChat,
handleNewChatClick,
handleSessionClick,
hasNextPage,
isFetchingNextPage: isSessionsFetching,
fetchNextPage,
isReadyToShowContent,
};
}

View File

@@ -0,0 +1,113 @@
import { getGetV2ListSessionsQueryKey } from "@/app/api/__generated__/endpoints/chat/chat";
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { useChatStore } from "@/components/contextual/Chat/chat-store";
import { useQueryClient } from "@tanstack/react-query";
import { useEffect, useMemo, useRef } from "react";
import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination";
import {
convertSessionDetailToSummary,
filterVisibleSessions,
mergeCurrentSessionIntoList,
} from "./helpers";
interface UseShellSessionListArgs {
paginationEnabled: boolean;
currentSessionId: string | null;
currentSessionData: SessionDetailResponse | null | undefined;
isOnHomepage: boolean;
paramSessionId: string | null;
}
export function useShellSessionList({
paginationEnabled,
currentSessionId,
currentSessionData,
isOnHomepage,
paramSessionId,
}: UseShellSessionListArgs) {
const queryClient = useQueryClient();
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
const {
sessions: accumulatedSessions,
isLoading: isSessionsLoading,
isFetching: isSessionsFetching,
hasNextPage,
fetchNextPage,
reset: resetPagination,
} = useSessionsPagination({
enabled: paginationEnabled,
});
const recentlyCreatedSessionsRef = useRef<
Map<string, SessionSummaryResponse>
>(new Map());
useEffect(() => {
if (isOnHomepage && !paramSessionId) {
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
}
}, [isOnHomepage, paramSessionId, queryClient]);
useEffect(() => {
if (currentSessionId && currentSessionData) {
const isNewSession =
currentSessionData.updated_at === currentSessionData.created_at;
const isNotInAccumulated = !accumulatedSessions.some(
(s) => s.id === currentSessionId,
);
if (isNewSession || isNotInAccumulated) {
const summary = convertSessionDetailToSummary(currentSessionData);
recentlyCreatedSessionsRef.current.set(currentSessionId, summary);
}
}
}, [currentSessionId, currentSessionData, accumulatedSessions]);
useEffect(() => {
for (const sessionId of recentlyCreatedSessionsRef.current.keys()) {
if (accumulatedSessions.some((s) => s.id === sessionId)) {
recentlyCreatedSessionsRef.current.delete(sessionId);
}
}
}, [accumulatedSessions]);
useEffect(() => {
const unsubscribe = onStreamComplete(() => {
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
});
return unsubscribe;
}, [onStreamComplete, queryClient]);
const sessions = useMemo(
() =>
mergeCurrentSessionIntoList(
accumulatedSessions,
currentSessionId,
currentSessionData,
recentlyCreatedSessionsRef.current,
),
[accumulatedSessions, currentSessionId, currentSessionData],
);
const visibleSessions = useMemo(
() => filterVisibleSessions(sessions),
[sessions],
);
const isLoading = isSessionsLoading && accumulatedSessions.length === 0;
return {
sessions: visibleSessions,
isLoading,
isSessionsFetching,
hasNextPage,
fetchNextPage,
resetPagination,
recentlyCreatedSessionsRef,
};
}

View File

@@ -4,51 +4,53 @@ import { create } from "zustand";
interface CopilotStoreState {
isStreaming: boolean;
isNewChatModalOpen: boolean;
newChatHandler: (() => void) | null;
isSwitchingSession: boolean;
isCreatingSession: boolean;
isInterruptModalOpen: boolean;
pendingAction: (() => void) | null;
}
interface CopilotStoreActions {
setIsStreaming: (isStreaming: boolean) => void;
setNewChatHandler: (handler: (() => void) | null) => void;
requestNewChat: () => void;
confirmNewChat: () => void;
cancelNewChat: () => void;
setIsSwitchingSession: (isSwitchingSession: boolean) => void;
setIsCreatingSession: (isCreating: boolean) => void;
openInterruptModal: (onConfirm: () => void) => void;
confirmInterrupt: () => void;
cancelInterrupt: () => void;
}
type CopilotStore = CopilotStoreState & CopilotStoreActions;
export const useCopilotStore = create<CopilotStore>((set, get) => ({
isStreaming: false,
isNewChatModalOpen: false,
newChatHandler: null,
isSwitchingSession: false,
isCreatingSession: false,
isInterruptModalOpen: false,
pendingAction: null,
setIsStreaming(isStreaming) {
set({ isStreaming });
},
setNewChatHandler(handler) {
set({ newChatHandler: handler });
setIsSwitchingSession(isSwitchingSession) {
set({ isSwitchingSession });
},
requestNewChat() {
const { isStreaming, newChatHandler } = get();
if (isStreaming) {
set({ isNewChatModalOpen: true });
} else if (newChatHandler) {
newChatHandler();
}
setIsCreatingSession(isCreatingSession) {
set({ isCreatingSession });
},
confirmNewChat() {
const { newChatHandler } = get();
set({ isNewChatModalOpen: false });
if (newChatHandler) {
newChatHandler();
}
openInterruptModal(onConfirm) {
set({ isInterruptModalOpen: true, pendingAction: onConfirm });
},
cancelNewChat() {
set({ isNewChatModalOpen: false });
confirmInterrupt() {
const { pendingAction } = get();
set({ isInterruptModalOpen: false, pendingAction: null });
if (pendingAction) pendingAction();
},
cancelInterrupt() {
set({ isInterruptModalOpen: false, pendingAction: null });
},
}));

View File

@@ -1,28 +1,5 @@
import type { User } from "@supabase/supabase-js";
export type PageState =
| { type: "welcome" }
| { type: "newChat" }
| { type: "creating"; prompt: string }
| { type: "chat"; sessionId: string; initialPrompt?: string };
export function getInitialPromptFromState(
pageState: PageState,
storedInitialPrompt: string | undefined,
) {
if (storedInitialPrompt) return storedInitialPrompt;
if (pageState.type === "creating") return pageState.prompt;
if (pageState.type === "chat") return pageState.initialPrompt;
}
export function shouldResetToWelcome(pageState: PageState) {
return (
pageState.type !== "newChat" &&
pageState.type !== "creating" &&
pageState.type !== "welcome"
);
}
export function getGreetingName(user?: User | null): string {
if (!user) return "there";
const metadata = user.user_metadata as Record<string, unknown> | undefined;

View File

@@ -1,25 +1,25 @@
"use client";
import { Button } from "@/components/atoms/Button/Button";
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
import { Text } from "@/components/atoms/Text/Text";
import { Chat } from "@/components/contextual/Chat/Chat";
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { useCopilotStore } from "./copilot-page-store";
import { useCopilotPage } from "./useCopilotPage";
export default function CopilotPage() {
const { state, handlers } = useCopilotPage();
const confirmNewChat = useCopilotStore((s) => s.confirmNewChat);
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
const {
greetingName,
quickActions,
isLoading,
pageState,
isNewChatModalOpen,
hasSession,
initialPrompt,
isReady,
} = state;
const {
@@ -27,20 +27,16 @@ export default function CopilotPage() {
startChatWithPrompt,
handleSessionNotFound,
handleStreamingChange,
handleCancelNewChat,
handleNewChatModalOpen,
} = handlers;
if (!isReady) return null;
if (pageState.type === "chat") {
if (hasSession) {
return (
<div className="flex h-full flex-col">
<Chat
key={pageState.sessionId ?? "welcome"}
className="flex-1"
urlSessionId={pageState.sessionId}
initialPrompt={pageState.initialPrompt}
initialPrompt={initialPrompt}
onSessionNotFound={handleSessionNotFound}
onStreamingChange={handleStreamingChange}
/>
@@ -48,31 +44,33 @@ export default function CopilotPage() {
title="Interrupt current chat?"
styling={{ maxWidth: 300, width: "100%" }}
controlled={{
isOpen: isNewChatModalOpen,
set: handleNewChatModalOpen,
isOpen: isInterruptModalOpen,
set: (open) => {
if (!open) cancelInterrupt();
},
}}
onClose={handleCancelNewChat}
onClose={cancelInterrupt}
>
<Dialog.Content>
<div className="flex flex-col gap-4">
<Text variant="body">
The current chat response will be interrupted. Are you sure you
want to start a new chat?
want to continue?
</Text>
<Dialog.Footer>
<Button
type="button"
variant="outline"
onClick={handleCancelNewChat}
onClick={cancelInterrupt}
>
Cancel
</Button>
<Button
type="button"
variant="primary"
onClick={confirmNewChat}
onClick={confirmInterrupt}
>
Start new chat
Continue
</Button>
</Dialog.Footer>
</div>
@@ -82,19 +80,6 @@ export default function CopilotPage() {
);
}
if (pageState.type === "newChat" || pageState.type === "creating") {
return (
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
<div className="flex flex-col items-center gap-4">
<ChatLoader />
<Text variant="body" className="text-zinc-500">
Loading your chats...
</Text>
</div>
</div>
);
}
return (
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10">
<div className="w-full text-center">

View File

@@ -10,64 +10,15 @@ import {
type FlagValues,
useGetFlag,
} from "@/services/feature-flags/use-get-flag";
import { SessionKey, sessionStorage } from "@/services/storage/session-storage";
import * as Sentry from "@sentry/nextjs";
import { useQueryClient } from "@tanstack/react-query";
import { useFlags } from "launchdarkly-react-client-sdk";
import { useRouter } from "next/navigation";
import { useEffect, useReducer } from "react";
import { useEffect } from "react";
import { useCopilotStore } from "./copilot-page-store";
import { getGreetingName, getQuickActions, type PageState } from "./helpers";
import { useCopilotURLState } from "./useCopilotURLState";
type CopilotState = {
pageState: PageState;
initialPrompts: Record<string, string>;
previousSessionId: string | null;
};
type CopilotAction =
| { type: "setPageState"; pageState: PageState }
| { type: "setInitialPrompt"; sessionId: string; prompt: string }
| { type: "setPreviousSessionId"; sessionId: string | null };
function isSamePageState(next: PageState, current: PageState) {
if (next.type !== current.type) return false;
if (next.type === "creating" && current.type === "creating") {
return next.prompt === current.prompt;
}
if (next.type === "chat" && current.type === "chat") {
return (
next.sessionId === current.sessionId &&
next.initialPrompt === current.initialPrompt
);
}
return true;
}
function copilotReducer(
state: CopilotState,
action: CopilotAction,
): CopilotState {
if (action.type === "setPageState") {
if (isSamePageState(action.pageState, state.pageState)) return state;
return { ...state, pageState: action.pageState };
}
if (action.type === "setInitialPrompt") {
if (state.initialPrompts[action.sessionId] === action.prompt) return state;
return {
...state,
initialPrompts: {
...state.initialPrompts,
[action.sessionId]: action.prompt,
},
};
}
if (action.type === "setPreviousSessionId") {
if (state.previousSessionId === action.sessionId) return state;
return { ...state, previousSessionId: action.sessionId };
}
return state;
}
import { getGreetingName, getQuickActions } from "./helpers";
import { useCopilotSessionId } from "./useCopilotSessionId";
export function useCopilotPage() {
const router = useRouter();
@@ -75,9 +26,10 @@ export function useCopilotPage() {
const { user, isLoggedIn, isUserLoading } = useSupabase();
const { toast } = useToast();
const isNewChatModalOpen = useCopilotStore((s) => s.isNewChatModalOpen);
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
const setIsStreaming = useCopilotStore((s) => s.setIsStreaming);
const cancelNewChat = useCopilotStore((s) => s.cancelNewChat);
const isCreating = useCopilotStore((s) => s.isCreatingSession);
const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession);
const isChatEnabled = useGetFlag(Flag.CHAT);
const flags = useFlags<FlagValues>();
@@ -88,72 +40,27 @@ export function useCopilotPage() {
const isFlagReady =
!isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined;
const [state, dispatch] = useReducer(copilotReducer, {
pageState: { type: "welcome" },
initialPrompts: {},
previousSessionId: null,
});
const greetingName = getGreetingName(user);
const quickActions = getQuickActions();
function setPageState(pageState: PageState) {
dispatch({ type: "setPageState", pageState });
}
const hasSession = Boolean(urlSessionId);
const initialPrompt = urlSessionId
? getInitialPrompt(urlSessionId)
: undefined;
function setInitialPrompt(sessionId: string, prompt: string) {
dispatch({ type: "setInitialPrompt", sessionId, prompt });
}
function setPreviousSessionId(sessionId: string | null) {
dispatch({ type: "setPreviousSessionId", sessionId });
}
const { setUrlSessionId } = useCopilotURLState({
pageState: state.pageState,
initialPrompts: state.initialPrompts,
previousSessionId: state.previousSessionId,
setPageState,
setInitialPrompt,
setPreviousSessionId,
});
useEffect(
function transitionNewChatToWelcome() {
if (state.pageState.type === "newChat") {
function setWelcomeState() {
dispatch({ type: "setPageState", pageState: { type: "welcome" } });
}
const timer = setTimeout(setWelcomeState, 300);
return function cleanup() {
clearTimeout(timer);
};
}
},
[state.pageState.type],
);
useEffect(
function ensureAccess() {
if (!isFlagReady) return;
if (isChatEnabled === false) {
router.replace(homepageRoute);
}
},
[homepageRoute, isChatEnabled, isFlagReady, router],
);
useEffect(() => {
if (!isFlagReady) return;
if (isChatEnabled === false) {
router.replace(homepageRoute);
}
}, [homepageRoute, isChatEnabled, isFlagReady, router]);
async function startChatWithPrompt(prompt: string) {
if (!prompt?.trim()) return;
if (state.pageState.type === "creating") return;
if (isCreating) return;
const trimmedPrompt = prompt.trim();
dispatch({
type: "setPageState",
pageState: { type: "creating", prompt: trimmedPrompt },
});
setIsCreating(true);
try {
const sessionResponse = await postV2CreateSession({
@@ -165,27 +72,19 @@ export function useCopilotPage() {
}
const sessionId = sessionResponse.data.id;
dispatch({
type: "setInitialPrompt",
sessionId,
prompt: trimmedPrompt,
});
setInitialPrompt(sessionId, trimmedPrompt);
await queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
await setUrlSessionId(sessionId, { shallow: false });
dispatch({
type: "setPageState",
pageState: { type: "chat", sessionId, initialPrompt: trimmedPrompt },
});
await setUrlSessionId(sessionId, { shallow: true });
} catch (error) {
console.error("[CopilotPage] Failed to start chat:", error);
toast({ title: "Failed to start chat", variant: "destructive" });
Sentry.captureException(error);
dispatch({ type: "setPageState", pageState: { type: "welcome" } });
} finally {
setIsCreating(false);
}
}
@@ -201,21 +100,13 @@ export function useCopilotPage() {
setIsStreaming(isStreamingValue);
}
function handleCancelNewChat() {
cancelNewChat();
}
function handleNewChatModalOpen(isOpen: boolean) {
if (!isOpen) cancelNewChat();
}
return {
state: {
greetingName,
quickActions,
isLoading: isUserLoading,
pageState: state.pageState,
isNewChatModalOpen,
hasSession,
initialPrompt,
isReady: isFlagReady && isChatEnabled !== false && isLoggedIn,
},
handlers: {
@@ -223,8 +114,32 @@ export function useCopilotPage() {
startChatWithPrompt,
handleSessionNotFound,
handleStreamingChange,
handleCancelNewChat,
handleNewChatModalOpen,
},
};
}
function getInitialPrompt(sessionId: string): string | undefined {
try {
const prompts = JSON.parse(
sessionStorage.get(SessionKey.CHAT_INITIAL_PROMPTS) || "{}",
);
return prompts[sessionId];
} catch {
return undefined;
}
}
function setInitialPrompt(sessionId: string, prompt: string): void {
try {
const prompts = JSON.parse(
sessionStorage.get(SessionKey.CHAT_INITIAL_PROMPTS) || "{}",
);
prompts[sessionId] = prompt;
sessionStorage.set(
SessionKey.CHAT_INITIAL_PROMPTS,
JSON.stringify(prompts),
);
} catch {
// Ignore storage errors
}
}

View File

@@ -0,0 +1,10 @@
import { parseAsString, useQueryState } from "nuqs";
export function useCopilotSessionId() {
const [urlSessionId, setUrlSessionId] = useQueryState(
"sessionId",
parseAsString,
);
return { urlSessionId, setUrlSessionId };
}

View File

@@ -1,80 +0,0 @@
import { parseAsString, useQueryState } from "nuqs";
import { useLayoutEffect } from "react";
import {
getInitialPromptFromState,
type PageState,
shouldResetToWelcome,
} from "./helpers";
interface UseCopilotUrlStateArgs {
pageState: PageState;
initialPrompts: Record<string, string>;
previousSessionId: string | null;
setPageState: (pageState: PageState) => void;
setInitialPrompt: (sessionId: string, prompt: string) => void;
setPreviousSessionId: (sessionId: string | null) => void;
}
export function useCopilotURLState({
pageState,
initialPrompts,
previousSessionId,
setPageState,
setInitialPrompt,
setPreviousSessionId,
}: UseCopilotUrlStateArgs) {
const [urlSessionId, setUrlSessionId] = useQueryState(
"sessionId",
parseAsString,
);
function syncSessionFromUrl() {
if (urlSessionId) {
if (pageState.type === "chat" && pageState.sessionId === urlSessionId) {
setPreviousSessionId(urlSessionId);
return;
}
const storedInitialPrompt = initialPrompts[urlSessionId];
const currentInitialPrompt = getInitialPromptFromState(
pageState,
storedInitialPrompt,
);
if (currentInitialPrompt) {
setInitialPrompt(urlSessionId, currentInitialPrompt);
}
setPageState({
type: "chat",
sessionId: urlSessionId,
initialPrompt: currentInitialPrompt,
});
setPreviousSessionId(urlSessionId);
return;
}
const wasInChat = previousSessionId !== null && pageState.type === "chat";
setPreviousSessionId(null);
if (wasInChat) {
setPageState({ type: "newChat" });
return;
}
if (shouldResetToWelcome(pageState)) {
setPageState({ type: "welcome" });
}
}
useLayoutEffect(syncSessionFromUrl, [
urlSessionId,
pageState.type,
previousSessionId,
initialPrompts,
]);
return {
urlSessionId,
setUrlSessionId,
};
}

View File

@@ -1933,7 +1933,9 @@
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/UploadFileResponse" }
"schema": {
"$ref": "#/components/schemas/backend__api__model__UploadFileResponse"
}
}
}
},
@@ -5927,6 +5929,478 @@
}
}
},
"/api/workspace": {
"get": {
"tags": ["v2", "workspace"],
"summary": "Get workspace info",
"description": "Get the current user's workspace information.\nCreates workspace if it doesn't exist.",
"operationId": "getV2Get workspace info",
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/WorkspaceInfo" }
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
}
},
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/workspace/files": {
"get": {
"tags": ["v2", "workspace"],
"summary": "List workspace files",
"description": "List files in the user's workspace.\n\n- **path**: Optional path prefix to filter results\n- **limit**: Maximum number of files to return (1-100)\n- **offset**: Number of files to skip",
"operationId": "getV2List workspace files",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "path",
"in": "query",
"required": false,
"schema": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"description": "Path prefix filter",
"title": "Path"
},
"description": "Path prefix filter"
},
{
"name": "limit",
"in": "query",
"required": false,
"schema": {
"type": "integer",
"maximum": 100,
"minimum": 1,
"default": 50,
"title": "Limit"
}
},
{
"name": "offset",
"in": "query",
"required": false,
"schema": {
"type": "integer",
"minimum": 0,
"default": 0,
"title": "Offset"
}
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/WorkspaceFileListResponse"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
},
"post": {
"tags": ["v2", "workspace"],
"summary": "Upload file to workspace",
"description": "Upload a file to the user's workspace.\n\n- **file**: The file to upload (max 100MB)\n- **path**: Optional virtual path (defaults to \"/{filename}\")\n- **overwrite**: Whether to overwrite existing file at path",
"operationId": "postV2Upload file to workspace",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "path",
"in": "query",
"required": false,
"schema": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Path"
}
},
{
"name": "overwrite",
"in": "query",
"required": false,
"schema": {
"type": "boolean",
"default": false,
"title": "Overwrite"
}
}
],
"requestBody": {
"required": true,
"content": {
"multipart/form-data": {
"schema": {
"$ref": "#/components/schemas/Body_postV2Upload_file_to_workspace"
}
}
}
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/backend__api__features__workspace__models__UploadFileResponse"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/workspace/files/by-path": {
"delete": {
"tags": ["v2", "workspace"],
"summary": "Delete file by path",
"description": "Delete a file by its virtual path (soft-delete).",
"operationId": "deleteV2Delete file by path",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "path",
"in": "query",
"required": true,
"schema": {
"type": "string",
"description": "Virtual file path",
"title": "Path"
},
"description": "Virtual file path"
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/DeleteFileResponse" }
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
},
"get": {
"tags": ["v2", "workspace"],
"summary": "Get file info by path",
"description": "Get file metadata by virtual path.",
"operationId": "getV2Get file info by path",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "path",
"in": "query",
"required": true,
"schema": {
"type": "string",
"description": "Virtual file path",
"title": "Path"
},
"description": "Virtual file path"
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/WorkspaceFileInfo" }
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/workspace/files/by-path/download": {
"get": {
"tags": ["v2", "workspace"],
"summary": "Download file by path",
"description": "Download a file by its virtual path.",
"operationId": "getV2Download file by path",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "path",
"in": "query",
"required": true,
"schema": {
"type": "string",
"description": "Virtual file path",
"title": "Path"
},
"description": "Virtual file path"
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": { "application/json": { "schema": {} } }
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/workspace/files/write": {
"post": {
"tags": ["v2", "workspace"],
"summary": "Write file content directly",
"description": "Write file content directly to workspace (for programmatic access).\n\n- **filename**: Name for the file\n- **content_base64**: Base64-encoded file content\n- **path**: Optional virtual path (defaults to \"/{filename}\")\n- **mime_type**: Optional MIME type (auto-detected if not provided)\n- **overwrite**: Whether to overwrite existing file at path",
"operationId": "postV2Write file content directly",
"requestBody": {
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/WriteFileRequest" }
}
},
"required": true
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/backend__api__features__workspace__models__UploadFileResponse"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
},
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/workspace/files/{file_id}": {
"delete": {
"tags": ["v2", "workspace"],
"summary": "Delete file by ID",
"description": "Delete a file from the workspace (soft-delete).",
"operationId": "deleteV2Delete file by id",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "file_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "File Id" }
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/DeleteFileResponse" }
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
},
"get": {
"tags": ["v2", "workspace"],
"summary": "Get file info by ID",
"description": "Get file metadata by file ID.",
"operationId": "getV2Get file info by id",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "file_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "File Id" }
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/WorkspaceFileInfo" }
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/workspace/files/{file_id}/download": {
"get": {
"tags": ["v2", "workspace"],
"summary": "Download file by ID",
"description": "Download a file by its ID.\n\nReturns the file content directly or redirects to a signed URL for GCS.",
"operationId": "getV2Download file by id",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "file_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "File Id" }
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": { "application/json": { "schema": {} } }
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/workspace/files/{file_id}/url": {
"get": {
"tags": ["v2", "workspace"],
"summary": "Get download URL",
"description": "Get a download URL for a file.\n\n- **expires_in**: URL expiration time in seconds (60-86400, default 3600)",
"operationId": "getV2Get download url",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "file_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "File Id" }
},
{
"name": "expires_in",
"in": "query",
"required": false,
"schema": {
"type": "integer",
"maximum": 86400,
"minimum": 60,
"default": 3600,
"title": "Expires In"
}
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/DownloadUrlResponse" }
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/health": {
"get": {
"tags": ["health"],
@@ -6698,6 +7172,14 @@
"type": "object",
"title": "Body_postV2Execute a preset"
},
"Body_postV2Upload_file_to_workspace": {
"properties": {
"file": { "type": "string", "format": "binary", "title": "File" }
},
"type": "object",
"required": ["file"],
"title": "Body_postV2Upload file to workspace"
},
"Body_postV2Upload_submission_media": {
"properties": {
"file": { "type": "string", "format": "binary", "title": "File" }
@@ -6989,6 +7471,17 @@
"enum": ["TOP_UP", "USAGE", "GRANT", "REFUND", "CARD_CHECK"],
"title": "CreditTransactionType"
},
"DeleteFileResponse": {
"properties": {
"success": { "type": "boolean", "title": "Success" },
"file_id": { "type": "string", "title": "File Id" },
"message": { "type": "string", "title": "Message" }
},
"type": "object",
"required": ["success", "file_id", "message"],
"title": "DeleteFileResponse",
"description": "Response model for file deletion."
},
"DeleteGraphResponse": {
"properties": {
"version_counts": { "type": "integer", "title": "Version Counts" }
@@ -7006,6 +7499,19 @@
"required": ["url", "relevance_score"],
"title": "Document"
},
"DownloadUrlResponse": {
"properties": {
"url": { "type": "string", "title": "Url" },
"expires_in_seconds": {
"type": "integer",
"title": "Expires In Seconds"
}
},
"type": "object",
"required": ["url", "expires_in_seconds"],
"title": "DownloadUrlResponse",
"description": "Response model for download URL."
},
"ExecutionAnalyticsConfig": {
"properties": {
"available_models": {
@@ -11722,24 +12228,6 @@
"required": ["timezone"],
"title": "UpdateTimezoneRequest"
},
"UploadFileResponse": {
"properties": {
"file_uri": { "type": "string", "title": "File Uri" },
"file_name": { "type": "string", "title": "File Name" },
"size": { "type": "integer", "title": "Size" },
"content_type": { "type": "string", "title": "Content Type" },
"expires_in_hours": { "type": "integer", "title": "Expires In Hours" }
},
"type": "object",
"required": [
"file_uri",
"file_name",
"size",
"content_type",
"expires_in_hours"
],
"title": "UploadFileResponse"
},
"UserHistoryResponse": {
"properties": {
"history": {
@@ -12043,6 +12531,159 @@
"url"
],
"title": "Webhook"
},
"WorkspaceFileInfo": {
"properties": {
"id": { "type": "string", "title": "Id" },
"name": { "type": "string", "title": "Name" },
"path": { "type": "string", "title": "Path" },
"mime_type": { "type": "string", "title": "Mime Type" },
"size_bytes": { "type": "integer", "title": "Size Bytes" },
"checksum": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Checksum"
},
"source": { "$ref": "#/components/schemas/WorkspaceFileSource" },
"source_exec_id": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Source Exec Id"
},
"source_session_id": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Source Session Id"
},
"created_at": {
"type": "string",
"format": "date-time",
"title": "Created At"
},
"updated_at": {
"type": "string",
"format": "date-time",
"title": "Updated At"
},
"metadata": {
"additionalProperties": true,
"type": "object",
"title": "Metadata"
}
},
"type": "object",
"required": [
"id",
"name",
"path",
"mime_type",
"size_bytes",
"source",
"created_at",
"updated_at"
],
"title": "WorkspaceFileInfo",
"description": "Response model for workspace file information."
},
"WorkspaceFileListResponse": {
"properties": {
"files": {
"items": { "$ref": "#/components/schemas/WorkspaceFileInfo" },
"type": "array",
"title": "Files"
},
"total_count": { "type": "integer", "title": "Total Count" },
"path_filter": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Path Filter"
}
},
"type": "object",
"required": ["files", "total_count"],
"title": "WorkspaceFileListResponse",
"description": "Response model for listing workspace files."
},
"WorkspaceFileSource": {
"type": "string",
"enum": ["UPLOAD", "EXECUTION", "COPILOT", "IMPORT"],
"title": "WorkspaceFileSource"
},
"WorkspaceInfo": {
"properties": {
"id": { "type": "string", "title": "Id" },
"user_id": { "type": "string", "title": "User Id" },
"created_at": {
"type": "string",
"format": "date-time",
"title": "Created At"
},
"updated_at": {
"type": "string",
"format": "date-time",
"title": "Updated At"
},
"file_count": {
"type": "integer",
"title": "File Count",
"default": 0
}
},
"type": "object",
"required": ["id", "user_id", "created_at", "updated_at"],
"title": "WorkspaceInfo",
"description": "Response model for workspace information."
},
"WriteFileRequest": {
"properties": {
"filename": { "type": "string", "title": "Filename" },
"content_base64": {
"type": "string",
"title": "Content Base64",
"description": "Base64-encoded file content"
},
"path": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Path"
},
"mime_type": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Mime Type"
},
"overwrite": {
"type": "boolean",
"title": "Overwrite",
"default": false
}
},
"type": "object",
"required": ["filename", "content_base64"],
"title": "WriteFileRequest",
"description": "Request model for writing file content directly (for CoPilot tools)."
},
"backend__api__features__workspace__models__UploadFileResponse": {
"properties": {
"file": { "$ref": "#/components/schemas/WorkspaceFileInfo" },
"message": { "type": "string", "title": "Message" }
},
"type": "object",
"required": ["file", "message"],
"title": "UploadFileResponse",
"description": "Response model for file upload."
},
"backend__api__model__UploadFileResponse": {
"properties": {
"file_uri": { "type": "string", "title": "File Uri" },
"file_name": { "type": "string", "title": "File Name" },
"size": { "type": "integer", "title": "Size" },
"content_type": { "type": "string", "title": "Content Type" },
"expires_in_hours": { "type": "integer", "title": "Expires In Hours" }
},
"type": "object",
"required": [
"file_uri",
"file_name",
"size",
"content_type",
"expires_in_hours"
],
"title": "UploadFileResponse"
}
},
"securitySchemes": {

View File

@@ -1,5 +1,6 @@
import {
ApiError,
getServerAuthToken,
makeAuthenticatedFileUpload,
makeAuthenticatedRequest,
} from "@/lib/autogpt-server-api/helpers";
@@ -15,6 +16,69 @@ function buildBackendUrl(path: string[], queryString: string): string {
return `${environment.getAGPTServerBaseUrl()}/${backendPath}${queryString}`;
}
/**
* Check if this is a workspace file download request that needs binary response handling.
*/
function isWorkspaceDownloadRequest(path: string[]): boolean {
// Match pattern: api/workspace/files/{id}/download
return (
path.length >= 4 &&
path[0] === "api" &&
path[1] === "workspace" &&
path[2] === "files" &&
path[path.length - 1] === "download"
);
}
/**
* Handle workspace file download requests with proper binary response streaming.
*/
async function handleWorkspaceDownload(
req: NextRequest,
backendUrl: string,
): Promise<NextResponse> {
const token = await getServerAuthToken();
const headers: Record<string, string> = {};
if (token && token !== "no-token-found") {
headers["Authorization"] = `Bearer ${token}`;
}
const response = await fetch(backendUrl, {
method: "GET",
headers,
redirect: "follow", // Follow redirects to signed URLs
});
if (!response.ok) {
return NextResponse.json(
{ error: `Failed to download file: ${response.statusText}` },
{ status: response.status },
);
}
// Get the content type from the backend response
const contentType =
response.headers.get("Content-Type") || "application/octet-stream";
const contentDisposition = response.headers.get("Content-Disposition");
// Stream the response body
const responseHeaders: Record<string, string> = {
"Content-Type": contentType,
};
if (contentDisposition) {
responseHeaders["Content-Disposition"] = contentDisposition;
}
// Return the binary content
const arrayBuffer = await response.arrayBuffer();
return new NextResponse(arrayBuffer, {
status: 200,
headers: responseHeaders,
});
}
async function handleJsonRequest(
req: NextRequest,
method: string,
@@ -180,6 +244,11 @@ async function handler(
};
try {
// Handle workspace file downloads separately (binary response)
if (method === "GET" && isWorkspaceDownloadRequest(path)) {
return await handleWorkspaceDownload(req, backendUrl);
}
if (method === "GET" || method === "DELETE") {
responseBody = await handleGetDeleteRequest(method, backendUrl, req);
} else if (contentType?.includes("application/json")) {

View File

@@ -1,16 +1,17 @@
"use client";
import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId";
import { useCopilotStore } from "@/app/(platform)/copilot/copilot-page-store";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { Text } from "@/components/atoms/Text/Text";
import { cn } from "@/lib/utils";
import { useEffect, useRef } from "react";
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
import { ChatErrorState } from "./components/ChatErrorState/ChatErrorState";
import { ChatLoader } from "./components/ChatLoader/ChatLoader";
import { useChat } from "./useChat";
export interface ChatProps {
className?: string;
urlSessionId?: string | null;
initialPrompt?: string;
onSessionNotFound?: () => void;
onStreamingChange?: (isStreaming: boolean) => void;
@@ -18,12 +19,13 @@ export interface ChatProps {
export function Chat({
className,
urlSessionId,
initialPrompt,
onSessionNotFound,
onStreamingChange,
}: ChatProps) {
const { urlSessionId } = useCopilotSessionId();
const hasHandledNotFoundRef = useRef(false);
const isSwitchingSession = useCopilotStore((s) => s.isSwitchingSession);
const {
messages,
isLoading,
@@ -35,41 +37,49 @@ export function Chat({
showLoader,
} = useChat({ urlSessionId });
useEffect(
function handleMissingSession() {
if (!onSessionNotFound) return;
if (!urlSessionId) return;
if (!isSessionNotFound || isLoading || isCreating) return;
if (hasHandledNotFoundRef.current) return;
hasHandledNotFoundRef.current = true;
onSessionNotFound();
},
[onSessionNotFound, urlSessionId, isSessionNotFound, isLoading, isCreating],
);
useEffect(() => {
if (!onSessionNotFound) return;
if (!urlSessionId) return;
if (!isSessionNotFound || isLoading || isCreating) return;
if (hasHandledNotFoundRef.current) return;
hasHandledNotFoundRef.current = true;
onSessionNotFound();
}, [
onSessionNotFound,
urlSessionId,
isSessionNotFound,
isLoading,
isCreating,
]);
const shouldShowLoader =
(showLoader && (isLoading || isCreating)) || isSwitchingSession;
return (
<div className={cn("flex h-full flex-col", className)}>
{/* Main Content */}
<main className="flex min-h-0 w-full flex-1 flex-col overflow-hidden bg-[#f8f8f9]">
{/* Loading State */}
{showLoader && (isLoading || isCreating) && (
{shouldShowLoader && (
<div className="flex flex-1 items-center justify-center">
<div className="flex flex-col items-center gap-4">
<ChatLoader />
<div className="flex flex-col items-center gap-3">
<LoadingSpinner size="large" className="text-neutral-400" />
<Text variant="body" className="text-zinc-500">
Loading your chats...
{isSwitchingSession
? "Switching chat..."
: "Loading your chat..."}
</Text>
</div>
</div>
)}
{/* Error State */}
{error && !isLoading && (
{error && !isLoading && !isSwitchingSession && (
<ChatErrorState error={error} onRetry={createSession} />
)}
{/* Session Content */}
{sessionId && !isLoading && !error && (
{sessionId && !isLoading && !error && !isSwitchingSession && (
<ChatContainer
sessionId={sessionId}
initialMessages={messages}

View File

@@ -58,39 +58,17 @@ function notifyStreamComplete(
}
}
function cleanupCompletedStreams(completedStreams: Map<string, StreamResult>) {
function cleanupExpiredStreams(
completedStreams: Map<string, StreamResult>,
): Map<string, StreamResult> {
const now = Date.now();
for (const [sessionId, result] of completedStreams) {
const cleaned = new Map(completedStreams);
for (const [sessionId, result] of cleaned) {
if (now - result.completedAt > COMPLETED_STREAM_TTL) {
completedStreams.delete(sessionId);
cleaned.delete(sessionId);
}
}
}
function moveToCompleted(
activeStreams: Map<string, ActiveStream>,
completedStreams: Map<string, StreamResult>,
streamCompleteCallbacks: Set<StreamCompleteCallback>,
sessionId: string,
) {
const stream = activeStreams.get(sessionId);
if (!stream) return;
const result: StreamResult = {
sessionId,
status: stream.status,
chunks: stream.chunks,
completedAt: Date.now(),
error: stream.error,
};
completedStreams.set(sessionId, result);
activeStreams.delete(sessionId);
cleanupCompletedStreams(completedStreams);
if (stream.status === "completed" || stream.status === "error") {
notifyStreamComplete(streamCompleteCallbacks, sessionId);
}
return cleaned;
}
export const useChatStore = create<ChatStore>((set, get) => ({
@@ -106,17 +84,31 @@ export const useChatStore = create<ChatStore>((set, get) => ({
context,
onChunk,
) {
const { activeStreams, completedStreams, streamCompleteCallbacks } = get();
const state = get();
const newActiveStreams = new Map(state.activeStreams);
let newCompletedStreams = new Map(state.completedStreams);
const callbacks = state.streamCompleteCallbacks;
const existingStream = activeStreams.get(sessionId);
const existingStream = newActiveStreams.get(sessionId);
if (existingStream) {
existingStream.abortController.abort();
moveToCompleted(
activeStreams,
completedStreams,
streamCompleteCallbacks,
const normalizedStatus =
existingStream.status === "streaming"
? "completed"
: existingStream.status;
const result: StreamResult = {
sessionId,
);
status: normalizedStatus,
chunks: existingStream.chunks,
completedAt: Date.now(),
error: existingStream.error,
};
newCompletedStreams.set(sessionId, result);
newActiveStreams.delete(sessionId);
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
if (normalizedStatus === "completed" || normalizedStatus === "error") {
notifyStreamComplete(callbacks, sessionId);
}
}
const abortController = new AbortController();
@@ -132,36 +124,76 @@ export const useChatStore = create<ChatStore>((set, get) => ({
onChunkCallbacks: initialCallbacks,
};
activeStreams.set(sessionId, stream);
newActiveStreams.set(sessionId, stream);
set({
activeStreams: newActiveStreams,
completedStreams: newCompletedStreams,
});
try {
await executeStream(stream, message, isUserMessage, context);
} finally {
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
if (stream.status !== "streaming") {
moveToCompleted(
activeStreams,
completedStreams,
streamCompleteCallbacks,
sessionId,
);
const currentState = get();
const finalActiveStreams = new Map(currentState.activeStreams);
let finalCompletedStreams = new Map(currentState.completedStreams);
const storedStream = finalActiveStreams.get(sessionId);
if (storedStream === stream) {
const result: StreamResult = {
sessionId,
status: stream.status,
chunks: stream.chunks,
completedAt: Date.now(),
error: stream.error,
};
finalCompletedStreams.set(sessionId, result);
finalActiveStreams.delete(sessionId);
finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams);
set({
activeStreams: finalActiveStreams,
completedStreams: finalCompletedStreams,
});
if (stream.status === "completed" || stream.status === "error") {
notifyStreamComplete(
currentState.streamCompleteCallbacks,
sessionId,
);
}
}
}
}
},
stopStream: function stopStream(sessionId) {
const { activeStreams, completedStreams, streamCompleteCallbacks } = get();
const stream = activeStreams.get(sessionId);
if (stream) {
stream.abortController.abort();
stream.status = "completed";
moveToCompleted(
activeStreams,
completedStreams,
streamCompleteCallbacks,
sessionId,
);
}
const state = get();
const stream = state.activeStreams.get(sessionId);
if (!stream) return;
stream.abortController.abort();
stream.status = "completed";
const newActiveStreams = new Map(state.activeStreams);
let newCompletedStreams = new Map(state.completedStreams);
const result: StreamResult = {
sessionId,
status: stream.status,
chunks: stream.chunks,
completedAt: Date.now(),
error: stream.error,
};
newCompletedStreams.set(sessionId, result);
newActiveStreams.delete(sessionId);
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
set({
activeStreams: newActiveStreams,
completedStreams: newCompletedStreams,
});
notifyStreamComplete(state.streamCompleteCallbacks, sessionId);
},
subscribeToStream: function subscribeToStream(
@@ -169,16 +201,18 @@ export const useChatStore = create<ChatStore>((set, get) => ({
onChunk,
skipReplay = false,
) {
const { activeStreams } = get();
const state = get();
const stream = state.activeStreams.get(sessionId);
const stream = activeStreams.get(sessionId);
if (stream) {
if (!skipReplay) {
for (const chunk of stream.chunks) {
onChunk(chunk);
}
}
stream.onChunkCallbacks.add(onChunk);
return function unsubscribe() {
stream.onChunkCallbacks.delete(onChunk);
};
@@ -204,7 +238,12 @@ export const useChatStore = create<ChatStore>((set, get) => ({
},
clearCompletedStream: function clearCompletedStream(sessionId) {
get().completedStreams.delete(sessionId);
const state = get();
if (!state.completedStreams.has(sessionId)) return;
const newCompletedStreams = new Map(state.completedStreams);
newCompletedStreams.delete(sessionId);
set({ completedStreams: newCompletedStreams });
},
isStreaming: function isStreaming(sessionId) {
@@ -213,11 +252,21 @@ export const useChatStore = create<ChatStore>((set, get) => ({
},
registerActiveSession: function registerActiveSession(sessionId) {
get().activeSessions.add(sessionId);
const state = get();
if (state.activeSessions.has(sessionId)) return;
const newActiveSessions = new Set(state.activeSessions);
newActiveSessions.add(sessionId);
set({ activeSessions: newActiveSessions });
},
unregisterActiveSession: function unregisterActiveSession(sessionId) {
get().activeSessions.delete(sessionId);
const state = get();
if (!state.activeSessions.has(sessionId)) return;
const newActiveSessions = new Set(state.activeSessions);
newActiveSessions.delete(sessionId);
set({ activeSessions: newActiveSessions });
},
isSessionActive: function isSessionActive(sessionId) {
@@ -225,10 +274,16 @@ export const useChatStore = create<ChatStore>((set, get) => ({
},
onStreamComplete: function onStreamComplete(callback) {
const { streamCompleteCallbacks } = get();
streamCompleteCallbacks.add(callback);
const state = get();
const newCallbacks = new Set(state.streamCompleteCallbacks);
newCallbacks.add(callback);
set({ streamCompleteCallbacks: newCallbacks });
return function unsubscribe() {
streamCompleteCallbacks.delete(callback);
const currentState = get();
const cleanedCallbacks = new Set(currentState.streamCompleteCallbacks);
cleanedCallbacks.delete(callback);
set({ streamCompleteCallbacks: cleanedCallbacks });
};
},
}));

View File

@@ -48,6 +48,15 @@ export function handleTextEnded(
const completedText = deps.streamingChunksRef.current.join("");
if (completedText.trim()) {
deps.setMessages((prev) => {
// Check if this exact message already exists to prevent duplicates
const exists = prev.some(
(msg) =>
msg.type === "message" &&
msg.role === "assistant" &&
msg.content === completedText,
);
if (exists) return prev;
const assistantMessage: ChatMessageData = {
type: "message",
role: "assistant",
@@ -203,13 +212,24 @@ export function handleStreamEnd(
]);
}
if (completedContent.trim()) {
const assistantMessage: ChatMessageData = {
type: "message",
role: "assistant",
content: completedContent,
timestamp: new Date(),
};
deps.setMessages((prev) => [...prev, assistantMessage]);
deps.setMessages((prev) => {
// Check if this exact message already exists to prevent duplicates
const exists = prev.some(
(msg) =>
msg.type === "message" &&
msg.role === "assistant" &&
msg.content === completedContent,
);
if (exists) return prev;
const assistantMessage: ChatMessageData = {
type: "message",
role: "assistant",
content: completedContent,
timestamp: new Date(),
};
return [...prev, assistantMessage];
});
}
deps.setStreamingChunks([]);
deps.streamingChunksRef.current = [];

View File

@@ -304,6 +304,7 @@ export function parseToolResponse(
if (isAgentArray(agentsData)) {
return {
type: "agent_carousel",
toolId,
toolName: "agent_carousel",
agents: agentsData,
totalCount: parsedResult.total_count as number | undefined,
@@ -316,6 +317,7 @@ export function parseToolResponse(
if (responseType === "execution_started") {
return {
type: "execution_started",
toolId,
toolName: "execution_started",
executionId: (parsedResult.execution_id as string) || "",
agentName: (parsedResult.graph_name as string) || undefined,
@@ -341,6 +343,41 @@ export function parseToolResponse(
timestamp: timestamp || new Date(),
};
}
if (responseType === "operation_started") {
return {
type: "operation_started",
toolName: (parsedResult.tool_name as string) || toolName,
toolId,
operationId: (parsedResult.operation_id as string) || "",
message:
(parsedResult.message as string) ||
"Operation started. You can close this tab.",
timestamp: timestamp || new Date(),
};
}
if (responseType === "operation_pending") {
return {
type: "operation_pending",
toolName: (parsedResult.tool_name as string) || toolName,
toolId,
operationId: (parsedResult.operation_id as string) || "",
message:
(parsedResult.message as string) ||
"Operation in progress. Please wait...",
timestamp: timestamp || new Date(),
};
}
if (responseType === "operation_in_progress") {
return {
type: "operation_in_progress",
toolName: (parsedResult.tool_name as string) || toolName,
toolCallId: (parsedResult.tool_call_id as string) || toolId,
message:
(parsedResult.message as string) ||
"Operation already in progress. Please wait...",
timestamp: timestamp || new Date(),
};
}
if (responseType === "need_login") {
return {
type: "login_needed",

View File

@@ -82,11 +82,110 @@ export function useChatContainer({
[sessionId, stopStreaming, activeStreams, subscribeToStream],
);
const allMessages = useMemo(
() => [...processInitialMessages(initialMessages), ...messages],
[initialMessages, messages],
// Collect toolIds from completed tool results in initialMessages
// Used to filter out operation messages when their results arrive
const completedToolIds = useMemo(() => {
const processedInitial = processInitialMessages(initialMessages);
const ids = new Set<string>();
for (const msg of processedInitial) {
if (
msg.type === "tool_response" ||
msg.type === "agent_carousel" ||
msg.type === "execution_started"
) {
const toolId = (msg as any).toolId;
if (toolId) {
ids.add(toolId);
}
}
}
return ids;
}, [initialMessages]);
// Clean up local operation messages when their completed results arrive from polling
// This effect runs when completedToolIds changes (i.e., when polling brings new results)
useEffect(
function cleanupCompletedOperations() {
if (completedToolIds.size === 0) return;
setMessages((prev) => {
const filtered = prev.filter((msg) => {
if (
msg.type === "operation_started" ||
msg.type === "operation_pending" ||
msg.type === "operation_in_progress"
) {
const toolId = (msg as any).toolId || (msg as any).toolCallId;
if (toolId && completedToolIds.has(toolId)) {
return false; // Remove - operation completed
}
}
return true;
});
// Only update state if something was actually filtered
return filtered.length === prev.length ? prev : filtered;
});
},
[completedToolIds],
);
// Combine initial messages from backend with local streaming messages,
// then deduplicate to prevent duplicates when polling refreshes initialMessages
const allMessages = useMemo(() => {
const processedInitial = processInitialMessages(initialMessages);
// Filter local messages to remove operation messages for completed tools
const filteredLocalMessages = messages.filter((msg) => {
if (
msg.type === "operation_started" ||
msg.type === "operation_pending" ||
msg.type === "operation_in_progress"
) {
const toolId = (msg as any).toolId || (msg as any).toolCallId;
if (toolId && completedToolIds.has(toolId)) {
return false; // Filter out - operation completed
}
}
return true;
});
const combined = [...processedInitial, ...filteredLocalMessages];
// Deduplicate by content+role+timestamp. When initialMessages is refreshed via polling,
// it may contain messages that are also in the local `messages` state.
// Including timestamp prevents dropping legitimate repeated messages (e.g., user sends "yes" twice)
const seen = new Set<string>();
return combined.filter((msg) => {
// Create a key based on type, role, content, and timestamp for deduplication
let key: string;
if (msg.type === "message") {
// Use timestamp (rounded to nearest second) to allow slight variations
// while still catching true duplicates from SSE/polling overlap
const ts = msg.timestamp
? Math.floor(new Date(msg.timestamp).getTime() / 1000)
: "";
key = `msg:${msg.role}:${ts}:${msg.content}`;
} else if (msg.type === "tool_call") {
key = `toolcall:${msg.toolId}`;
} else if (
msg.type === "operation_started" ||
msg.type === "operation_pending" ||
msg.type === "operation_in_progress"
) {
// Dedupe operation messages by toolId or operationId
key = `op:${(msg as any).toolId || (msg as any).operationId || (msg as any).toolCallId || ""}:${msg.toolName}`;
} else {
// For other types, use a combination of type and first few fields
key = `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`;
}
if (seen.has(key)) {
return false;
}
seen.add(key);
return true;
});
}, [initialMessages, messages]);
async function sendMessage(
content: string,
isUserMessage: boolean = true,

View File

@@ -16,6 +16,7 @@ import { AuthPromptWidget } from "../AuthPromptWidget/AuthPromptWidget";
import { ChatCredentialsSetup } from "../ChatCredentialsSetup/ChatCredentialsSetup";
import { ClarificationQuestionsWidget } from "../ClarificationQuestionsWidget/ClarificationQuestionsWidget";
import { ExecutionStartedMessage } from "../ExecutionStartedMessage/ExecutionStartedMessage";
import { PendingOperationWidget } from "../PendingOperationWidget/PendingOperationWidget";
import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
import { NoResultsMessage } from "../NoResultsMessage/NoResultsMessage";
import { ToolCallMessage } from "../ToolCallMessage/ToolCallMessage";
@@ -71,6 +72,9 @@ export function ChatMessage({
isLoginNeeded,
isCredentialsNeeded,
isClarificationNeeded,
isOperationStarted,
isOperationPending,
isOperationInProgress,
} = useChatMessage(message);
const displayContent = getDisplayContent(message, isUser);
@@ -126,10 +130,6 @@ export function ChatMessage({
[displayContent, message],
);
function isLongResponse(content: string): boolean {
return content.split("\n").length > 5;
}
const handleTryAgain = useCallback(() => {
if (message.type !== "message" || !onSendMessage) return;
onSendMessage(message.content, message.role === "user");
@@ -294,6 +294,42 @@ export function ChatMessage({
);
}
// Render operation_started messages (long-running background operations)
if (isOperationStarted && message.type === "operation_started") {
return (
<PendingOperationWidget
status="started"
message={message.message}
toolName={message.toolName}
className={className}
/>
);
}
// Render operation_pending messages (operations in progress when refreshing)
if (isOperationPending && message.type === "operation_pending") {
return (
<PendingOperationWidget
status="pending"
message={message.message}
toolName={message.toolName}
className={className}
/>
);
}
// Render operation_in_progress messages (duplicate request while operation running)
if (isOperationInProgress && message.type === "operation_in_progress") {
return (
<PendingOperationWidget
status="in_progress"
message={message.message}
toolName={message.toolName}
className={className}
/>
);
}
// Render tool response messages (but skip agent_output if it's being rendered inside assistant message)
if (isToolResponse && message.type === "tool_response") {
return (
@@ -358,7 +394,7 @@ export function ChatMessage({
<ArrowsClockwiseIcon className="size-4 text-zinc-600" />
</Button>
)}
{!isUser && isFinalMessage && isLongResponse(displayContent) && (
{!isUser && isFinalMessage && !isStreaming && (
<Button
variant="ghost"
size="icon"

View File

@@ -61,6 +61,7 @@ export type ChatMessageData =
}
| {
type: "agent_carousel";
toolId: string;
toolName: string;
agents: Array<{
id: string;
@@ -74,6 +75,7 @@ export type ChatMessageData =
}
| {
type: "execution_started";
toolId: string;
toolName: string;
executionId: string;
agentName?: string;
@@ -103,6 +105,29 @@ export type ChatMessageData =
message: string;
sessionId: string;
timestamp?: string | Date;
}
| {
type: "operation_started";
toolName: string;
toolId: string;
operationId: string;
message: string;
timestamp?: string | Date;
}
| {
type: "operation_pending";
toolName: string;
toolId: string;
operationId: string;
message: string;
timestamp?: string | Date;
}
| {
type: "operation_in_progress";
toolName: string;
toolCallId: string;
message: string;
timestamp?: string | Date;
};
export function useChatMessage(message: ChatMessageData) {
@@ -124,5 +149,8 @@ export function useChatMessage(message: ChatMessageData) {
isExecutionStarted: message.type === "execution_started",
isInputsNeeded: message.type === "inputs_needed",
isClarificationNeeded: message.type === "clarification_needed",
isOperationStarted: message.type === "operation_started",
isOperationPending: message.type === "operation_pending",
isOperationInProgress: message.type === "operation_in_progress",
};
}

View File

@@ -30,6 +30,7 @@ export function ClarificationQuestionsWidget({
className,
}: Props) {
const [answers, setAnswers] = useState<Record<string, string>>({});
const [isSubmitted, setIsSubmitted] = useState(false);
function handleAnswerChange(keyword: string, value: string) {
setAnswers((prev) => ({ ...prev, [keyword]: value }));
@@ -41,11 +42,42 @@ export function ClarificationQuestionsWidget({
if (!allAnswered) {
return;
}
setIsSubmitted(true);
onSubmitAnswers(answers);
}
const allAnswered = questions.every((q) => answers[q.keyword]?.trim());
// Show submitted state after answers are submitted
if (isSubmitted) {
return (
<div
className={cn(
"group relative flex w-full justify-start gap-3 px-4 py-3",
className,
)}
>
<div className="flex w-full max-w-3xl gap-3">
<div className="flex-shrink-0">
<div className="flex h-7 w-7 items-center justify-center rounded-lg bg-green-500">
<CheckCircleIcon className="h-4 w-4 text-white" weight="bold" />
</div>
</div>
<div className="flex min-w-0 flex-1 flex-col">
<Card className="p-4">
<Text variant="h4" className="mb-1 text-slate-900">
Answers submitted
</Text>
<Text variant="small" className="text-slate-600">
Processing your responses...
</Text>
</Card>
</div>
</div>
</div>
);
}
return (
<div
className={cn(

View File

@@ -1,6 +1,8 @@
"use client";
import { getGetV2DownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
import { cn } from "@/lib/utils";
import { EyeSlash } from "@phosphor-icons/react";
import React from "react";
import ReactMarkdown from "react-markdown";
import remarkGfm from "remark-gfm";
@@ -29,12 +31,90 @@ interface InputProps extends React.InputHTMLAttributes<HTMLInputElement> {
type?: string;
}
/**
* Converts a workspace:// URL to a proxy URL that routes through Next.js to the backend.
* workspace://abc123 -> /api/proxy/api/workspace/files/abc123/download
*
* Uses the generated API URL helper and routes through the Next.js proxy
* which handles authentication and proper backend routing.
*/
function resolveWorkspaceUrl(src: string): string {
if (src.startsWith("workspace://")) {
const fileId = src.replace("workspace://", "");
// Use the generated API URL helper to get the correct path
const apiPath = getGetV2DownloadFileByIdUrl(fileId);
// Route through the Next.js proxy (same pattern as customMutator for client-side)
return `/api/proxy${apiPath}`;
}
return src;
}
/**
* URL transformer for ReactMarkdown.
* Transforms workspace:// URLs to backend API download URLs before rendering.
* This is needed because ReactMarkdown sanitizes URLs and only allows
* http, https, mailto, and tel protocols by default.
*/
function transformUrl(url: string): string {
return resolveWorkspaceUrl(url);
}
/**
* Check if the image URL is a workspace file (AI cannot see these yet).
* After URL transformation, workspace files have URLs like /api/proxy/api/workspace/files/...
*/
function isWorkspaceImage(src: string | undefined): boolean {
return src?.includes("/workspace/files/") ?? false;
}
/**
* Custom image component that shows an indicator when the AI cannot see the image.
* Note: src is already transformed by urlTransform, so workspace:// is now /api/workspace/...
*/
function MarkdownImage(props: Record<string, unknown>) {
const src = props.src as string | undefined;
const alt = props.alt as string | undefined;
const aiCannotSee = isWorkspaceImage(src);
// If no src, show a placeholder
if (!src) {
return (
<span className="my-2 inline-block rounded border border-amber-200 bg-amber-50 px-2 py-1 text-sm text-amber-700">
[Image: {alt || "missing src"}]
</span>
);
}
return (
<span className="relative my-2 inline-block">
{/* eslint-disable-next-line @next/next/no-img-element */}
<img
src={src}
alt={alt || "Image"}
className="h-auto max-w-full rounded-md border border-zinc-200"
loading="lazy"
/>
{aiCannotSee && (
<span
className="absolute bottom-2 right-2 flex items-center gap-1 rounded bg-black/70 px-2 py-1 text-xs text-white"
title="The AI cannot see this image"
>
<EyeSlash size={14} />
<span>AI cannot see this image</span>
</span>
)}
</span>
);
}
export function MarkdownContent({ content, className }: MarkdownContentProps) {
return (
<div className={cn("markdown-content", className)}>
<ReactMarkdown
skipHtml={true}
remarkPlugins={[remarkGfm]}
urlTransform={transformUrl}
components={{
code: ({ children, className, ...props }: CodeProps) => {
const isInline = !className?.includes("language-");
@@ -206,6 +286,9 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) {
{children}
</td>
),
img: ({ src, alt, ...props }) => (
<MarkdownImage src={src} alt={alt} {...props} />
),
}}
>
{content}

View File

@@ -0,0 +1,109 @@
"use client";
import { Card } from "@/components/atoms/Card/Card";
import { Text } from "@/components/atoms/Text/Text";
import { cn } from "@/lib/utils";
import { CircleNotch, CheckCircle, XCircle } from "@phosphor-icons/react";
type OperationStatus =
| "pending"
| "started"
| "in_progress"
| "completed"
| "error";
interface Props {
status: OperationStatus;
message: string;
toolName?: string;
className?: string;
}
function getOperationTitle(toolName?: string): string {
if (!toolName) return "Operation";
// Convert tool name to human-readable format
// e.g., "create_agent" -> "Creating Agent", "edit_agent" -> "Editing Agent"
if (toolName === "create_agent") return "Creating Agent";
if (toolName === "edit_agent") return "Editing Agent";
// Default: capitalize and format tool name
return toolName
.split("_")
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
.join(" ");
}
export function PendingOperationWidget({
status,
message,
toolName,
className,
}: Props) {
const isPending =
status === "pending" || status === "started" || status === "in_progress";
const isCompleted = status === "completed";
const isError = status === "error";
const operationTitle = getOperationTitle(toolName);
return (
<div
className={cn(
"group relative flex w-full justify-start gap-3 px-4 py-3",
className,
)}
>
<div className="flex w-full max-w-3xl gap-3">
<div className="flex-shrink-0">
<div
className={cn(
"flex h-7 w-7 items-center justify-center rounded-lg",
isPending && "bg-blue-500",
isCompleted && "bg-green-500",
isError && "bg-red-500",
)}
>
{isPending && (
<CircleNotch
className="h-4 w-4 animate-spin text-white"
weight="bold"
/>
)}
{isCompleted && (
<CheckCircle className="h-4 w-4 text-white" weight="bold" />
)}
{isError && (
<XCircle className="h-4 w-4 text-white" weight="bold" />
)}
</div>
</div>
<div className="flex min-w-0 flex-1 flex-col">
<Card className="space-y-2 p-4">
<div>
<Text variant="h4" className="mb-1 text-slate-900">
{isPending && operationTitle}
{isCompleted && `${operationTitle} Complete`}
{isError && `${operationTitle} Failed`}
</Text>
<Text variant="small" className="text-slate-600">
{message}
</Text>
</div>
{isPending && (
<Text variant="small" className="italic text-slate-500">
Check your library in a few minutes.
</Text>
)}
{toolName && (
<Text variant="small" className="text-slate-400">
Tool: {toolName}
</Text>
)}
</Card>
</div>
</div>
</div>
);
}

View File

@@ -37,6 +37,77 @@ export function getErrorMessage(result: unknown): string {
return "An error occurred";
}
/**
* Check if a value is a workspace file reference.
*/
function isWorkspaceRef(value: unknown): value is string {
return typeof value === "string" && value.startsWith("workspace://");
}
/**
* Check if a workspace reference appears to be an image based on common patterns.
* Since workspace refs don't have extensions, we check the context or assume image
* for certain block types.
*/
function isLikelyImageRef(value: string, outputKey?: string): boolean {
if (!isWorkspaceRef(value)) return false;
// Check output key name for image-related hints
const imageKeywords = [
"image",
"img",
"photo",
"picture",
"thumbnail",
"avatar",
"icon",
"screenshot",
"output",
"result",
"generated",
];
if (outputKey) {
const lowerKey = outputKey.toLowerCase();
if (imageKeywords.some((kw) => lowerKey.includes(kw))) {
return true;
}
}
// Default to treating workspace refs as potential images
// since that's the most common case for generated content
return true;
}
/**
* Format a single output value, converting workspace refs to markdown images.
*/
function formatOutputValue(value: unknown, outputKey?: string): string {
if (isWorkspaceRef(value) && isLikelyImageRef(value, outputKey)) {
// Format as markdown image
return `![${outputKey || "Generated image"}](${value})`;
}
if (typeof value === "string") {
// Check for data URIs (images)
if (value.startsWith("data:image/")) {
return `![${outputKey || "Generated image"}](${value})`;
}
return value;
}
if (Array.isArray(value)) {
return value
.map((item, idx) => formatOutputValue(item, `${outputKey}_${idx}`))
.join("\n\n");
}
if (typeof value === "object" && value !== null) {
return JSON.stringify(value, null, 2);
}
return String(value);
}
function getToolCompletionPhrase(toolName: string): string {
const toolCompletionPhrases: Record<string, string> = {
add_understanding: "Updated your business information",
@@ -127,10 +198,26 @@ export function formatToolResponse(result: unknown, toolName: string): string {
case "block_output":
const blockName = (response.block_name as string) || "Block";
const outputs = response.outputs as Record<string, unknown> | undefined;
const outputs = response.outputs as Record<string, unknown[]> | undefined;
if (outputs && Object.keys(outputs).length > 0) {
const outputKeys = Object.keys(outputs);
return `${blockName} executed successfully. Outputs: ${outputKeys.join(", ")}`;
const formattedOutputs: string[] = [];
for (const [key, values] of Object.entries(outputs)) {
if (!Array.isArray(values) || values.length === 0) continue;
// Format each value in the output array
for (const value of values) {
const formatted = formatOutputValue(value, key);
if (formatted) {
formattedOutputs.push(formatted);
}
}
}
if (formattedOutputs.length > 0) {
return `${blockName} executed successfully.\n\n${formattedOutputs.join("\n\n")}`;
}
return `${blockName} executed successfully.`;
}
return `${blockName} executed successfully.`;

View File

@@ -10,7 +10,7 @@ export function UserChatBubble({ children, className }: UserChatBubbleProps) {
return (
<div
className={cn(
"group relative min-w-20 overflow-hidden rounded-xl bg-purple-100 px-3 text-right text-[1rem] leading-relaxed transition-all duration-500 ease-in-out",
"group relative min-w-20 overflow-hidden rounded-xl bg-purple-100 px-3 text-left text-[1rem] leading-relaxed transition-all duration-500 ease-in-out",
className,
)}
style={{

View File

@@ -59,6 +59,7 @@ export function useChatSession({
query: {
enabled: !!sessionId,
select: okData,
staleTime: 0,
retry: shouldRetrySessionLoad,
retryDelay: getSessionRetryDelay,
},
@@ -102,15 +103,98 @@ export function useChatSession({
}
}, [createError, loadError]);
// Check if there are any pending operations in the messages
// Must check all operation types: operation_pending, operation_started, operation_in_progress
const hasPendingOperations = useMemo(() => {
if (!messages || messages.length === 0) return false;
const pendingTypes = new Set([
"operation_pending",
"operation_in_progress",
"operation_started",
]);
return messages.some((msg) => {
if (msg.role !== "tool" || !msg.content) return false;
try {
const content =
typeof msg.content === "string"
? JSON.parse(msg.content)
: msg.content;
return pendingTypes.has(content?.type);
} catch {
return false;
}
});
}, [messages]);
// Refresh sessions list when a pending operation completes
// (hasPendingOperations transitions from true to false)
const prevHasPendingOperationsRef = useRef(hasPendingOperations);
useEffect(
function refreshSessionsListOnLoad() {
if (sessionId && sessionData && !isLoadingSession) {
function refreshSessionsListOnOperationComplete() {
const wasHasPending = prevHasPendingOperationsRef.current;
prevHasPendingOperationsRef.current = hasPendingOperations;
// Only invalidate when transitioning from pending to not pending
if (wasHasPending && !hasPendingOperations && sessionId) {
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
}
},
[sessionId, sessionData, isLoadingSession, queryClient],
[hasPendingOperations, sessionId, queryClient],
);
// Poll for updates when there are pending operations (long poll - 10s intervals with backoff)
const pollAttemptRef = useRef(0);
const hasPendingOperationsRef = useRef(hasPendingOperations);
hasPendingOperationsRef.current = hasPendingOperations;
useEffect(
function pollForPendingOperations() {
if (!sessionId || !hasPendingOperations) {
pollAttemptRef.current = 0;
return;
}
let cancelled = false;
let timeoutId: ReturnType<typeof setTimeout> | null = null;
// Calculate delay with exponential backoff: 10s, 15s, 20s, 25s, 30s (max)
const baseDelay = 10000;
const maxDelay = 30000;
function schedule() {
const delay = Math.min(
baseDelay + pollAttemptRef.current * 5000,
maxDelay,
);
timeoutId = setTimeout(async () => {
if (cancelled) return;
console.info(
`[useChatSession] Polling for pending operation updates (attempt ${pollAttemptRef.current + 1})`,
);
pollAttemptRef.current += 1;
try {
await refetch();
} catch (err) {
console.error("[useChatSession] Poll failed:", err);
} finally {
// Continue polling if still pending and not cancelled
if (!cancelled && hasPendingOperationsRef.current) {
schedule();
}
}
}, delay);
}
schedule();
return () => {
cancelled = true;
if (timeoutId) clearTimeout(timeoutId);
};
},
[sessionId, hasPendingOperations, refetch],
);
async function createSession() {
@@ -239,6 +323,7 @@ export function useChatSession({
isCreating,
error,
isSessionNotFound: isNotFoundError(loadError),
hasPendingOperations,
createSession,
loadSession,
refreshSession,

View File

@@ -3,6 +3,7 @@ import { environment } from "../environment";
export enum SessionKey {
CHAT_SENT_INITIAL_PROMPTS = "chat_sent_initial_prompts",
CHAT_INITIAL_PROMPTS = "chat_initial_prompts",
}
function get(key: SessionKey) {

View File

@@ -1 +0,0 @@
# Video editing blocks

View File

@@ -277,6 +277,50 @@ async def run(
token = credentials.api_key.get_secret_value()
```
### Handling Files
When your block works with files (images, videos, documents), use `store_media_file()`:
```python
from backend.data.execution import ExecutionContext
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
async def run(
self,
input_data: Input,
*,
execution_context: ExecutionContext,
**kwargs,
):
# PROCESSING: Need local file path for tools like ffmpeg, MoviePy, PIL
local_path = await store_media_file(
file=input_data.video,
execution_context=execution_context,
return_format="for_local_processing",
)
# EXTERNAL API: Need base64 content for APIs like Replicate, OpenAI
image_b64 = await store_media_file(
file=input_data.image,
execution_context=execution_context,
return_format="for_external_api",
)
# OUTPUT: Return to user/next block (auto-adapts to context)
result = await store_media_file(
file=generated_url,
execution_context=execution_context,
return_format="for_block_output", # workspace:// in CoPilot, data URI in graphs
)
yield "image_url", result
```
**Return format options:**
- `"for_local_processing"` - Local file path for processing tools
- `"for_external_api"` - Data URI for external APIs needing base64
- `"for_block_output"` - **Always use for outputs** - automatically picks best format
## Testing Your Block
```bash

View File

@@ -111,6 +111,71 @@ Follow these steps to create and test a new block:
- `graph_exec_id`: The ID of the execution of the agent. This changes every time the agent has a new "run"
- `node_exec_id`: The ID of the execution of the node. This changes every time the node is executed
- `node_id`: The ID of the node that is being executed. It changes every version of the graph, but not every time the node is executed.
- `execution_context`: An `ExecutionContext` object containing user_id, graph_exec_id, workspace_id, and session_id. Required for file handling.
### Handling Files in Blocks
When your block needs to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. This function handles downloading, validation, virus scanning, and storage.
**Import:**
```python
from backend.data.execution import ExecutionContext
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
```
**The `return_format` parameter determines what you get back:**
| Format | Use When | Returns |
|--------|----------|---------|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
**Examples:**
```python
async def run(
self,
input_data: Input,
*,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# PROCESSING: Need to work with file locally (ffmpeg, MoviePy, PIL)
local_path = await store_media_file(
file=input_data.video,
execution_context=execution_context,
return_format="for_local_processing",
)
# local_path = "video.mp4" - use with Path, ffmpeg, subprocess, etc.
full_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
# EXTERNAL API: Need to send content to an API like Replicate
image_b64 = await store_media_file(
file=input_data.image,
execution_context=execution_context,
return_format="for_external_api",
)
# image_b64 = "..." - send to external API
# OUTPUT: Returning result from block to user/next block
result_url = await store_media_file(
file=generated_image_url,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", result_url
# In CoPilot: result_url = "workspace://abc123" (persistent, context-efficient)
# In graphs: result_url = "data:image/png;base64,..." (for next block/display)
```
**Key points:**
- `for_block_output` is the **only** format that auto-adapts to execution context
- Always use `for_block_output` for block outputs unless you have a specific reason not to
- Never manually check for `workspace_id` - let `for_block_output` handle the logic
- The function handles URLs, data URIs, `workspace://` references, and local paths as input
### Field Types