Compare commits

...

18 Commits

Author SHA1 Message Date
Swifty
06758adefd Merge branch 'dev' into swiftyos/sse-long-running-tasks 2026-01-29 13:33:32 +01:00
Swifty
c01c29a059 fmt issues 2026-01-29 13:28:01 +01:00
Ubbe
b94c83aacc feat(frontend): Copilot speech to text via Whisper model (#11871)
## Changes 🏗️


https://github.com/user-attachments/assets/d9c12ac0-625c-4b38-8834-e494b5eda9c0

Add a "speech to text" feature in the Chat input fox of Copilot, similar
as what you have in ChatGPT.

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Run locally and try the speech to text feature as part of the chat
input box

### For configuration changes:

We need to add `OPENAI_API_KEY=` to Vercel ( used in the Front-end )
both in Dev and Prod.

- [x] `.env.default` is updated or already compatible with my changes

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-29 17:46:36 +07:00
Swifty
d738059da8 added long running task support 2026-01-29 10:24:14 +01:00
Nicholas Tindle
7668c17d9c feat(platform): add User Workspace for persistent CoPilot file storage (#11867)
Implements persistent User Workspace storage for CoPilot, enabling
blocks to save and retrieve files across sessions. Files are stored in
session-scoped virtual paths (`/sessions/{session_id}/`).

Fixes SECRT-1833

### Changes 🏗️

**Database & Storage:**
- Add `UserWorkspace` and `UserWorkspaceFile` Prisma models
- Implement `WorkspaceStorageBackend` abstraction (GCS for cloud, local
filesystem for self-hosted)
- Add `workspace_id` and `session_id` fields to `ExecutionContext`

**Backend API:**
- Add REST endpoints: `GET/POST /api/workspace/files`, `GET/DELETE
/api/workspace/files/{id}`, `GET /api/workspace/files/{id}/download`
- Add CoPilot tools: `list_workspace_files`, `read_workspace_file`,
`write_workspace_file`
- Integrate workspace storage into `store_media_file()` - returns
`workspace://file-id` references

**Block Updates:**
- Refactor all file-handling blocks to use unified `ExecutionContext`
parameter
- Update media-generating blocks to persist outputs to workspace
(AIImageGenerator, AIImageCustomizer, FluxKontext, TalkingHead, FAL
video, Bannerbear, etc.)

**Frontend:**
- Render `workspace://` image references in chat via proxy endpoint
- Add "AI cannot see this image" overlay indicator

**CoPilot Context Mapping:**
- Session = Agent (graph_id) = Run (graph_exec_id)
- Files scoped to `/sessions/{session_id}/`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [ ] I have tested my changes according to the test plan:
- [ ] Create CoPilot session, generate image with AIImageGeneratorBlock
  - [ ] Verify image returns `workspace://file-id` (not base64)
  - [ ] Verify image renders in chat with visibility indicator
  - [ ] Verify workspace files persist across sessions
  - [ ] Test list/read/write workspace files via CoPilot tools
  - [ ] Test local storage backend for self-hosted deployments

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

🤖 Generated with [Claude Code](https://claude.ai/code)

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> Introduces a new persistent file-storage surface area (DB tables,
storage backends, download API, and chat tools) and rewires
`store_media_file()`/block execution context across many blocks, so
regressions could impact file handling, access control, or storage
costs.
> 
> **Overview**
> Adds a **persistent per-user Workspace** (new
`UserWorkspace`/`UserWorkspaceFile` models plus `WorkspaceManager` +
`WorkspaceStorageBackend` with GCS/local implementations) and wires it
into the API via a new `/api/workspace/files/{file_id}/download` route
(including header-sanitized `Content-Disposition`) and shutdown
lifecycle hooks.
> 
> Extends `ExecutionContext` to carry execution identity +
`workspace_id`/`session_id`, updates executor tooling to clone
node-specific contexts, and updates `run_block` (CoPilot) to create a
session-scoped workspace and synthetic graph/run/node IDs.
> 
> Refactors `store_media_file()` to require `execution_context` +
`return_format` and to support `workspace://` references; migrates many
media/file-handling blocks and related tests to the new API and to
persist generated media as `workspace://...` (or fall back to data URIs
outside CoPilot), and adds CoPilot chat tools for
listing/reading/writing/deleting workspace files with safeguards against
context bloat.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
6abc70f793. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2026-01-29 05:49:47 +00:00
Nicholas Tindle
e0dfae5732 fix(platform): evaluate chat flag after auth for correct redirect (#11873)
Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 14:58:02 -06:00
Zamil Majdy
7df867d645 Merge branch 'master' of github.com:Significant-Gravitas/AutoGPT into dev 2026-01-28 12:29:41 -06:00
Zamil Majdy
d855f79874 fix(platform): reduce Sentry alert spam for expected errors (#11872)
## Summary
- Add `InvalidInputError` for validation errors (search term too long,
invalid pagination) - returns 400 instead of 500
- Remove redundant try/catch blocks in library routes - global exception
handlers already handle `ValueError`→400 and `NotFoundError`→404
- Aggregate embedding backfill errors and log once at the end instead of
per content type to prevent Sentry issue spam

## Test plan
- [x] Verify validation errors (search term >100 chars) return 400 Bad
Request
- [x] Verify NotFoundError still returns 404
- [x] Verify embedding errors are logged once at the end with aggregated
counts

Fixes AUTOGPT-SERVER-7K5, BUILDER-6NC

---------

Co-authored-by: Swifty <craigswift13@gmail.com>
2026-01-29 01:28:27 +07:00
Swifty
dac99694fe Merge branch 'release/v0.6.44' 2026-01-28 12:19:13 +01:00
Nicholas Tindle
0953983944 feat(platform): disable onboarding redirects and add $5 signup bonus (#11862)
Disable automatic onboarding redirects on signup/login while keeping the
checklist/wallet functional. Users now receive $5 (500 credits) on their
first visit to /copilot.

### Changes 🏗️

- **Frontend**: `shouldShowOnboarding()` now returns `false`, disabling
auto-redirects to `/onboarding`
- **Backend**: Added `VISIT_COPILOT` onboarding step with 500 credit
($5) reward
- **Frontend**: Copilot page automatically completes `VISIT_COPILOT`
step on mount
- **Database**: Migration to add `VISIT_COPILOT` to `OnboardingStep`
enum

NOTE: /onboarding/1-welcome -> /library now as shouldShowOnboardin is
always false

Users land directly on `/copilot` after signup/login and receive $5
invisibly (not shown in checklist UI).

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] New user signup (email/password) → lands on `/copilot`, wallet
shows 500 credits
- [x] Verified credits are only granted once (idempotent via onboarding
reward mechanism)
- [x] Existing user login (already granted flag set) → lands on
`/copilot`, no duplicate credits
  - [x] Checklist/wallet remains functional

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

No configuration changes required.

---

OPEN-2967

🤖 Generated with [Claude Code](https://claude.ai/code)


<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> Introduces a new onboarding step and adjusts onboarding flow.
> 
> - Adds `VISIT_COPILOT` onboarding step (+500 credits) with DB enum
migration and API/type updates
> - Copilot page auto-completes `VISIT_COPILOT` on mount to grant the
welcome bonus
> - Changes `/onboarding/enabled` to require user context and return
`false` when `CHAT` feature is enabled (skips legacy onboarding)
> - Wallet now refreshes credits on any onboarding `step_completed`
notification; confetti limited to visible tasks
> - Test flows updated to accept redirects to `copilot`/`library` and
verify authenticated state
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
ec5a5a4dfd. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <ntindle@users.noreply.github.com>
2026-01-28 07:22:46 +00:00
Zamil Majdy
0058cd3ba6 fix(frontend): auto-poll for long-running tool completion (#11866)
## Summary
Fixes the issue where the "Creating Agent" spinner doesn't auto-update
when agent generation completes - user had to refresh the browser.

**Changes:**
- **Frontend polling**: Add `onOperationStarted` callback to trigger
polling when `operation_started` is received via SSE
- **Polling backoff**: 2s, 4s, 6s, 8s... up to 30s max
- **Message deduplication**: Use content-based keys (role + content)
instead of timestamps to prevent duplicate messages
- **Message ordering**: Preserve server message order instead of
timestamp-based sorting
- **Debug cleanup**: Remove verbose console.log/console.info statements

## Test plan
- [ ] Start agent generation in copilot
- [ ] Verify "Creating Agent" spinner appears
- [ ] Wait for completion (2-5 min) WITHOUT refreshing
- [ ] Verify agent carousel appears automatically when done
- [ ] Verify no duplicate messages in chat
- [ ] Verify message order is correct (user → assistant → tool_call →
tool_response)
2026-01-28 10:03:21 +07:00
Nicholas Tindle
ea035224bc feat(copilot): Increase max_agent_runs and max_agent_schedules (#11865)
<!-- Clearly explain the need for these changes: -->
Config change to increase the max times an agent can run in the chat and
the max number of scheduels created by copilot in one chat

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> Increases per-chat operational limits for Copilot.
> 
> - Bumps `max_agent_runs` default from `3` to `30` in `ChatConfig`
> - Bumps `max_agent_schedules` default from `3` to `30` in `ChatConfig`
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
93cbae6d27. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->
2026-01-28 01:08:02 +00:00
Nicholas Tindle
62813a1ea6 Delete backend/blocks/video/__init__.py (#11864)
<!-- Clearly explain the need for these changes: -->
oops file
### Changes 🏗️

<!-- Concisely describe all of the changes made in this pull request:
-->
removes file that should have not been commited

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> Removes erroneous `backend/blocks/video/__init__.py`, eliminating an
unintended `video` package.
> 
> - Deletes a placeholder comment-only file
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
3b84576c33. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->
2026-01-28 00:58:49 +00:00
Bently
67405f7eb9 fix(copilot): ensure tool_call/tool_response pairs stay intact during context compaction (#11863)
## Summary

Fixes context compaction breaking tool_call/tool_response pairs, causing
API validation errors.

## Problem

When context compaction slices messages with `messages[-KEEP_RECENT:]`,
a naive slice can separate an assistant message containing `tool_calls`
from its corresponding tool response messages. This causes API
validation errors like:

```
messages.0.content.1: unexpected 'tool_use_id' found in 'tool_result' blocks: orphan_12345.
Each 'tool_result' block must have a corresponding 'tool_use' block in the previous message.
```

## Solution

Added `_ensure_tool_pairs_intact()` helper function that:
1. Detects orphan tool responses in a slice (tool messages whose
`tool_call_id` has no matching assistant message)
2. Extends the slice backwards to include the missing assistant messages
3. Falls back to removing orphan tool responses if the assistant cannot
be found (edge case)

Applied this safeguard to:
- The initial `KEEP_RECENT` slice (line ~990)
- The progressive fallback slices when still over token limit (line
~1079)

## Testing

- Syntax validated with `python -m py_compile`
- Logic reviewed for correctness

## Linear

Fixes SECRT-1839

---
*Debugged by Toran & Orion in #agpt Discord*
2026-01-28 00:21:54 +00:00
Zamil Majdy
171ff6e776 feat(backend): persist long-running tool results to survive SSE disconnects (#11856)
## Summary

Agent generation (`create_agent`, `edit_agent`) can take 1-5 minutes.
Previously, if the user closed their browser tab during this time:
1. The SSE connection would die
2. The tool execution would be cancelled via `CancelledError`
3. The result would be lost - even if the agent-generator service
completed successfully

This PR ensures long-running tool operations survive SSE disconnections.

### Changes 🏗️

**Backend:**
- **base.py**: Added `is_long_running` property to `BaseTool` for tools
to opt-in to background execution
- **create_agent.py / edit_agent.py**: Set `is_long_running = True`
- **models.py**: Added `OperationStartedResponse`,
`OperationPendingResponse`, `OperationInProgressResponse` types
- **service.py**: Modified `_yield_tool_call()` to:
  - Check if tool is `is_long_running`
  - Save "pending" message to chat history immediately
  - Spawn background task that runs independently of SSE
  - Return `operation_started` immediately (don't wait)
  - Update chat history with result when background task completes
- Track running operations for idempotency (prevents duplicate ops on
refresh)
- **db.py**: Added `update_tool_message_content()` to update pending
messages
- **model.py**: Added `invalidate_session_cache()` to clear Redis after
background completion

**Frontend:**
- **useChatMessage.ts**: Added operation message types
- **helpers.ts**: Handle `operation_started`, `operation_pending`,
`operation_in_progress` response types
- **PendingOperationWidget**: New component to display operation status
with spinner
- **ChatMessage.tsx**: Render `PendingOperationWidget` for operation
messages

### How It Works

```
User Request → Save "pending" message → Spawn background task → Return immediately
                                              ↓
                                     Task runs independently of SSE
                                              ↓
                                     On completion: Update message in chat history
                                              ↓
                                     User refreshes → Loads history → Sees result
```

### User Experience

1. User requests agent creation
2. Sees "Agent creation started. You can close this tab - check your
library in a few minutes."
3. Can close browser tab safely
4. When they return, chat shows the completed result (or error)

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] pyright passes (0 errors)
  - [x] TypeScript checks pass
  - [x] Formatters applied

### Test Plan

1. Start agent creation in copilot
2. Close browser tab immediately after seeing "operation_started" 
3. Wait 2-3 minutes
4. Reopen chat
5. Verify: Chat history shows completion message and agent appears in
library

---------

Co-authored-by: Ubbe <hi@ubbe.dev>
2026-01-28 05:09:34 +07:00
Lluis Agusti
349b1f9c79 hotfix(frontend): copilot session handling refinements... 2026-01-28 02:53:45 +07:00
Lluis Agusti
277b0537e9 hotfix(frontend): copilot simplication... 2026-01-28 02:10:18 +07:00
Ubbe
071b3bb5cd fix(frontend): more copilot refinements (#11858)
## Changes 🏗️

On the **Copilot** page:

- prevent unnecessary sidebar repaints 
- show a disclaimer when switching chats on the sidebar to terminate a
current stream
- handle loading better
- save streams better when disconnecting


### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Run the app locally and test the above
2026-01-28 00:49:28 +07:00
118 changed files with 7187 additions and 1467 deletions

View File

@@ -16,7 +16,6 @@ See `docs/content/platform/getting-started.md` for setup instructions.
- Format Python code with `poetry run format`.
- Format frontend code using `pnpm format`.
## Frontend guidelines:
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
@@ -33,14 +32,17 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
- Use function declarations for components, arrow functions only for callbacks
- No barrel files or `index.ts` re-exports
- Do not use `useCallback` or `useMemo` unless strictly needed
- Avoid comments at all times unless the code is very complex
- Do not use `useCallback` or `useMemo` unless asked to optimise a given function
- Do not type hook returns, let Typescript infer as much as possible
- Never type with `any`, if not types available use `unknown`
## Testing
@@ -49,22 +51,8 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
Always run the relevant linters and tests before committing.
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).
Types:
- feat
- fix
- refactor
- ci
- dx (developer experience)
Scopes:
- platform
- platform/library
- platform/marketplace
- backend
- backend/executor
- frontend
- frontend/library
- frontend/marketplace
- blocks
Types: - feat - fix - refactor - ci - dx (developer experience)
Scopes: - platform - platform/library - platform/marketplace - backend - backend/executor - frontend - frontend/library - frontend/marketplace - blocks
## Pull requests

View File

@@ -85,17 +85,6 @@ pnpm format
pnpm types
```
**📖 Complete Guide**: See `/frontend/CONTRIBUTING.md` and `/frontend/.cursorrules` for comprehensive frontend patterns.
**Key Frontend Conventions:**
- Separate render logic from data/behavior in components
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
- Use function declarations (not arrow functions) for components/handlers
- Use design system components from `src/components/` (atoms, molecules, organisms)
- Only use Phosphor Icons
- Never use `src/components/__legacy__/*` or deprecated `BackendAPI`
## Architecture Overview
### Backend Architecture
@@ -194,6 +183,50 @@ ex: do the inputs and outputs tie well together?
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
**Handling files in blocks with `store_media_file()`:**
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
| Format | Use When | Returns |
|--------|----------|---------|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
**Examples:**
```python
# INPUT: Need to process file locally with ffmpeg
local_path = await store_media_file(
file=input_data.video,
execution_context=execution_context,
return_format="for_local_processing",
)
# local_path = "video.mp4" - use with Path/ffmpeg/etc
# INPUT: Need to send to external API like Replicate
image_b64 = await store_media_file(
file=input_data.image,
execution_context=execution_context,
return_format="for_external_api",
)
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
# OUTPUT: Returning result from block
result_url = await store_media_file(
file=generated_image_url,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", result_url
# In CoPilot: result_url = "workspace://abc123"
# In graphs: result_url = "data:image/png;base64,..."
```
**Key points:**
- `for_block_output` is the ONLY format that auto-adapts to execution context
- Always use `for_block_output` for block outputs unless you have a specific reason not to
- Never hardcode workspace checks - let `for_block_output` handle it
**Modifying the API:**
1. Update route in `/backend/backend/server/routers/`
@@ -217,14 +250,17 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
- Use function declarations for components, arrow functions only for callbacks
- No barrel files or `index.ts` re-exports
- Do not use `useCallback` or `useMemo` unless strictly needed
- Do not use `useCallback` or `useMemo` unless asked to optimise a given function
- Avoid comments at all times unless the code is very complex
- Do not type hook returns, let Typescript infer as much as possible
- Never type with `any`, if not types available use `unknown`
### Security Implementation

View File

@@ -0,0 +1,308 @@
"""RabbitMQ consumer for operation completion messages.
This module provides a consumer that listens for completion notifications
from external services (like Agent Generator) and triggers the appropriate
stream registry and chat service updates.
"""
import asyncio
import logging
import orjson
from pydantic import BaseModel
from backend.data.rabbitmq import (
AsyncRabbitMQ,
Exchange,
ExchangeType,
Queue,
RabbitMQConfig,
)
from . import service as chat_service
from . import stream_registry
from .response_model import StreamError, StreamToolOutputAvailable
from .tools.models import ErrorResponse
logger = logging.getLogger(__name__)
# Queue and exchange configuration
OPERATION_COMPLETE_EXCHANGE = Exchange(
name="chat_operations",
type=ExchangeType.DIRECT,
durable=True,
)
OPERATION_COMPLETE_QUEUE = Queue(
name="chat_operation_complete",
durable=True,
exchange=OPERATION_COMPLETE_EXCHANGE,
routing_key="operation.complete",
)
RABBITMQ_CONFIG = RabbitMQConfig(
exchanges=[OPERATION_COMPLETE_EXCHANGE],
queues=[OPERATION_COMPLETE_QUEUE],
)
class OperationCompleteMessage(BaseModel):
"""Message format for operation completion notifications."""
operation_id: str
task_id: str
success: bool
result: dict | str | None = None
error: str | None = None
class ChatCompletionConsumer:
"""Consumer for chat operation completion messages from RabbitMQ."""
def __init__(self):
self._rabbitmq: AsyncRabbitMQ | None = None
self._consumer_task: asyncio.Task | None = None
self._running = False
async def start(self) -> None:
"""Start the completion consumer."""
if self._running:
logger.warning("Completion consumer already running")
return
self._rabbitmq = AsyncRabbitMQ(RABBITMQ_CONFIG)
await self._rabbitmq.connect()
self._running = True
self._consumer_task = asyncio.create_task(self._consume_messages())
logger.info("Chat completion consumer started")
async def stop(self) -> None:
"""Stop the completion consumer."""
self._running = False
if self._consumer_task:
self._consumer_task.cancel()
try:
await self._consumer_task
except asyncio.CancelledError:
pass
self._consumer_task = None
if self._rabbitmq:
await self._rabbitmq.disconnect()
self._rabbitmq = None
logger.info("Chat completion consumer stopped")
async def _consume_messages(self) -> None:
"""Main message consumption loop."""
if not self._rabbitmq:
logger.error("RabbitMQ not initialized")
return
try:
channel = await self._rabbitmq.get_channel()
queue = await channel.get_queue(OPERATION_COMPLETE_QUEUE.name)
async with queue.iterator() as queue_iter:
async for message in queue_iter:
if not self._running:
break
try:
async with message.process():
await self._handle_message(message.body)
except Exception as e:
logger.error(
f"Error processing completion message: {e}",
exc_info=True,
)
# Message will be requeued due to exception
except asyncio.CancelledError:
logger.info("Consumer cancelled")
except Exception as e:
logger.error(f"Consumer error: {e}", exc_info=True)
# Attempt to reconnect after a delay
if self._running:
await asyncio.sleep(5)
await self._consume_messages()
async def _handle_message(self, body: bytes) -> None:
"""Handle a single completion message."""
try:
data = orjson.loads(body)
message = OperationCompleteMessage(**data)
except Exception as e:
logger.error(f"Failed to parse completion message: {e}")
return
logger.info(
f"Received completion for operation {message.operation_id} "
f"(task_id={message.task_id}, success={message.success})"
)
# Find task in registry
task = await stream_registry.find_task_by_operation_id(message.operation_id)
if task is None:
# Try to look up by task_id directly
task = await stream_registry.get_task(message.task_id)
if task is None:
logger.warning(
f"Task not found for operation {message.operation_id} "
f"(task_id={message.task_id})"
)
return
if message.success:
await self._handle_success(task, message)
else:
await self._handle_failure(task, message)
async def _handle_success(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle successful operation completion."""
# Publish result to stream registry
result_output = message.result if message.result else {"status": "completed"}
await stream_registry.publish_chunk(
task.task_id,
StreamToolOutputAvailable(
toolCallId=task.tool_call_id,
toolName=task.tool_name,
output=(
result_output
if isinstance(result_output, str)
else orjson.dumps(result_output).decode("utf-8")
),
success=True,
),
)
# Update pending operation in database
result_str = (
message.result
if isinstance(message.result, str)
else (
orjson.dumps(message.result).decode("utf-8")
if message.result
else '{"status": "completed"}'
)
)
await chat_service._update_pending_operation(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
result=result_str,
)
# Generate LLM continuation with streaming
await chat_service._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
# Mark task as completed
await stream_registry.mark_task_completed(task.task_id, status="completed")
logger.info(
f"Successfully processed completion for task {task.task_id} "
f"(operation {message.operation_id})"
)
async def _handle_failure(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle failed operation completion."""
error_msg = message.error or "Operation failed"
# Publish error to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText=error_msg),
)
# Update pending operation with error
error_response = ErrorResponse(
message=error_msg,
error=message.error,
)
await chat_service._update_pending_operation(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
result=error_response.model_dump_json(),
)
# Mark task as failed
await stream_registry.mark_task_completed(task.task_id, status="failed")
logger.info(
f"Processed failure for task {task.task_id} "
f"(operation {message.operation_id}): {error_msg}"
)
# Module-level consumer instance
_consumer: ChatCompletionConsumer | None = None
async def start_completion_consumer() -> None:
"""Start the global completion consumer."""
global _consumer
if _consumer is None:
_consumer = ChatCompletionConsumer()
await _consumer.start()
async def stop_completion_consumer() -> None:
"""Stop the global completion consumer."""
global _consumer
if _consumer:
await _consumer.stop()
_consumer = None
async def publish_operation_complete(
operation_id: str,
task_id: str,
success: bool,
result: dict | str | None = None,
error: str | None = None,
) -> None:
"""Publish an operation completion message.
This is a helper function for testing or for services that want to
publish completion messages directly.
Args:
operation_id: The operation ID that completed.
task_id: The task ID associated with the operation.
success: Whether the operation succeeded.
result: The result data (for success).
error: The error message (for failure).
"""
message = OperationCompleteMessage(
operation_id=operation_id,
task_id=task_id,
success=success,
result=result,
error=error,
)
rabbitmq = AsyncRabbitMQ(RABBITMQ_CONFIG)
try:
await rabbitmq.connect()
await rabbitmq.publish_message(
routing_key="operation.complete",
message=message.model_dump_json(),
exchange=OPERATION_COMPLETE_EXCHANGE,
)
logger.info(f"Published completion for operation {operation_id}")
finally:
await rabbitmq.disconnect()

View File

@@ -33,9 +33,29 @@ class ChatConfig(BaseSettings):
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(default=3, description="Maximum number of retries")
max_agent_runs: int = Field(default=3, description="Maximum number of agent runs")
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
max_agent_schedules: int = Field(
default=3, description="Maximum number of agent schedules"
default=30, description="Maximum number of agent schedules"
)
# Long-running operation configuration
long_running_operation_ttl: int = Field(
default=600,
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
)
# Stream registry configuration for SSE reconnection
stream_ttl: int = Field(
default=3600,
description="TTL in seconds for stream data in Redis (1 hour)",
)
stream_max_length: int = Field(
default=1000,
description="Maximum number of messages to store per stream",
)
internal_api_key: str | None = Field(
default=None,
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
)
# Langfuse Prompt Management Configuration

View File

@@ -247,3 +247,45 @@ async def get_chat_session_message_count(session_id: str) -> int:
"""Get the number of messages in a chat session."""
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
return count
async def update_tool_message_content(
session_id: str,
tool_call_id: str,
new_content: str,
) -> bool:
"""Update the content of a tool message in chat history.
Used by background tasks to update pending operation messages with final results.
Args:
session_id: The chat session ID.
tool_call_id: The tool call ID to find the message.
new_content: The new content to set.
Returns:
True if a message was updated, False otherwise.
"""
try:
result = await PrismaChatMessage.prisma().update_many(
where={
"sessionId": session_id,
"toolCallId": tool_call_id,
},
data={
"content": new_content,
},
)
if result == 0:
logger.warning(
f"No message found to update for session {session_id}, "
f"tool_call_id {tool_call_id}"
)
return False
return True
except Exception as e:
logger.error(
f"Failed to update tool message for session {session_id}, "
f"tool_call_id {tool_call_id}: {e}"
)
return False

View File

@@ -295,6 +295,21 @@ async def cache_chat_session(session: ChatSession) -> None:
await _cache_session(session)
async def invalidate_session_cache(session_id: str) -> None:
"""Invalidate a chat session from Redis cache.
Used by background tasks to ensure fresh data is loaded on next access.
This is best-effort - Redis failures are logged but don't fail the operation.
"""
try:
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
except Exception as e:
# Best-effort: log but don't fail - cache will expire naturally
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
async def _get_session_from_db(session_id: str) -> ChatSession | None:
"""Get a chat session from the database."""
prisma_session = await chat_db.get_chat_session(session_id)

View File

@@ -5,15 +5,17 @@ from collections.abc import AsyncGenerator
from typing import Annotated
from autogpt_libs import auth
from fastapi import APIRouter, Depends, Query, Security
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from backend.util.exceptions import NotFoundError
from . import service as chat_service
from . import stream_registry
from .config import ChatConfig
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
from .response_model import StreamFinish
config = ChatConfig()
@@ -81,6 +83,14 @@ class ListSessionsResponse(BaseModel):
total: int
class OperationCompleteRequest(BaseModel):
"""Request model for external completion webhook."""
success: bool
result: dict | str | None = None
error: str | None = None
# ========== Routes ==========
@@ -366,6 +376,243 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Task Streaming (SSE Reconnection) ==========
@router.get(
"/tasks/{task_id}/stream",
)
async def stream_task(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
last_idx: int = Query(default=0, ge=0, description="Last message index received"),
):
"""
Reconnect to a long-running task's SSE stream.
When a long-running operation (like agent generation) starts, the client
receives a task_id. If the connection drops, the client can reconnect
using this endpoint to resume receiving updates.
Args:
task_id: The task ID from the operation_started response.
user_id: Authenticated user ID for ownership validation.
last_idx: Last message index received (0 for full replay).
Returns:
StreamingResponse: SSE-formatted response chunks starting from last_idx.
Raises:
NotFoundError: If task_id is not found or user doesn't have access.
"""
# Get subscriber queue from stream registry
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_idx=last_idx,
)
if subscriber_queue is None:
raise NotFoundError(f"Task {task_id} not found or access denied.")
async def event_generator() -> AsyncGenerator[str, None]:
chunk_count = 0
try:
while True:
# Wait for next chunk from the queue
chunk = await subscriber_queue.get()
chunk_count += 1
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
logger.info(
f"Task stream completed for task {task_id}, "
f"chunk_count={chunk_count}"
)
break
except Exception as e:
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
},
)
@router.get(
"/tasks/{task_id}",
)
async def get_task_status(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
) -> dict:
"""
Get the status of a long-running task.
Args:
task_id: The task ID to check.
user_id: Authenticated user ID for ownership validation.
Returns:
dict: Task status including task_id, status, tool_name, and operation_id.
Raises:
NotFoundError: If task_id is not found or user doesn't have access.
"""
task = await stream_registry.get_task(task_id)
if task is None:
raise NotFoundError(f"Task {task_id} not found.")
# Validate ownership
if user_id and task.user_id and task.user_id != user_id:
raise NotFoundError(f"Task {task_id} not found.")
return {
"task_id": task.task_id,
"session_id": task.session_id,
"status": task.status,
"tool_name": task.tool_name,
"operation_id": task.operation_id,
"created_at": task.created_at.isoformat(),
}
# ========== External Completion Webhook ==========
@router.post(
"/operations/{operation_id}/complete",
status_code=200,
)
async def complete_operation(
operation_id: str,
request: OperationCompleteRequest,
x_api_key: str | None = Header(default=None),
) -> dict:
"""
External completion webhook for long-running operations.
Called by Agent Generator (or other services) when an operation completes.
This triggers the stream registry to publish completion and continue LLM generation.
Args:
operation_id: The operation ID to complete.
request: Completion payload with success status and result/error.
x_api_key: Internal API key for authentication.
Returns:
dict: Status of the completion.
Raises:
HTTPException: If API key is invalid or operation not found.
"""
# Validate internal API key
if config.internal_api_key:
if x_api_key != config.internal_api_key:
raise HTTPException(status_code=401, detail="Invalid API key")
else:
# If no internal API key is configured, log a warning
logger.warning(
"Operation complete webhook called without API key validation "
"(CHAT_INTERNAL_API_KEY not configured)"
)
# Find task by operation_id
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None:
raise HTTPException(
status_code=404,
detail=f"Operation {operation_id} not found",
)
logger.info(
f"Received completion webhook for operation {operation_id} "
f"(task_id={task.task_id}, success={request.success})"
)
if request.success:
# Publish result to stream registry
from .response_model import StreamToolOutputAvailable
result_output = request.result if request.result else {"status": "completed"}
await stream_registry.publish_chunk(
task.task_id,
StreamToolOutputAvailable(
toolCallId=task.tool_call_id,
toolName=task.tool_name,
output=(
result_output
if isinstance(result_output, str)
else str(result_output)
),
success=True,
),
)
# Update pending operation in database
from . import service as svc
result_str = (
request.result
if isinstance(request.result, str)
else str(request.result) if request.result else '{"status": "completed"}'
)
await svc._update_pending_operation(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
result=result_str,
)
# Generate LLM continuation with streaming
await svc._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
# Mark task as completed
await stream_registry.mark_task_completed(task.task_id, status="completed")
else:
# Publish error to stream registry
from .response_model import StreamError
error_msg = request.error or "Operation failed"
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText=error_msg),
)
# Update pending operation with error
from . import service as svc
from .tools.models import ErrorResponse
error_response = ErrorResponse(
message=error_msg,
error=request.error,
)
await svc._update_pending_operation(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
result=error_response.model_dump_json(),
)
# Mark task as failed
await stream_registry.mark_task_completed(task.task_id, status="failed")
return {"status": "ok", "task_id": task.task_id}
# ========== Health Check ==========

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,470 @@
"""Stream registry for managing reconnectable SSE streams.
This module provides a registry for tracking active streaming tasks and their
messages. It supports:
- Creating tasks with unique IDs for long-running operations
- Publishing stream messages to both Redis Streams and in-memory queues
- Subscribing to tasks with replay of missed messages
- Looking up tasks by operation_id for webhook callbacks
"""
import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Literal
import orjson
from backend.data.redis_client import get_redis_async
from .config import ChatConfig
from .response_model import StreamBaseResponse, StreamFinish
logger = logging.getLogger(__name__)
config = ChatConfig()
@dataclass
class ActiveTask:
"""Represents an active streaming task."""
task_id: str
session_id: str
user_id: str | None
tool_call_id: str
tool_name: str
operation_id: str
status: Literal["running", "completed", "failed"] = "running"
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
queue: asyncio.Queue[StreamBaseResponse] = field(default_factory=asyncio.Queue)
asyncio_task: asyncio.Task | None = None
# Module-level registry for active tasks
_active_tasks: dict[str, ActiveTask] = {}
# Redis key patterns
TASK_META_PREFIX = "chat:task:meta:" # Hash for task metadata
TASK_STREAM_PREFIX = "chat:stream:" # Redis Stream for messages
TASK_OP_PREFIX = "chat:task:op:" # Operation ID -> task_id mapping
def _get_task_meta_key(task_id: str) -> str:
"""Get Redis key for task metadata."""
return f"{TASK_META_PREFIX}{task_id}"
def _get_task_stream_key(task_id: str) -> str:
"""Get Redis key for task message stream."""
return f"{TASK_STREAM_PREFIX}{task_id}"
def _get_operation_mapping_key(operation_id: str) -> str:
"""Get Redis key for operation_id to task_id mapping."""
return f"{TASK_OP_PREFIX}{operation_id}"
async def create_task(
task_id: str,
session_id: str,
user_id: str | None,
tool_call_id: str,
tool_name: str,
operation_id: str,
) -> ActiveTask:
"""Create a new streaming task in memory and Redis.
Args:
task_id: Unique identifier for the task
session_id: Chat session ID
user_id: User ID (may be None for anonymous)
tool_call_id: Tool call ID from the LLM
tool_name: Name of the tool being executed
operation_id: Operation ID for webhook callbacks
Returns:
The created ActiveTask instance
"""
task = ActiveTask(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
operation_id=operation_id,
)
# Store in memory registry
_active_tasks[task_id] = task
# Store metadata in Redis for durability
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
op_key = _get_operation_mapping_key(operation_id)
await redis.hset( # type: ignore[misc]
meta_key,
mapping={
"task_id": task_id,
"session_id": session_id,
"user_id": user_id or "",
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"operation_id": operation_id,
"status": task.status,
"created_at": task.created_at.isoformat(),
},
)
await redis.expire(meta_key, config.stream_ttl)
# Create operation_id -> task_id mapping for webhook lookups
await redis.set(op_key, task_id, ex=config.stream_ttl)
logger.info(
f"Created streaming task {task_id} for operation {operation_id} "
f"in session {session_id}"
)
return task
async def publish_chunk(
task_id: str,
chunk: StreamBaseResponse,
) -> int:
"""Publish a chunk to the task's stream.
Writes to both Redis Stream (for replay) and in-memory queue (for live subscribers).
Args:
task_id: Task ID to publish to
chunk: The stream response chunk to publish
Returns:
The message index in the Redis Stream
"""
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
# Serialize chunk to JSON
chunk_json = chunk.model_dump_json()
# Add to Redis Stream with auto-generated ID
# The ID format is "timestamp-sequence" which gives us ordering
message_id = await redis.xadd(
stream_key,
{"data": chunk_json},
maxlen=config.stream_max_length,
)
# Publish to in-memory queue if task exists
task = _active_tasks.get(task_id)
if task:
try:
task.queue.put_nowait(chunk)
except asyncio.QueueFull:
logger.warning(f"Queue full for task {task_id}, dropping chunk")
logger.debug(f"Published chunk to task {task_id}, message_id={message_id}")
# Parse the message_id to extract the index
# Redis Stream IDs are "timestamp-sequence", we return the raw ID
return int(message_id.split("-")[1]) if "-" in message_id else 0
async def subscribe_to_task(
task_id: str,
user_id: str | None,
last_idx: int = 0,
) -> asyncio.Queue[StreamBaseResponse] | None:
"""Subscribe to a task's stream with replay of missed messages.
Args:
task_id: Task ID to subscribe to
user_id: User ID for ownership validation
last_idx: Last message index received (0 for full replay)
Returns:
An asyncio Queue that will receive stream chunks, or None if task not found
or user doesn't have access
"""
# Check in-memory first
task = _active_tasks.get(task_id)
if task:
# Validate ownership
if user_id and task.user_id and task.user_id != user_id:
logger.warning(
f"User {user_id} attempted to subscribe to task {task_id} "
f"owned by {task.user_id}"
)
return None
# Create a new queue for this subscriber
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
# Replay from Redis Stream
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
# Read all messages from stream
# Use "0-0" to get all messages or construct ID from last_idx
start_id = "0-0" if last_idx == 0 else f"0-{last_idx}"
messages = await redis.xread({stream_key: start_id}, block=0, count=1000)
if messages:
# messages format: [[stream_name, [(id, {data: json}), ...]]]
for _stream_name, stream_messages in messages:
for _msg_id, msg_data in stream_messages:
if b"data" in msg_data:
try:
chunk_data = orjson.loads(msg_data[b"data"])
# Reconstruct the appropriate response type
chunk = _reconstruct_chunk(chunk_data)
if chunk:
await subscriber_queue.put(chunk)
except Exception as e:
logger.warning(f"Failed to replay message: {e}")
# If task is still running, set up live subscription
if task.status == "running":
# Forward messages from task queue to subscriber queue
async def _forward_messages():
try:
while True:
chunk = await task.queue.get()
await subscriber_queue.put(chunk)
if isinstance(chunk, StreamFinish):
break
except asyncio.CancelledError:
pass
asyncio.create_task(_forward_messages())
else:
# Task is done, add finish marker
await subscriber_queue.put(StreamFinish())
return subscriber_queue
# Try to load from Redis if not in memory
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
logger.warning(f"Task {task_id} not found in memory or Redis")
return None
# Validate ownership
task_user_id = meta.get(b"user_id", b"").decode() or None
if user_id and task_user_id and task_user_id != user_id:
logger.warning(
f"User {user_id} attempted to subscribe to task {task_id} "
f"owned by {task_user_id}"
)
return None
# Replay from Redis Stream only (task is not in memory, so it's completed/crashed)
subscriber_queue = asyncio.Queue()
stream_key = _get_task_stream_key(task_id)
start_id = "0-0" if last_idx == 0 else f"0-{last_idx}"
messages = await redis.xread({stream_key: start_id}, block=0, count=1000)
if messages:
for _stream_name, stream_messages in messages:
for _msg_id, msg_data in stream_messages:
if b"data" in msg_data:
try:
chunk_data = orjson.loads(msg_data[b"data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
await subscriber_queue.put(chunk)
except Exception as e:
logger.warning(f"Failed to replay message: {e}")
# Add finish marker since task is not active
await subscriber_queue.put(StreamFinish())
return subscriber_queue
async def mark_task_completed(
task_id: str,
status: Literal["completed", "failed"] = "completed",
) -> None:
"""Mark a task as completed and publish final event.
Args:
task_id: Task ID to mark as completed
status: Final status ("completed" or "failed")
"""
task = _active_tasks.get(task_id)
if task:
task.status = status
# Publish finish event to all subscribers
await publish_chunk(task_id, StreamFinish())
# Remove from active tasks after a short delay to allow subscribers to finish
async def _cleanup():
await asyncio.sleep(5)
_active_tasks.pop(task_id, None)
logger.info(f"Cleaned up task {task_id} from memory")
asyncio.create_task(_cleanup())
# Update Redis metadata
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
await redis.hset(meta_key, "status", status) # type: ignore[misc]
logger.info(f"Marked task {task_id} as {status}")
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
"""Find a task by its operation ID.
Used by webhook callbacks to locate the task to update.
Args:
operation_id: Operation ID to search for
Returns:
ActiveTask if found, None otherwise
"""
# Check in-memory first
for task in _active_tasks.values():
if task.operation_id == operation_id:
return task
# Try Redis lookup
redis = await get_redis_async()
op_key = _get_operation_mapping_key(operation_id)
task_id = await redis.get(op_key)
if task_id:
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
# Check if task is in memory
if task_id_str in _active_tasks:
return _active_tasks[task_id_str]
# Load metadata from Redis
meta_key = _get_task_meta_key(task_id_str)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if meta:
# Reconstruct task object (not fully active, but has metadata)
return ActiveTask(
task_id=meta.get(b"task_id", b"").decode(),
session_id=meta.get(b"session_id", b"").decode(),
user_id=meta.get(b"user_id", b"").decode() or None,
tool_call_id=meta.get(b"tool_call_id", b"").decode(),
tool_name=meta.get(b"tool_name", b"").decode(),
operation_id=operation_id,
status=meta.get(b"status", b"running").decode(), # type: ignore
)
return None
async def get_task(task_id: str) -> ActiveTask | None:
"""Get a task by its ID.
Args:
task_id: Task ID to look up
Returns:
ActiveTask if found, None otherwise
"""
# Check in-memory first
if task_id in _active_tasks:
return _active_tasks[task_id]
# Try Redis lookup
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if meta:
return ActiveTask(
task_id=meta.get(b"task_id", b"").decode(),
session_id=meta.get(b"session_id", b"").decode(),
user_id=meta.get(b"user_id", b"").decode() or None,
tool_call_id=meta.get(b"tool_call_id", b"").decode(),
tool_name=meta.get(b"tool_name", b"").decode(),
operation_id=meta.get(b"operation_id", b"").decode(),
status=meta.get(b"status", b"running").decode(), # type: ignore[arg-type]
)
return None
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
"""Reconstruct a StreamBaseResponse from JSON data.
Args:
chunk_data: Parsed JSON data from Redis
Returns:
Reconstructed response object, or None if unknown type
"""
from .response_model import (
ResponseType,
StreamError,
StreamFinish,
StreamHeartbeat,
StreamStart,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
chunk_type = chunk_data.get("type")
try:
if chunk_type == ResponseType.START.value:
return StreamStart(**chunk_data)
elif chunk_type == ResponseType.FINISH.value:
return StreamFinish(**chunk_data)
elif chunk_type == ResponseType.TEXT_START.value:
return StreamTextStart(**chunk_data)
elif chunk_type == ResponseType.TEXT_DELTA.value:
return StreamTextDelta(**chunk_data)
elif chunk_type == ResponseType.TEXT_END.value:
return StreamTextEnd(**chunk_data)
elif chunk_type == ResponseType.TOOL_INPUT_START.value:
return StreamToolInputStart(**chunk_data)
elif chunk_type == ResponseType.TOOL_INPUT_AVAILABLE.value:
return StreamToolInputAvailable(**chunk_data)
elif chunk_type == ResponseType.TOOL_OUTPUT_AVAILABLE.value:
return StreamToolOutputAvailable(**chunk_data)
elif chunk_type == ResponseType.ERROR.value:
return StreamError(**chunk_data)
elif chunk_type == ResponseType.USAGE.value:
return StreamUsage(**chunk_data)
elif chunk_type == ResponseType.HEARTBEAT.value:
return StreamHeartbeat(**chunk_data)
else:
logger.warning(f"Unknown chunk type: {chunk_type}")
return None
except Exception as e:
logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}")
return None
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
"""Associate an asyncio.Task with an ActiveTask.
Args:
task_id: Task ID
asyncio_task: The asyncio Task to associate
"""
task = _active_tasks.get(task_id)
if task:
task.asyncio_task = asyncio_task

View File

@@ -0,0 +1,79 @@
# CoPilot Tools - Future Ideas
## Multimodal Image Support for CoPilot
**Problem:** CoPilot uses a vision-capable model but can't "see" workspace images. When a block generates an image and returns `workspace://abc123`, CoPilot can't evaluate it (e.g., checking blog thumbnail quality).
**Backend Solution:**
When preparing messages for the LLM, detect `workspace://` image references and convert them to proper image content blocks:
```python
# Before sending to LLM, scan for workspace image references
# and inject them as image content parts
# Example message transformation:
# FROM: {"role": "assistant", "content": "Generated image: workspace://abc123"}
# TO: {"role": "assistant", "content": [
# {"type": "text", "text": "Generated image: workspace://abc123"},
# {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
# ]}
```
**Where to implement:**
- In the chat stream handler before calling the LLM
- Or in a message preprocessing step
- Need to fetch image from workspace, convert to base64, add as image content
**Considerations:**
- Only do this for image MIME types (image/png, image/jpeg, etc.)
- May want a size limit (don't pass 10MB images)
- Track which images were "shown" to the AI for frontend indicator
- Cost implications - vision API calls are more expensive
**Frontend Solution:**
Show visual indicator on workspace files in chat:
- If AI saw the image: normal display
- If AI didn't see it: overlay icon saying "AI can't see this image"
Requires response metadata indicating which `workspace://` refs were passed to the model.
---
## Output Post-Processing Layer for run_block
**Problem:** Many blocks produce large outputs that:
- Consume massive context (100KB base64 image = ~133KB tokens)
- Can't fit in conversation
- Break things and cause high LLM costs
**Proposed Solution:** Instead of modifying individual blocks or `store_media_file()`, implement a centralized output processor in `run_block.py` that handles outputs before they're returned to CoPilot.
**Benefits:**
1. **Centralized** - one place to handle all output processing
2. **Future-proof** - new blocks automatically get output processing
3. **Keeps blocks pure** - they don't need to know about context constraints
4. **Handles all large outputs** - not just images
**Processing Rules:**
- Detect base64 data URIs → save to workspace, return `workspace://` reference
- Truncate very long strings (>N chars) with truncation note
- Summarize large arrays/lists (e.g., "Array with 1000 items, first 5: [...]")
- Handle nested large outputs in dicts recursively
- Cap total output size
**Implementation Location:** `run_block.py` after block execution, before returning `BlockOutputResponse`
**Example:**
```python
def _process_outputs_for_context(
outputs: dict[str, list[Any]],
workspace_manager: WorkspaceManager,
max_string_length: int = 10000,
max_array_preview: int = 5,
) -> dict[str, list[Any]]:
"""Process block outputs to prevent context bloat."""
processed = {}
for name, values in outputs.items():
processed[name] = [_process_value(v, workspace_manager) for v in values]
return processed
```

View File

@@ -18,6 +18,12 @@ from .get_doc_page import GetDocPageTool
from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .search_docs import SearchDocsTool
from .workspace_files import (
DeleteWorkspaceFileTool,
ListWorkspaceFilesTool,
ReadWorkspaceFileTool,
WriteWorkspaceFileTool,
)
if TYPE_CHECKING:
from backend.api.features.chat.response_model import StreamToolOutputAvailable
@@ -37,6 +43,11 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"view_agent_output": AgentOutputTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
# Workspace tools for CoPilot file operations
"list_workspace_files": ListWorkspaceFilesTool(),
"read_workspace_file": ReadWorkspaceFileTool(),
"write_workspace_file": WriteWorkspaceFileTool(),
"delete_workspace_file": DeleteWorkspaceFileTool(),
}
# Export individual tool instances for backwards compatibility
@@ -49,6 +60,11 @@ tools: list[ChatCompletionToolParam] = [
]
def get_tool(tool_name: str) -> BaseTool | None:
"""Get a tool instance by name."""
return TOOL_REGISTRY.get(tool_name)
async def execute_tool(
tool_name: str,
parameters: dict[str, Any],
@@ -57,7 +73,7 @@ async def execute_tool(
tool_call_id: str,
) -> "StreamToolOutputAvailable":
"""Execute a tool by name."""
tool = TOOL_REGISTRY.get(tool_name)
tool = get_tool(tool_name)
if not tool:
raise ValueError(f"Tool {tool_name} not found")

View File

@@ -36,6 +36,16 @@ class BaseTool:
"""Whether this tool requires authentication."""
return False
@property
def is_long_running(self) -> bool:
"""Whether this tool is long-running and should execute in background.
Long-running tools (like agent generation) are executed via background
tasks to survive SSE disconnections. The result is persisted to chat
history and visible when the user refreshes.
"""
return False
def as_openai_tool(self) -> ChatCompletionToolParam:
"""Convert to OpenAI tool format."""
return ChatCompletionToolParam(

View File

@@ -42,6 +42,10 @@ class CreateAgentTool(BaseTool):
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {

View File

@@ -42,6 +42,10 @@ class EditAgentTool(BaseTool):
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {

View File

@@ -28,6 +28,16 @@ class ResponseType(str, Enum):
BLOCK_OUTPUT = "block_output"
DOC_SEARCH_RESULTS = "doc_search_results"
DOC_PAGE = "doc_page"
# Workspace response types
WORKSPACE_FILE_LIST = "workspace_file_list"
WORKSPACE_FILE_CONTENT = "workspace_file_content"
WORKSPACE_FILE_METADATA = "workspace_file_metadata"
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
# Long-running operation types
OPERATION_STARTED = "operation_started"
OPERATION_PENDING = "operation_pending"
OPERATION_IN_PROGRESS = "operation_in_progress"
# Base response model
@@ -334,3 +344,43 @@ class BlockOutputResponse(ToolResponseBase):
block_name: str
outputs: dict[str, list[Any]]
success: bool = True
# Long-running operation models
class OperationStartedResponse(ToolResponseBase):
"""Response when a long-running operation has been started in the background.
This is returned immediately to the client while the operation continues
to execute. The user can close the tab and check back later.
The task_id can be used to reconnect to the SSE stream via
GET /chat/tasks/{task_id}/stream?last_idx=0
"""
type: ResponseType = ResponseType.OPERATION_STARTED
operation_id: str
tool_name: str
task_id: str | None = None # For SSE reconnection
class OperationPendingResponse(ToolResponseBase):
"""Response stored in chat history while a long-running operation is executing.
This is persisted to the database so users see a pending state when they
refresh before the operation completes.
"""
type: ResponseType = ResponseType.OPERATION_PENDING
operation_id: str
tool_name: str
class OperationInProgressResponse(ToolResponseBase):
"""Response when an operation is already in progress.
Returned for idempotency when the same tool_call_id is requested again
while the background task is still running.
"""
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
tool_call_id: str

View File

@@ -1,6 +1,7 @@
"""Tool for executing blocks directly."""
import logging
import uuid
from collections import defaultdict
from typing import Any
@@ -8,6 +9,7 @@ from backend.api.features.chat.model import ChatSession
from backend.data.block import get_block
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
@@ -223,11 +225,48 @@ class RunBlockTool(BaseTool):
)
try:
# Fetch actual credentials and prepare kwargs for block execution
# Create execution context with defaults (blocks may require it)
# Get or create user's workspace for CoPilot file operations
workspace = await get_or_create_workspace(user_id)
# Generate synthetic IDs for CoPilot context
# Each chat session is treated as its own agent with one continuous run
# This means:
# - graph_id (agent) = session (memories scoped to session when limit_to_agent=True)
# - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True)
# - node_exec_id = unique per block execution
synthetic_graph_id = f"copilot-session-{session.session_id}"
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
synthetic_node_id = f"copilot-node-{block_id}"
synthetic_node_exec_id = (
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
)
# Create unified execution context with all required fields
execution_context = ExecutionContext(
# Execution identity
user_id=user_id,
graph_id=synthetic_graph_id,
graph_exec_id=synthetic_graph_exec_id,
graph_version=1, # Versions are 1-indexed
node_id=synthetic_node_id,
node_exec_id=synthetic_node_exec_id,
# Workspace with session scoping
workspace_id=workspace.id,
session_id=session.session_id,
)
# Prepare kwargs for block execution
# Keep individual kwargs for backwards compatibility with existing blocks
exec_kwargs: dict[str, Any] = {
"user_id": user_id,
"execution_context": ExecutionContext(),
"execution_context": execution_context,
# Legacy: individual kwargs for blocks not yet using execution_context
"workspace_id": workspace.id,
"graph_exec_id": synthetic_graph_exec_id,
"node_exec_id": synthetic_node_exec_id,
"node_id": synthetic_node_id,
"graph_version": 1, # Versions are 1-indexed
"graph_id": synthetic_graph_id,
}
for field_name, cred_meta in matched_credentials.items():

View File

@@ -0,0 +1,620 @@
"""CoPilot tools for workspace file operations."""
import base64
import logging
from typing import Any, Optional
from pydantic import BaseModel
from backend.api.features.chat.model import ChatSession
from backend.data.workspace import get_or_create_workspace
from backend.util.settings import Config
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import WorkspaceManager
from .base import BaseTool
from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
class WorkspaceFileInfoData(BaseModel):
"""Data model for workspace file information (not a response itself)."""
file_id: str
name: str
path: str
mime_type: str
size_bytes: int
class WorkspaceFileListResponse(ToolResponseBase):
"""Response containing list of workspace files."""
type: ResponseType = ResponseType.WORKSPACE_FILE_LIST
files: list[WorkspaceFileInfoData]
total_count: int
class WorkspaceFileContentResponse(ToolResponseBase):
"""Response containing workspace file content (legacy, for small text files)."""
type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT
file_id: str
name: str
path: str
mime_type: str
content_base64: str
class WorkspaceFileMetadataResponse(ToolResponseBase):
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA
file_id: str
name: str
path: str
mime_type: str
size_bytes: int
download_url: str
preview: str | None = None # First 500 chars for text files
class WorkspaceWriteResponse(ToolResponseBase):
"""Response after writing a file to workspace."""
type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN
file_id: str
name: str
path: str
size_bytes: int
class WorkspaceDeleteResponse(ToolResponseBase):
"""Response after deleting a file from workspace."""
type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED
file_id: str
success: bool
class ListWorkspaceFilesTool(BaseTool):
"""Tool for listing files in user's workspace."""
@property
def name(self) -> str:
return "list_workspace_files"
@property
def description(self) -> str:
return (
"List files in the user's workspace. "
"Returns file names, paths, sizes, and metadata. "
"Optionally filter by path prefix."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"path_prefix": {
"type": "string",
"description": (
"Optional path prefix to filter files "
"(e.g., '/documents/' to list only files in documents folder). "
"By default, only files from the current session are listed."
),
},
"limit": {
"type": "integer",
"description": "Maximum number of files to return (default 50, max 100)",
"minimum": 1,
"maximum": 100,
},
"include_all_sessions": {
"type": "boolean",
"description": (
"If true, list files from all sessions. "
"Default is false (only current session's files)."
),
},
},
"required": [],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
path_prefix: Optional[str] = kwargs.get("path_prefix")
limit = min(kwargs.get("limit", 50), 100)
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
files = await manager.list_files(
path=path_prefix,
limit=limit,
include_all_sessions=include_all_sessions,
)
total = await manager.get_file_count(
path=path_prefix,
include_all_sessions=include_all_sessions,
)
file_infos = [
WorkspaceFileInfoData(
file_id=f.id,
name=f.name,
path=f.path,
mime_type=f.mimeType,
size_bytes=f.sizeBytes,
)
for f in files
]
scope_msg = "all sessions" if include_all_sessions else "current session"
return WorkspaceFileListResponse(
files=file_infos,
total_count=total,
message=f"Found {len(files)} files in workspace ({scope_msg})",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error listing workspace files: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to list workspace files: {str(e)}",
error=str(e),
session_id=session_id,
)
class ReadWorkspaceFileTool(BaseTool):
"""Tool for reading file content from workspace."""
# Size threshold for returning full content vs metadata+URL
# Files larger than this return metadata with download URL to prevent context bloat
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
# Preview size for text files
PREVIEW_SIZE = 500
@property
def name(self) -> str:
return "read_workspace_file"
@property
def description(self) -> str:
return (
"Read a file from the user's workspace. "
"Specify either file_id or path to identify the file. "
"For small text files, returns content directly. "
"For large or binary files, returns metadata and a download URL. "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"file_id": {
"type": "string",
"description": "The file's unique ID (from list_workspace_files)",
},
"path": {
"type": "string",
"description": (
"The virtual file path (e.g., '/documents/report.pdf'). "
"Scoped to current session by default."
),
},
"force_download_url": {
"type": "boolean",
"description": (
"If true, always return metadata+URL instead of inline content. "
"Default is false (auto-selects based on file size/type)."
),
},
},
"required": [], # At least one must be provided
}
@property
def requires_auth(self) -> bool:
return True
def _is_text_mime_type(self, mime_type: str) -> bool:
"""Check if the MIME type is a text-based type."""
text_types = [
"text/",
"application/json",
"application/xml",
"application/javascript",
"application/x-python",
"application/x-sh",
]
return any(mime_type.startswith(t) for t in text_types)
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
force_download_url: bool = kwargs.get("force_download_url", False)
if not file_id and not path:
return ErrorResponse(
message="Please provide either file_id or path",
session_id=session_id,
)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
# Get file info
if file_id:
file_info = await manager.get_file_info(file_id)
if file_info is None:
return ErrorResponse(
message=f"File not found: {file_id}",
session_id=session_id,
)
target_file_id = file_id
else:
# path is guaranteed to be non-None here due to the check above
assert path is not None
file_info = await manager.get_file_info_by_path(path)
if file_info is None:
return ErrorResponse(
message=f"File not found at path: {path}",
session_id=session_id,
)
target_file_id = file_info.id
# Decide whether to return inline content or metadata+URL
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
is_text_file = self._is_text_mime_type(file_info.mimeType)
# Return inline content for small text files (unless force_download_url)
if is_small_file and is_text_file and not force_download_url:
content = await manager.read_file_by_id(target_file_id)
content_b64 = base64.b64encode(content).decode("utf-8")
return WorkspaceFileContentResponse(
file_id=file_info.id,
name=file_info.name,
path=file_info.path,
mime_type=file_info.mimeType,
content_base64=content_b64,
message=f"Successfully read file: {file_info.name}",
session_id=session_id,
)
# Return metadata + workspace:// reference for large or binary files
# This prevents context bloat (100KB file = ~133KB as base64)
# Use workspace:// format so frontend urlTransform can add proxy prefix
download_url = f"workspace://{target_file_id}"
# Generate preview for text files
preview: str | None = None
if is_text_file:
try:
content = await manager.read_file_by_id(target_file_id)
preview_text = content[: self.PREVIEW_SIZE].decode(
"utf-8", errors="replace"
)
if len(content) > self.PREVIEW_SIZE:
preview_text += "..."
preview = preview_text
except Exception:
pass # Preview is optional
return WorkspaceFileMetadataResponse(
file_id=file_info.id,
name=file_info.name,
path=file_info.path,
mime_type=file_info.mimeType,
size_bytes=file_info.sizeBytes,
download_url=download_url,
preview=preview,
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
session_id=session_id,
)
except FileNotFoundError as e:
return ErrorResponse(
message=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Error reading workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to read workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)
class WriteWorkspaceFileTool(BaseTool):
"""Tool for writing files to workspace."""
@property
def name(self) -> str:
return "write_workspace_file"
@property
def description(self) -> str:
return (
"Write or create a file in the user's workspace. "
"Provide the content as a base64-encoded string. "
f"Maximum file size is {Config().max_file_size_mb}MB. "
"Files are saved to the current session's folder by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "Name for the file (e.g., 'report.pdf')",
},
"content_base64": {
"type": "string",
"description": "Base64-encoded file content",
},
"path": {
"type": "string",
"description": (
"Optional virtual path where to save the file "
"(e.g., '/documents/report.pdf'). "
"Defaults to '/{filename}'. Scoped to current session."
),
},
"mime_type": {
"type": "string",
"description": (
"Optional MIME type of the file. "
"Auto-detected from filename if not provided."
),
},
"overwrite": {
"type": "boolean",
"description": "Whether to overwrite if file exists at path (default: false)",
},
},
"required": ["filename", "content_base64"],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
filename: str = kwargs.get("filename", "")
content_b64: str = kwargs.get("content_base64", "")
path: Optional[str] = kwargs.get("path")
mime_type: Optional[str] = kwargs.get("mime_type")
overwrite: bool = kwargs.get("overwrite", False)
if not filename:
return ErrorResponse(
message="Please provide a filename",
session_id=session_id,
)
if not content_b64:
return ErrorResponse(
message="Please provide content_base64",
session_id=session_id,
)
# Decode content
try:
content = base64.b64decode(content_b64)
except Exception:
return ErrorResponse(
message="Invalid base64-encoded content",
session_id=session_id,
)
# Check size
max_file_size = Config().max_file_size_mb * 1024 * 1024
if len(content) > max_file_size:
return ErrorResponse(
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
session_id=session_id,
)
try:
# Virus scan
await scan_content_safe(content, filename=filename)
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
file_record = await manager.write_file(
content=content,
filename=filename,
path=path,
mime_type=mime_type,
overwrite=overwrite,
)
return WorkspaceWriteResponse(
file_id=file_record.id,
name=file_record.name,
path=file_record.path,
size_bytes=file_record.sizeBytes,
message=f"Successfully wrote file: {file_record.name}",
session_id=session_id,
)
except ValueError as e:
return ErrorResponse(
message=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Error writing workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to write workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)
class DeleteWorkspaceFileTool(BaseTool):
"""Tool for deleting files from workspace."""
@property
def name(self) -> str:
return "delete_workspace_file"
@property
def description(self) -> str:
return (
"Delete a file from the user's workspace. "
"Specify either file_id or path to identify the file. "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"file_id": {
"type": "string",
"description": "The file's unique ID (from list_workspace_files)",
},
"path": {
"type": "string",
"description": (
"The virtual file path (e.g., '/documents/report.pdf'). "
"Scoped to current session by default."
),
},
},
"required": [], # At least one must be provided
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
if not file_id and not path:
return ErrorResponse(
message="Please provide either file_id or path",
session_id=session_id,
)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
# Determine the file_id to delete
target_file_id: str
if file_id:
target_file_id = file_id
else:
# path is guaranteed to be non-None here due to the check above
assert path is not None
file_info = await manager.get_file_info_by_path(path)
if file_info is None:
return ErrorResponse(
message=f"File not found at path: {path}",
session_id=session_id,
)
target_file_id = file_info.id
success = await manager.delete_file(target_file_id)
if not success:
return ErrorResponse(
message=f"File not found: {target_file_id}",
session_id=session_id,
)
return WorkspaceDeleteResponse(
file_id=target_file_id,
success=True,
message="File deleted successfully",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to delete workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)

View File

@@ -21,7 +21,7 @@ from backend.data.model import CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
from backend.util.clients import get_scheduler_client
from backend.util.exceptions import DatabaseError, NotFoundError
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
from backend.util.json import SafeJson
from backend.util.models import Pagination
from backend.util.settings import Config
@@ -64,11 +64,11 @@ async def list_library_agents(
if page < 1 or page_size < 1:
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
raise DatabaseError("Invalid pagination input")
raise InvalidInputError("Invalid pagination input")
if search_term and len(search_term.strip()) > 100:
logger.warning(f"Search term too long: {repr(search_term)}")
raise DatabaseError("Search term is too long")
raise InvalidInputError("Search term is too long")
where_clause: prisma.types.LibraryAgentWhereInput = {
"userId": user_id,
@@ -175,7 +175,7 @@ async def list_favorite_library_agents(
if page < 1 or page_size < 1:
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
raise DatabaseError("Invalid pagination input")
raise InvalidInputError("Invalid pagination input")
where_clause: prisma.types.LibraryAgentWhereInput = {
"userId": user_id,

View File

@@ -1,4 +1,3 @@
import logging
from typing import Literal, Optional
import autogpt_libs.auth as autogpt_auth_lib
@@ -6,15 +5,11 @@ from fastapi import APIRouter, Body, HTTPException, Query, Security, status
from fastapi.responses import Response
from prisma.enums import OnboardingStep
import backend.api.features.store.exceptions as store_exceptions
from backend.data.onboarding import complete_onboarding_step
from backend.util.exceptions import DatabaseError, NotFoundError
from .. import db as library_db
from .. import model as library_model
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/agents",
tags=["library", "private"],
@@ -26,10 +21,6 @@ router = APIRouter(
"",
summary="List Library Agents",
response_model=library_model.LibraryAgentResponse,
responses={
200: {"description": "List of library agents"},
500: {"description": "Server error", "content": {"application/json": {}}},
},
)
async def list_library_agents(
user_id: str = Security(autogpt_auth_lib.get_user_id),
@@ -53,43 +44,19 @@ async def list_library_agents(
) -> library_model.LibraryAgentResponse:
"""
Get all agents in the user's library (both created and saved).
Args:
user_id: ID of the authenticated user.
search_term: Optional search term to filter agents by name/description.
filter_by: List of filters to apply (favorites, created by user).
sort_by: List of sorting criteria (created date, updated date).
page: Page number to retrieve.
page_size: Number of agents per page.
Returns:
A LibraryAgentResponse containing agents and pagination metadata.
Raises:
HTTPException: If a server/database error occurs.
"""
try:
return await library_db.list_library_agents(
user_id=user_id,
search_term=search_term,
sort_by=sort_by,
page=page,
page_size=page_size,
)
except Exception as e:
logger.error(f"Could not list library agents for user #{user_id}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
) from e
return await library_db.list_library_agents(
user_id=user_id,
search_term=search_term,
sort_by=sort_by,
page=page,
page_size=page_size,
)
@router.get(
"/favorites",
summary="List Favorite Library Agents",
responses={
500: {"description": "Server error", "content": {"application/json": {}}},
},
)
async def list_favorite_library_agents(
user_id: str = Security(autogpt_auth_lib.get_user_id),
@@ -106,30 +73,12 @@ async def list_favorite_library_agents(
) -> library_model.LibraryAgentResponse:
"""
Get all favorite agents in the user's library.
Args:
user_id: ID of the authenticated user.
page: Page number to retrieve.
page_size: Number of agents per page.
Returns:
A LibraryAgentResponse containing favorite agents and pagination metadata.
Raises:
HTTPException: If a server/database error occurs.
"""
try:
return await library_db.list_favorite_library_agents(
user_id=user_id,
page=page,
page_size=page_size,
)
except Exception as e:
logger.error(f"Could not list favorite library agents for user #{user_id}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
) from e
return await library_db.list_favorite_library_agents(
user_id=user_id,
page=page,
page_size=page_size,
)
@router.get("/{library_agent_id}", summary="Get Library Agent")
@@ -162,10 +111,6 @@ async def get_library_agent_by_graph_id(
summary="Get Agent By Store ID",
tags=["store", "library"],
response_model=library_model.LibraryAgent | None,
responses={
200: {"description": "Library agent found"},
404: {"description": "Agent not found"},
},
)
async def get_library_agent_by_store_listing_version_id(
store_listing_version_id: str,
@@ -174,32 +119,15 @@ async def get_library_agent_by_store_listing_version_id(
"""
Get Library Agent from Store Listing Version ID.
"""
try:
return await library_db.get_library_agent_by_store_version_id(
store_listing_version_id, user_id
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
)
except Exception as e:
logger.error(f"Could not fetch library agent from store version ID: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
) from e
return await library_db.get_library_agent_by_store_version_id(
store_listing_version_id, user_id
)
@router.post(
"",
summary="Add Marketplace Agent",
status_code=status.HTTP_201_CREATED,
responses={
201: {"description": "Agent added successfully"},
404: {"description": "Store listing version not found"},
500: {"description": "Server error"},
},
)
async def add_marketplace_agent_to_library(
store_listing_version_id: str = Body(embed=True),
@@ -210,59 +138,19 @@ async def add_marketplace_agent_to_library(
) -> library_model.LibraryAgent:
"""
Add an agent from the marketplace to the user's library.
Args:
store_listing_version_id: ID of the store listing version to add.
user_id: ID of the authenticated user.
Returns:
library_model.LibraryAgent: Agent added to the library
Raises:
HTTPException(404): If the listing version is not found.
HTTPException(500): If a server/database error occurs.
"""
try:
agent = await library_db.add_store_agent_to_library(
store_listing_version_id=store_listing_version_id,
user_id=user_id,
)
if source != "onboarding":
await complete_onboarding_step(
user_id, OnboardingStep.MARKETPLACE_ADD_AGENT
)
return agent
except store_exceptions.AgentNotFoundError as e:
logger.warning(
f"Could not find store listing version {store_listing_version_id} "
"to add to library"
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except DatabaseError as e:
logger.error(f"Database error while adding agent to library: {e}", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Inspect DB logs for details."},
) from e
except Exception as e:
logger.error(f"Unexpected error while adding agent to library: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"message": str(e),
"hint": "Check server logs for more information.",
},
) from e
agent = await library_db.add_store_agent_to_library(
store_listing_version_id=store_listing_version_id,
user_id=user_id,
)
if source != "onboarding":
await complete_onboarding_step(user_id, OnboardingStep.MARKETPLACE_ADD_AGENT)
return agent
@router.patch(
"/{library_agent_id}",
summary="Update Library Agent",
responses={
200: {"description": "Agent updated successfully"},
500: {"description": "Server error"},
},
)
async def update_library_agent(
library_agent_id: str,
@@ -271,52 +159,21 @@ async def update_library_agent(
) -> library_model.LibraryAgent:
"""
Update the library agent with the given fields.
Args:
library_agent_id: ID of the library agent to update.
payload: Fields to update (auto_update_version, is_favorite, etc.).
user_id: ID of the authenticated user.
Raises:
HTTPException(500): If a server/database error occurs.
"""
try:
return await library_db.update_library_agent(
library_agent_id=library_agent_id,
user_id=user_id,
auto_update_version=payload.auto_update_version,
graph_version=payload.graph_version,
is_favorite=payload.is_favorite,
is_archived=payload.is_archived,
settings=payload.settings,
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
except DatabaseError as e:
logger.error(f"Database error while updating library agent: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Verify DB connection."},
) from e
except Exception as e:
logger.error(f"Unexpected error while updating library agent: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Check server logs."},
) from e
return await library_db.update_library_agent(
library_agent_id=library_agent_id,
user_id=user_id,
auto_update_version=payload.auto_update_version,
graph_version=payload.graph_version,
is_favorite=payload.is_favorite,
is_archived=payload.is_archived,
settings=payload.settings,
)
@router.delete(
"/{library_agent_id}",
summary="Delete Library Agent",
responses={
204: {"description": "Agent deleted successfully"},
404: {"description": "Agent not found"},
500: {"description": "Server error"},
},
)
async def delete_library_agent(
library_agent_id: str,
@@ -324,28 +181,11 @@ async def delete_library_agent(
) -> Response:
"""
Soft-delete the specified library agent.
Args:
library_agent_id: ID of the library agent to delete.
user_id: ID of the authenticated user.
Returns:
204 No Content if successful.
Raises:
HTTPException(404): If the agent does not exist.
HTTPException(500): If a server/database error occurs.
"""
try:
await library_db.delete_library_agent(
library_agent_id=library_agent_id, user_id=user_id
)
return Response(status_code=status.HTTP_204_NO_CONTENT)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
await library_db.delete_library_agent(
library_agent_id=library_agent_id, user_id=user_id
)
return Response(status_code=status.HTTP_204_NO_CONTENT)
@router.post("/{library_agent_id}/fork", summary="Fork Library Agent")

View File

@@ -118,21 +118,6 @@ async def test_get_library_agents_success(
)
def test_get_library_agents_error(mocker: pytest_mock.MockFixture, test_user_id: str):
mock_db_call = mocker.patch("backend.api.features.library.db.list_library_agents")
mock_db_call.side_effect = Exception("Test error")
response = client.get("/agents?search_term=test")
assert response.status_code == 500
mock_db_call.assert_called_once_with(
user_id=test_user_id,
search_term="test",
sort_by=library_model.LibraryAgentSort.UPDATED_AT,
page=1,
page_size=15,
)
@pytest.mark.asyncio
async def test_get_favorite_library_agents_success(
mocker: pytest_mock.MockFixture,
@@ -190,23 +175,6 @@ async def test_get_favorite_library_agents_success(
)
def test_get_favorite_library_agents_error(
mocker: pytest_mock.MockFixture, test_user_id: str
):
mock_db_call = mocker.patch(
"backend.api.features.library.db.list_favorite_library_agents"
)
mock_db_call.side_effect = Exception("Test error")
response = client.get("/agents/favorites")
assert response.status_code == 500
mock_db_call.assert_called_once_with(
user_id=test_user_id,
page=1,
page_size=15,
)
def test_add_agent_to_library_success(
mocker: pytest_mock.MockFixture, test_user_id: str
):
@@ -258,19 +226,3 @@ def test_add_agent_to_library_success(
store_listing_version_id="test-version-id", user_id=test_user_id
)
mock_complete_onboarding.assert_awaited_once()
def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture, test_user_id: str):
mock_db_call = mocker.patch(
"backend.api.features.library.db.add_store_agent_to_library"
)
mock_db_call.side_effect = Exception("Test error")
response = client.post(
"/agents", json={"store_listing_version_id": "test-version-id"}
)
assert response.status_code == 500
assert "detail" in response.json() # Verify error response structure
mock_db_call.assert_called_once_with(
store_listing_version_id="test-version-id", user_id=test_user_id
)

View File

@@ -454,6 +454,7 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
total_processed = 0
total_success = 0
total_failed = 0
all_errors: dict[str, int] = {} # Aggregate errors across all content types
# Process content types in explicit order
processing_order = [
@@ -499,23 +500,12 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
success = sum(1 for result in results if result is True)
failed = len(results) - success
# Aggregate unique errors to avoid Sentry spam
# Aggregate errors across all content types
if failed > 0:
# Group errors by type and message
error_summary: dict[str, int] = {}
for result in results:
if isinstance(result, Exception):
error_key = f"{type(result).__name__}: {str(result)}"
error_summary[error_key] = error_summary.get(error_key, 0) + 1
# Log aggregated error summary
error_details = ", ".join(
f"{error} ({count}x)" for error, count in error_summary.items()
)
logger.error(
f"{content_type.value}: {failed}/{len(results)} embeddings failed. "
f"Errors: {error_details}"
)
all_errors[error_key] = all_errors.get(error_key, 0) + 1
results_by_type[content_type.value] = {
"processed": len(missing_items),
@@ -542,6 +532,13 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
"error": str(e),
}
# Log aggregated errors once at the end
if all_errors:
error_details = ", ".join(
f"{error} ({count}x)" for error, count in all_errors.items()
)
logger.error(f"Embedding backfill errors: {error_details}")
return {
"by_type": results_by_type,
"totals": {

View File

@@ -261,14 +261,36 @@ async def get_onboarding_agents(
return await get_recommended_agents(user_id)
class OnboardingStatusResponse(pydantic.BaseModel):
"""Response for onboarding status check."""
is_onboarding_enabled: bool
is_chat_enabled: bool
@v1_router.get(
"/onboarding/enabled",
summary="Is onboarding enabled",
tags=["onboarding", "public"],
dependencies=[Security(requires_user)],
response_model=OnboardingStatusResponse,
)
async def is_onboarding_enabled() -> bool:
return await onboarding_enabled()
async def is_onboarding_enabled(
user_id: Annotated[str, Security(get_user_id)],
) -> OnboardingStatusResponse:
# Check if chat is enabled for user
is_chat_enabled = await is_feature_enabled(Flag.CHAT, user_id, False)
# If chat is enabled, skip legacy onboarding
if is_chat_enabled:
return OnboardingStatusResponse(
is_onboarding_enabled=False,
is_chat_enabled=True,
)
return OnboardingStatusResponse(
is_onboarding_enabled=await onboarding_enabled(),
is_chat_enabled=False,
)
@v1_router.post(

View File

@@ -0,0 +1 @@
# Workspace API feature module

View File

@@ -0,0 +1,122 @@
"""
Workspace API routes for managing user file storage.
"""
import logging
import re
from typing import Annotated
from urllib.parse import quote
import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
from fastapi.responses import Response
from backend.data.workspace import get_workspace, get_workspace_file
from backend.util.workspace_storage import get_workspace_storage
def _sanitize_filename_for_header(filename: str) -> str:
"""
Sanitize filename for Content-Disposition header to prevent header injection.
Removes/replaces characters that could break the header or inject new headers.
Uses RFC5987 encoding for non-ASCII characters.
"""
# Remove CR, LF, and null bytes (header injection prevention)
sanitized = re.sub(r"[\r\n\x00]", "", filename)
# Escape quotes
sanitized = sanitized.replace('"', '\\"')
# For non-ASCII, use RFC5987 filename* parameter
# Check if filename has non-ASCII characters
try:
sanitized.encode("ascii")
return f'attachment; filename="{sanitized}"'
except UnicodeEncodeError:
# Use RFC5987 encoding for UTF-8 filenames
encoded = quote(sanitized, safe="")
return f"attachment; filename*=UTF-8''{encoded}"
logger = logging.getLogger(__name__)
router = fastapi.APIRouter(
dependencies=[fastapi.Security(requires_user)],
)
def _create_streaming_response(content: bytes, file) -> Response:
"""Create a streaming response for file content."""
return Response(
content=content,
media_type=file.mimeType,
headers={
"Content-Disposition": _sanitize_filename_for_header(file.name),
"Content-Length": str(len(content)),
},
)
async def _create_file_download_response(file) -> Response:
"""
Create a download response for a workspace file.
Handles both local storage (direct streaming) and GCS (signed URL redirect
with fallback to streaming).
"""
storage = await get_workspace_storage()
# For local storage, stream the file directly
if file.storagePath.startswith("local://"):
content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file)
# For GCS, try to redirect to signed URL, fall back to streaming
try:
url = await storage.get_download_url(file.storagePath, expires_in=300)
# If we got back an API path (fallback), stream directly instead
if url.startswith("/api/"):
content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file)
return fastapi.responses.RedirectResponse(url=url, status_code=302)
except Exception as e:
# Log the signed URL failure with context
logger.error(
f"Failed to get signed URL for file {file.id} "
f"(storagePath={file.storagePath}): {e}",
exc_info=True,
)
# Fall back to streaming directly from GCS
try:
content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file)
except Exception as fallback_error:
logger.error(
f"Fallback streaming also failed for file {file.id} "
f"(storagePath={file.storagePath}): {fallback_error}",
exc_info=True,
)
raise
@router.get(
"/files/{file_id}/download",
summary="Download file by ID",
)
async def download_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
file_id: str,
) -> Response:
"""
Download a file by its ID.
Returns the file content directly or redirects to a signed URL for GCS.
"""
workspace = await get_workspace(user_id)
if workspace is None:
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
file = await get_workspace_file(file_id, workspace.id)
if file is None:
raise fastapi.HTTPException(status_code=404, detail="File not found")
return await _create_file_download_response(file)

View File

@@ -32,6 +32,7 @@ import backend.api.features.postmark.postmark
import backend.api.features.store.model
import backend.api.features.store.routes
import backend.api.features.v1
import backend.api.features.workspace.routes as workspace_routes
import backend.data.block
import backend.data.db
import backend.data.graph
@@ -52,6 +53,7 @@ from backend.util.exceptions import (
)
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
from backend.util.service import UnhealthyServiceError
from backend.util.workspace_storage import shutdown_workspace_storage
from .external.fastapi_app import external_api
from .features.analytics import router as analytics_router
@@ -124,6 +126,11 @@ async def lifespan_context(app: fastapi.FastAPI):
except Exception as e:
logger.warning(f"Error shutting down cloud storage handler: {e}")
try:
await shutdown_workspace_storage()
except Exception as e:
logger.warning(f"Error shutting down workspace storage: {e}")
await backend.data.db.disconnect()
@@ -315,6 +322,11 @@ app.include_router(
tags=["v2", "chat"],
prefix="/api/chat",
)
app.include_router(
workspace_routes.router,
tags=["workspace"],
prefix="/api/workspace",
)
app.include_router(
backend.api.features.oauth.router,
tags=["oauth"],

View File

@@ -13,6 +13,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -117,11 +118,13 @@ class AIImageCustomizerBlock(Block):
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
("image_url", "https://replicate.delivery/generated-image.jpg"),
# Output will be a workspace ref or data URI depending on context
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
],
test_mock={
# Use data URI to avoid HTTP requests during tests
"run_model": lambda *args, **kwargs: MediaFileType(
"https://replicate.delivery/generated-image.jpg"
"data:image/jpeg;base64,/9j/4AAQSkZJRgABAgAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAABAAEDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD3+iiigD//2Q=="
),
},
test_credentials=TEST_CREDENTIALS,
@@ -132,8 +135,7 @@ class AIImageCustomizerBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
try:
@@ -141,10 +143,9 @@ class AIImageCustomizerBlock(Block):
processed_images = await asyncio.gather(
*(
store_media_file(
graph_exec_id=graph_exec_id,
file=img,
user_id=user_id,
return_content=True,
execution_context=execution_context,
return_format="for_external_api", # Get content for Replicate API
)
for img in input_data.images
)
@@ -158,7 +159,14 @@ class AIImageCustomizerBlock(Block):
aspect_ratio=input_data.aspect_ratio.value,
output_format=input_data.output_format.value,
)
yield "image_url", result
# Store the generated image to the user's workspace for persistence
stored_url = await store_media_file(
file=result,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", stored_url
except Exception as e:
yield "error", str(e)

View File

@@ -6,6 +6,7 @@ from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -13,6 +14,8 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
class ImageSize(str, Enum):
@@ -165,11 +168,13 @@ class AIImageGeneratorBlock(Block):
test_output=[
(
"image_url",
"https://replicate.delivery/generated-image.webp",
# Test output is a data URI since we now store images
lambda x: x.startswith("data:image/"),
),
],
test_mock={
"_run_client": lambda *args, **kwargs: "https://replicate.delivery/generated-image.webp"
# Return a data URI directly so store_media_file doesn't need to download
"_run_client": lambda *args, **kwargs: "data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
},
)
@@ -318,11 +323,24 @@ class AIImageGeneratorBlock(Block):
style_text = style_map.get(style, "")
return f"{style_text} of" if style_text else ""
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
):
try:
url = await self.generate_image(input_data, credentials)
if url:
yield "image_url", url
# Store the generated image to the user's workspace/execution folder
stored_url = await store_media_file(
file=MediaFileType(url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", stored_url
else:
yield "error", "Image generation returned an empty result."
except Exception as e:

View File

@@ -13,6 +13,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -21,7 +22,9 @@ from backend.data.model import (
)
from backend.integrations.providers import ProviderName
from backend.util.exceptions import BlockExecutionError
from backend.util.file import store_media_file
from backend.util.request import Requests
from backend.util.type import MediaFileType
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@@ -271,7 +274,10 @@ class AIShortformVideoCreatorBlock(Block):
"voice": Voice.LILY,
"video_style": VisualMediaType.STOCK_VIDEOS,
},
test_output=("video_url", "https://example.com/video.mp4"),
test_output=(
"video_url",
lambda x: x.startswith(("workspace://", "data:")),
),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
@@ -280,15 +286,21 @@ class AIShortformVideoCreatorBlock(Block):
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "https://example.com/video.mp4",
"videoUrl": "data:video/mp4;base64,AAAA",
},
"wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4",
# Use data URI to avoid HTTP requests during tests
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# Create a new Webhook.site URL
webhook_token, webhook_url = await self.create_webhook()
@@ -340,7 +352,13 @@ class AIShortformVideoCreatorBlock(Block):
)
video_url = await self.wait_for_video(credentials.api_key, pid)
logger.debug(f"Video ready: {video_url}")
yield "video_url", video_url
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
class AIAdMakerVideoCreatorBlock(Block):
@@ -447,7 +465,10 @@ class AIAdMakerVideoCreatorBlock(Block):
"https://cdn.revid.ai/uploads/1747076315114-image.png",
],
},
test_output=("video_url", "https://example.com/ad.mp4"),
test_output=(
"video_url",
lambda x: x.startswith(("workspace://", "data:")),
),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
@@ -456,14 +477,21 @@ class AIAdMakerVideoCreatorBlock(Block):
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "https://example.com/ad.mp4",
"videoUrl": "data:video/mp4;base64,AAAA",
},
"wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4",
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
):
webhook_token, webhook_url = await self.create_webhook()
payload = {
@@ -531,7 +559,13 @@ class AIAdMakerVideoCreatorBlock(Block):
raise RuntimeError("Failed to create video: No project ID returned")
video_url = await self.wait_for_video(credentials.api_key, pid)
yield "video_url", video_url
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
class AIScreenshotToVideoAdBlock(Block):
@@ -626,7 +660,10 @@ class AIScreenshotToVideoAdBlock(Block):
"script": "Amazing numbers!",
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
},
test_output=("video_url", "https://example.com/screenshot.mp4"),
test_output=(
"video_url",
lambda x: x.startswith(("workspace://", "data:")),
),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
@@ -635,14 +672,21 @@ class AIScreenshotToVideoAdBlock(Block):
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "https://example.com/screenshot.mp4",
"videoUrl": "data:video/mp4;base64,AAAA",
},
"wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4",
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
):
webhook_token, webhook_url = await self.create_webhook()
payload = {
@@ -710,4 +754,10 @@ class AIScreenshotToVideoAdBlock(Block):
raise RuntimeError("Failed to create video: No project ID returned")
video_url = await self.wait_for_video(credentials.api_key, pid)
yield "video_url", video_url
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url

View File

@@ -6,6 +6,7 @@ if TYPE_CHECKING:
from pydantic import SecretStr
from backend.data.execution import ExecutionContext
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -17,6 +18,8 @@ from backend.sdk import (
Requests,
SchemaField,
)
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
from ._config import bannerbear
@@ -135,15 +138,17 @@ class BannerbearTextOverlayBlock(Block):
},
test_output=[
("success", True),
("image_url", "https://cdn.bannerbear.com/test-image.jpg"),
# Output will be a workspace ref or data URI depending on context
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
("uid", "test-uid-123"),
("status", "completed"),
],
test_mock={
# Use data URI to avoid HTTP requests during tests
"_make_api_request": lambda *args, **kwargs: {
"uid": "test-uid-123",
"status": "completed",
"image_url": "https://cdn.bannerbear.com/test-image.jpg",
"image_url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAABAAEBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+v//Z",
}
},
test_credentials=TEST_CREDENTIALS,
@@ -177,7 +182,12 @@ class BannerbearTextOverlayBlock(Block):
raise Exception(error_msg)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# Build the modifications array
modifications = []
@@ -234,6 +244,18 @@ class BannerbearTextOverlayBlock(Block):
# Synchronous request - image should be ready
yield "success", True
yield "image_url", data.get("image_url", "")
# Store the generated image to workspace for persistence
image_url = data.get("image_url", "")
if image_url:
stored_url = await store_media_file(
file=MediaFileType(image_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", stored_url
else:
yield "image_url", ""
yield "uid", data.get("uid", "")
yield "status", data.get("status", "completed")

View File

@@ -9,6 +9,7 @@ from backend.data.block import (
BlockSchemaOutput,
BlockType,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import store_media_file
from backend.util.type import MediaFileType, convert
@@ -17,10 +18,10 @@ from backend.util.type import MediaFileType, convert
class FileStoreBlock(Block):
class Input(BlockSchemaInput):
file_in: MediaFileType = SchemaField(
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
description="The file to download and store. Can be a URL (https://...), data URI, or local path."
)
base_64: bool = SchemaField(
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
description="Whether to produce output in base64 format (not recommended, you can pass the file reference across blocks).",
default=False,
advanced=True,
title="Produce Base64 Output",
@@ -28,13 +29,18 @@ class FileStoreBlock(Block):
class Output(BlockSchemaOutput):
file_out: MediaFileType = SchemaField(
description="The relative path to the stored file in the temporary directory."
description="Reference to the stored file. In CoPilot: workspace:// URI (visible in list_workspace_files). In graphs: data URI for passing to other blocks."
)
def __init__(self):
super().__init__(
id="cbb50872-625b-42f0-8203-a2ae78242d8a",
description="Stores the input file in the temporary directory.",
description=(
"Downloads and stores a file from a URL, data URI, or local path. "
"Use this to fetch images, documents, or other files for processing. "
"In CoPilot: saves to workspace (use list_workspace_files to see it). "
"In graphs: outputs a data URI to pass to other blocks."
),
categories={BlockCategory.BASIC, BlockCategory.MULTIMEDIA},
input_schema=FileStoreBlock.Input,
output_schema=FileStoreBlock.Output,
@@ -45,15 +51,18 @@ class FileStoreBlock(Block):
self,
input_data: Input,
*,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# Determine return format based on user preference
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
yield "file_out", await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.file_in,
user_id=user_id,
return_content=input_data.base_64,
execution_context=execution_context,
return_format=return_format,
)

View File

@@ -15,6 +15,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import APIKeyCredentials, SchemaField
from backend.util.file import store_media_file
from backend.util.request import Requests
@@ -666,8 +667,7 @@ class SendDiscordFileBlock(Block):
file: MediaFileType,
filename: str,
message_content: str,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
) -> dict:
intents = discord.Intents.default()
intents.guilds = True
@@ -731,10 +731,9 @@ class SendDiscordFileBlock(Block):
# Local file path - read from stored media file
# This would be a path from a previous block's output
stored_file = await store_media_file(
graph_exec_id=graph_exec_id,
file=file,
user_id=user_id,
return_content=True, # Get as data URI
execution_context=execution_context,
return_format="for_external_api", # Get content to send to Discord
)
# Now process as data URI
header, encoded = stored_file.split(",", 1)
@@ -781,8 +780,7 @@ class SendDiscordFileBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
try:
@@ -793,8 +791,7 @@ class SendDiscordFileBlock(Block):
file=input_data.file,
filename=input_data.filename,
message_content=input_data.message_content,
graph_exec_id=graph_exec_id,
user_id=user_id,
execution_context=execution_context,
)
yield "status", result.get("status", "Unknown error")

View File

@@ -17,8 +17,11 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import store_media_file
from backend.util.request import ClientResponseError, Requests
from backend.util.type import MediaFileType
logger = logging.getLogger(__name__)
@@ -64,9 +67,13 @@ class AIVideoGeneratorBlock(Block):
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("video_url", "https://fal.media/files/example/video.mp4")],
test_output=[
# Output will be a workspace ref or data URI depending on context
("video_url", lambda x: x.startswith(("workspace://", "data:"))),
],
test_mock={
"generate_video": lambda *args, **kwargs: "https://fal.media/files/example/video.mp4"
# Use data URI to avoid HTTP requests during tests
"generate_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA"
},
)
@@ -208,11 +215,22 @@ class AIVideoGeneratorBlock(Block):
raise RuntimeError(f"API request failed: {str(e)}")
async def run(
self, input_data: Input, *, credentials: FalCredentials, **kwargs
self,
input_data: Input,
*,
credentials: FalCredentials,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
try:
video_url = await self.generate_video(input_data, credentials)
yield "video_url", video_url
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
except Exception as e:
error_message = str(e)
yield "error", error_message

View File

@@ -12,6 +12,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -121,10 +122,12 @@ class AIImageEditorBlock(Block):
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
("output_image", "https://replicate.com/output/edited-image.png"),
# Output will be a workspace ref or data URI depending on context
("output_image", lambda x: x.startswith(("workspace://", "data:"))),
],
test_mock={
"run_model": lambda *args, **kwargs: "https://replicate.com/output/edited-image.png",
# Use data URI to avoid HTTP requests during tests
"run_model": lambda *args, **kwargs: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
},
test_credentials=TEST_CREDENTIALS,
)
@@ -134,8 +137,7 @@ class AIImageEditorBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
result = await self.run_model(
@@ -144,20 +146,25 @@ class AIImageEditorBlock(Block):
prompt=input_data.prompt,
input_image_b64=(
await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.input_image,
user_id=user_id,
return_content=True,
execution_context=execution_context,
return_format="for_external_api", # Get content for Replicate API
)
if input_data.input_image
else None
),
aspect_ratio=input_data.aspect_ratio.value,
seed=input_data.seed,
user_id=user_id,
graph_exec_id=graph_exec_id,
user_id=execution_context.user_id or "",
graph_exec_id=execution_context.graph_exec_id or "",
)
yield "output_image", result
# Store the generated image to the user's workspace for persistence
stored_url = await store_media_file(
file=result,
execution_context=execution_context,
return_format="for_block_output",
)
yield "output_image", stored_url
async def run_model(
self,

View File

@@ -21,6 +21,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
from backend.util.settings import Settings
@@ -95,8 +96,7 @@ def _make_mime_text(
async def create_mime_message(
input_data,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
) -> str:
"""Create a MIME message with attachments and return base64-encoded raw message."""
@@ -117,12 +117,12 @@ async def create_mime_message(
if input_data.attachments:
for attach in input_data.attachments:
local_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=attach,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
abs_path = get_exec_file_path(graph_exec_id, local_path)
assert execution_context.graph_exec_id # Validated by store_media_file
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
part = MIMEBase("application", "octet-stream")
with open(abs_path, "rb") as f:
part.set_payload(f.read())
@@ -582,27 +582,25 @@ class GmailSendBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
result = await self._send_email(
service,
input_data,
graph_exec_id,
user_id,
execution_context,
)
yield "result", result
async def _send_email(
self, service, input_data: Input, graph_exec_id: str, user_id: str
self, service, input_data: Input, execution_context: ExecutionContext
) -> dict:
if not input_data.to or not input_data.subject or not input_data.body:
raise ValueError(
"At least one recipient, subject, and body are required for sending an email"
)
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
raw_message = await create_mime_message(input_data, execution_context)
sent_message = await asyncio.to_thread(
lambda: service.users()
.messages()
@@ -692,30 +690,28 @@ class GmailCreateDraftBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
result = await self._create_draft(
service,
input_data,
graph_exec_id,
user_id,
execution_context,
)
yield "result", GmailDraftResult(
id=result["id"], message_id=result["message"]["id"], status="draft_created"
)
async def _create_draft(
self, service, input_data: Input, graph_exec_id: str, user_id: str
self, service, input_data: Input, execution_context: ExecutionContext
) -> dict:
if not input_data.to or not input_data.subject:
raise ValueError(
"At least one recipient and subject are required for creating a draft"
)
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
raw_message = await create_mime_message(input_data, execution_context)
draft = await asyncio.to_thread(
lambda: service.users()
.drafts()
@@ -1100,7 +1096,7 @@ class GmailGetThreadBlock(GmailBase):
async def _build_reply_message(
service, input_data, graph_exec_id: str, user_id: str
service, input_data, execution_context: ExecutionContext
) -> tuple[str, str]:
"""
Builds a reply MIME message for Gmail threads.
@@ -1190,12 +1186,12 @@ async def _build_reply_message(
# Handle attachments
for attach in input_data.attachments:
local_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=attach,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
abs_path = get_exec_file_path(graph_exec_id, local_path)
assert execution_context.graph_exec_id # Validated by store_media_file
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
part = MIMEBase("application", "octet-stream")
with open(abs_path, "rb") as f:
part.set_payload(f.read())
@@ -1311,16 +1307,14 @@ class GmailReplyBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
message = await self._reply(
service,
input_data,
graph_exec_id,
user_id,
execution_context,
)
yield "messageId", message["id"]
yield "threadId", message.get("threadId", input_data.threadId)
@@ -1343,11 +1337,11 @@ class GmailReplyBlock(GmailBase):
yield "email", email
async def _reply(
self, service, input_data: Input, graph_exec_id: str, user_id: str
self, service, input_data: Input, execution_context: ExecutionContext
) -> dict:
# Build the reply message using the shared helper
raw, thread_id = await _build_reply_message(
service, input_data, graph_exec_id, user_id
service, input_data, execution_context
)
# Send the message
@@ -1441,16 +1435,14 @@ class GmailDraftReplyBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
draft = await self._create_draft_reply(
service,
input_data,
graph_exec_id,
user_id,
execution_context,
)
yield "draftId", draft["id"]
yield "messageId", draft["message"]["id"]
@@ -1458,11 +1450,11 @@ class GmailDraftReplyBlock(GmailBase):
yield "status", "draft_created"
async def _create_draft_reply(
self, service, input_data: Input, graph_exec_id: str, user_id: str
self, service, input_data: Input, execution_context: ExecutionContext
) -> dict:
# Build the reply message using the shared helper
raw, thread_id = await _build_reply_message(
service, input_data, graph_exec_id, user_id
service, input_data, execution_context
)
# Create draft with proper thread association
@@ -1629,23 +1621,21 @@ class GmailForwardBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
result = await self._forward_message(
service,
input_data,
graph_exec_id,
user_id,
execution_context,
)
yield "messageId", result["id"]
yield "threadId", result.get("threadId", "")
yield "status", "forwarded"
async def _forward_message(
self, service, input_data: Input, graph_exec_id: str, user_id: str
self, service, input_data: Input, execution_context: ExecutionContext
) -> dict:
if not input_data.to:
raise ValueError("At least one recipient is required for forwarding")
@@ -1727,12 +1717,12 @@ To: {original_to}
# Add any additional attachments
for attach in input_data.additionalAttachments:
local_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=attach,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
abs_path = get_exec_file_path(graph_exec_id, local_path)
assert execution_context.graph_exec_id # Validated by store_media_file
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
part = MIMEBase("application", "octet-stream")
with open(abs_path, "rb") as f:
part.set_payload(f.read())

View File

@@ -15,6 +15,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
@@ -116,10 +117,9 @@ class SendWebRequestBlock(Block):
@staticmethod
async def _prepare_files(
graph_exec_id: str,
execution_context: ExecutionContext,
files_name: str,
files: list[MediaFileType],
user_id: str,
) -> list[tuple[str, tuple[str, BytesIO, str]]]:
"""
Prepare files for the request by storing them and reading their content.
@@ -127,11 +127,16 @@ class SendWebRequestBlock(Block):
(files_name, (filename, BytesIO, mime_type))
"""
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
graph_exec_id = execution_context.graph_exec_id
if graph_exec_id is None:
raise ValueError("graph_exec_id is required for file operations")
for media in files:
# Normalise to a list so we can repeat the same key
rel_path = await store_media_file(
graph_exec_id, media, user_id, return_content=False
file=media,
execution_context=execution_context,
return_format="for_local_processing",
)
abs_path = get_exec_file_path(graph_exec_id, rel_path)
async with aiofiles.open(abs_path, "rb") as f:
@@ -143,7 +148,7 @@ class SendWebRequestBlock(Block):
return files_payload
async def run(
self, input_data: Input, *, graph_exec_id: str, user_id: str, **kwargs
self, input_data: Input, *, execution_context: ExecutionContext, **kwargs
) -> BlockOutput:
# ─── Parse/normalise body ────────────────────────────────────
body = input_data.body
@@ -174,7 +179,7 @@ class SendWebRequestBlock(Block):
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
if use_files:
files_payload = await self._prepare_files(
graph_exec_id, input_data.files_name, input_data.files, user_id
execution_context, input_data.files_name, input_data.files
)
# Enforce body format rules
@@ -238,9 +243,8 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
self,
input_data: Input,
*,
graph_exec_id: str,
execution_context: ExecutionContext,
credentials: HostScopedCredentials,
user_id: str,
**kwargs,
) -> BlockOutput:
# Create SendWebRequestBlock.Input from our input (removing credentials field)
@@ -271,6 +275,6 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
# Use parent class run method
async for output_name, output_data in super().run(
base_input, graph_exec_id=graph_exec_id, user_id=user_id, **kwargs
base_input, execution_context=execution_context, **kwargs
):
yield output_name, output_data

View File

@@ -12,6 +12,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockType,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import store_media_file
from backend.util.mock import MockObject
@@ -462,18 +463,21 @@ class AgentFileInputBlock(AgentInputBlock):
self,
input_data: Input,
*,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
if not input_data.value:
return
# Determine return format based on user preference
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
yield "result", await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.value,
user_id=user_id,
return_content=input_data.base_64,
execution_context=execution_context,
return_format=return_format,
)

View File

@@ -1,6 +1,6 @@
import os
import tempfile
from typing import Literal, Optional
from typing import Optional
from moviepy.audio.io.AudioFileClip import AudioFileClip
from moviepy.video.fx.Loop import Loop
@@ -13,6 +13,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
@@ -46,18 +47,19 @@ class MediaDurationBlock(Block):
self,
input_data: Input,
*,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# 1) Store the input media locally
local_media_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.media_in,
user_id=user_id,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
assert execution_context.graph_exec_id is not None
media_abspath = get_exec_file_path(
execution_context.graph_exec_id, local_media_path
)
media_abspath = get_exec_file_path(graph_exec_id, local_media_path)
# 2) Load the clip
if input_data.is_video:
@@ -88,10 +90,6 @@ class LoopVideoBlock(Block):
default=None,
ge=1,
)
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
description="How to return the output video. Either a relative path or base64 data URI.",
default="file_path",
)
class Output(BlockSchemaOutput):
video_out: str = SchemaField(
@@ -111,17 +109,19 @@ class LoopVideoBlock(Block):
self,
input_data: Input,
*,
node_exec_id: str,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
assert execution_context.graph_exec_id is not None
assert execution_context.node_exec_id is not None
graph_exec_id = execution_context.graph_exec_id
node_exec_id = execution_context.node_exec_id
# 1) Store the input video locally
local_video_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.video_in,
user_id=user_id,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
@@ -149,12 +149,11 @@ class LoopVideoBlock(Block):
looped_clip = looped_clip.with_audio(clip.audio)
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
# Return as data URI
# Return output - for_block_output returns workspace:// if available, else data URI
video_out = await store_media_file(
graph_exec_id=graph_exec_id,
file=output_filename,
user_id=user_id,
return_content=input_data.output_return_type == "data_uri",
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_out", video_out
@@ -177,10 +176,6 @@ class AddAudioToVideoBlock(Block):
description="Volume scale for the newly attached audio track (1.0 = original).",
default=1.0,
)
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
description="Return the final output as a relative path or base64 data URI.",
default="file_path",
)
class Output(BlockSchemaOutput):
video_out: MediaFileType = SchemaField(
@@ -200,23 +195,24 @@ class AddAudioToVideoBlock(Block):
self,
input_data: Input,
*,
node_exec_id: str,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
assert execution_context.graph_exec_id is not None
assert execution_context.node_exec_id is not None
graph_exec_id = execution_context.graph_exec_id
node_exec_id = execution_context.node_exec_id
# 1) Store the inputs locally
local_video_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.video_in,
user_id=user_id,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
local_audio_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.audio_in,
user_id=user_id,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
@@ -240,12 +236,11 @@ class AddAudioToVideoBlock(Block):
output_abspath = os.path.join(abs_temp_dir, output_filename)
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
# 5) Return either path or data URI
# 5) Return output - for_block_output returns workspace:// if available, else data URI
video_out = await store_media_file(
graph_exec_id=graph_exec_id,
file=output_filename,
user_id=user_id,
return_content=input_data.output_return_type == "data_uri",
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_out", video_out

View File

@@ -11,6 +11,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -112,8 +113,7 @@ class ScreenshotWebPageBlock(Block):
@staticmethod
async def take_screenshot(
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
url: str,
viewport_width: int,
viewport_height: int,
@@ -155,12 +155,11 @@ class ScreenshotWebPageBlock(Block):
return {
"image": await store_media_file(
graph_exec_id=graph_exec_id,
file=MediaFileType(
f"data:image/{format.value};base64,{b64encode(content).decode('utf-8')}"
),
user_id=user_id,
return_content=True,
execution_context=execution_context,
return_format="for_block_output",
)
}
@@ -169,15 +168,13 @@ class ScreenshotWebPageBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
try:
screenshot_data = await self.take_screenshot(
credentials=credentials,
graph_exec_id=graph_exec_id,
user_id=user_id,
execution_context=execution_context,
url=input_data.url,
viewport_width=input_data.viewport_width,
viewport_height=input_data.viewport_height,

View File

@@ -7,6 +7,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import ContributorDetails, SchemaField
from backend.util.file import get_exec_file_path, store_media_file
from backend.util.type import MediaFileType
@@ -98,7 +99,7 @@ class ReadSpreadsheetBlock(Block):
)
async def run(
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
) -> BlockOutput:
import csv
from io import StringIO
@@ -106,14 +107,16 @@ class ReadSpreadsheetBlock(Block):
# Determine data source - prefer file_input if provided, otherwise use contents
if input_data.file_input:
stored_file_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=input_data.file_input,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
# Get full file path
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
assert execution_context.graph_exec_id # Validated by store_media_file
file_path = get_exec_file_path(
execution_context.graph_exec_id, stored_file_path
)
if not Path(file_path).exists():
raise ValueError(f"File does not exist: {file_path}")

View File

@@ -10,6 +10,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -17,7 +18,9 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.file import store_media_file
from backend.util.request import Requests
from backend.util.type import MediaFileType
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@@ -102,7 +105,7 @@ class CreateTalkingAvatarVideoBlock(Block):
test_output=[
(
"video_url",
"https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
lambda x: x.startswith(("workspace://", "data:")),
),
],
test_mock={
@@ -110,9 +113,10 @@ class CreateTalkingAvatarVideoBlock(Block):
"id": "abcd1234-5678-efgh-ijkl-mnopqrstuvwx",
"status": "created",
},
# Use data URI to avoid HTTP requests during tests
"get_clip_status": lambda *args, **kwargs: {
"status": "done",
"result_url": "https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
"result_url": "data:video/mp4;base64,AAAA",
},
},
test_credentials=TEST_CREDENTIALS,
@@ -138,7 +142,12 @@ class CreateTalkingAvatarVideoBlock(Block):
return response.json()
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# Create the clip
payload = {
@@ -165,7 +174,14 @@ class CreateTalkingAvatarVideoBlock(Block):
for _ in range(input_data.max_polling_attempts):
status_response = await self.get_clip_status(credentials.api_key, clip_id)
if status_response["status"] == "done":
yield "video_url", status_response["result_url"]
# Store the generated video to the user's workspace for persistence
video_url = status_response["result_url"]
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
return
elif status_response["status"] == "error":
raise RuntimeError(

View File

@@ -12,6 +12,7 @@ from backend.blocks.iteration import StepThroughItemsBlock
from backend.blocks.llm import AITextSummarizerBlock
from backend.blocks.text import ExtractTextInformationBlock
from backend.blocks.xml_parser import XMLParserBlock
from backend.data.execution import ExecutionContext
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
@@ -233,9 +234,12 @@ class TestStoreMediaFileSecurity:
with pytest.raises(ValueError, match="File too large"):
await store_media_file(
graph_exec_id="test",
file=MediaFileType(large_data_uri),
user_id="test_user",
execution_context=ExecutionContext(
user_id="test_user",
graph_exec_id="test",
),
return_format="for_local_processing",
)
@patch("backend.util.file.Path")
@@ -270,9 +274,12 @@ class TestStoreMediaFileSecurity:
# Should raise an error when directory size exceeds limit
with pytest.raises(ValueError, match="Disk usage limit exceeded"):
await store_media_file(
graph_exec_id="test",
file=MediaFileType(
"data:text/plain;base64,dGVzdA=="
), # Small test file
user_id="test_user",
execution_context=ExecutionContext(
user_id="test_user",
graph_exec_id="test",
),
return_format="for_local_processing",
)

View File

@@ -11,10 +11,22 @@ from backend.blocks.http import (
HttpMethod,
SendAuthenticatedWebRequestBlock,
)
from backend.data.execution import ExecutionContext
from backend.data.model import HostScopedCredentials
from backend.util.request import Response
def make_test_context(
graph_exec_id: str = "test-exec-id",
user_id: str = "test-user-id",
) -> ExecutionContext:
"""Helper to create test ExecutionContext."""
return ExecutionContext(
user_id=user_id,
graph_exec_id=graph_exec_id,
)
class TestHttpBlockWithHostScopedCredentials:
"""Test suite for HTTP block integration with HostScopedCredentials."""
@@ -105,8 +117,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=exact_match_credentials,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -161,8 +172,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=wildcard_credentials,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -208,8 +218,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=non_matching_credentials,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -258,8 +267,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=exact_match_credentials,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -318,8 +326,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=auto_discovered_creds, # Execution manager found these
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -382,8 +389,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=multi_header_creds,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -471,8 +477,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=test_creds,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))

View File

@@ -11,6 +11,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util import json, text
from backend.util.file import get_exec_file_path, store_media_file
@@ -444,18 +445,21 @@ class FileReadBlock(Block):
)
async def run(
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
) -> BlockOutput:
# Store the media file properly (handles URLs, data URIs, etc.)
stored_file_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=input_data.file_input,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
# Get full file path
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
# Get full file path (graph_exec_id validated by store_media_file above)
if not execution_context.graph_exec_id:
raise ValueError("execution_context.graph_exec_id is required")
file_path = get_exec_file_path(
execution_context.graph_exec_id, stored_file_path
)
if not Path(file_path).exists():
raise ValueError(f"File does not exist: {file_path}")

View File

@@ -83,12 +83,29 @@ class ExecutionContext(BaseModel):
model_config = {"extra": "ignore"}
# Execution identity
user_id: Optional[str] = None
graph_id: Optional[str] = None
graph_exec_id: Optional[str] = None
graph_version: Optional[int] = None
node_id: Optional[str] = None
node_exec_id: Optional[str] = None
# Safety settings
human_in_the_loop_safe_mode: bool = True
sensitive_action_safe_mode: bool = False
# User settings
user_timezone: str = "UTC"
# Execution hierarchy
root_execution_id: Optional[str] = None
parent_execution_id: Optional[str] = None
# Workspace
workspace_id: Optional[str] = None
session_id: Optional[str] = None
# -------------------------- Models -------------------------- #

View File

@@ -41,6 +41,7 @@ FrontendOnboardingStep = Literal[
OnboardingStep.AGENT_NEW_RUN,
OnboardingStep.AGENT_INPUT,
OnboardingStep.CONGRATS,
OnboardingStep.VISIT_COPILOT,
OnboardingStep.MARKETPLACE_VISIT,
OnboardingStep.BUILDER_OPEN,
]
@@ -122,6 +123,9 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
async def _reward_user(user_id: str, onboarding: UserOnboarding, step: OnboardingStep):
reward = 0
match step:
# Welcome bonus for visiting copilot ($5 = 500 credits)
case OnboardingStep.VISIT_COPILOT:
reward = 500
# Reward user when they clicked New Run during onboarding
# This is because they need credits before scheduling a run (next step)
# This is seen as a reward for the GET_RESULTS step in the wallet

View File

@@ -0,0 +1,276 @@
"""
Database CRUD operations for User Workspace.
This module provides functions for managing user workspaces and workspace files.
"""
import logging
from datetime import datetime, timezone
from typing import Optional
from prisma.models import UserWorkspace, UserWorkspaceFile
from prisma.types import UserWorkspaceFileWhereInput
from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
async def get_or_create_workspace(user_id: str) -> UserWorkspace:
"""
Get user's workspace, creating one if it doesn't exist.
Uses upsert to handle race conditions when multiple concurrent requests
attempt to create a workspace for the same user.
Args:
user_id: The user's ID
Returns:
UserWorkspace instance
"""
workspace = await UserWorkspace.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id},
"update": {}, # No updates needed if exists
},
)
return workspace
async def get_workspace(user_id: str) -> Optional[UserWorkspace]:
"""
Get user's workspace if it exists.
Args:
user_id: The user's ID
Returns:
UserWorkspace instance or None
"""
return await UserWorkspace.prisma().find_unique(where={"userId": user_id})
async def create_workspace_file(
workspace_id: str,
file_id: str,
name: str,
path: str,
storage_path: str,
mime_type: str,
size_bytes: int,
checksum: Optional[str] = None,
metadata: Optional[dict] = None,
) -> UserWorkspaceFile:
"""
Create a new workspace file record.
Args:
workspace_id: The workspace ID
file_id: The file ID (same as used in storage path for consistency)
name: User-visible filename
path: Virtual path (e.g., "/documents/report.pdf")
storage_path: Actual storage path (GCS or local)
mime_type: MIME type of the file
size_bytes: File size in bytes
checksum: Optional SHA256 checksum
metadata: Optional additional metadata
Returns:
Created UserWorkspaceFile instance
"""
# Normalize path to start with /
if not path.startswith("/"):
path = f"/{path}"
file = await UserWorkspaceFile.prisma().create(
data={
"id": file_id,
"workspaceId": workspace_id,
"name": name,
"path": path,
"storagePath": storage_path,
"mimeType": mime_type,
"sizeBytes": size_bytes,
"checksum": checksum,
"metadata": SafeJson(metadata or {}),
}
)
logger.info(
f"Created workspace file {file.id} at path {path} "
f"in workspace {workspace_id}"
)
return file
async def get_workspace_file(
file_id: str,
workspace_id: Optional[str] = None,
) -> Optional[UserWorkspaceFile]:
"""
Get a workspace file by ID.
Args:
file_id: The file ID
workspace_id: Optional workspace ID for validation
Returns:
UserWorkspaceFile instance or None
"""
where_clause: dict = {"id": file_id, "isDeleted": False}
if workspace_id:
where_clause["workspaceId"] = workspace_id
return await UserWorkspaceFile.prisma().find_first(where=where_clause)
async def get_workspace_file_by_path(
workspace_id: str,
path: str,
) -> Optional[UserWorkspaceFile]:
"""
Get a workspace file by its virtual path.
Args:
workspace_id: The workspace ID
path: Virtual path
Returns:
UserWorkspaceFile instance or None
"""
# Normalize path
if not path.startswith("/"):
path = f"/{path}"
return await UserWorkspaceFile.prisma().find_first(
where={
"workspaceId": workspace_id,
"path": path,
"isDeleted": False,
}
)
async def list_workspace_files(
workspace_id: str,
path_prefix: Optional[str] = None,
include_deleted: bool = False,
limit: Optional[int] = None,
offset: int = 0,
) -> list[UserWorkspaceFile]:
"""
List files in a workspace.
Args:
workspace_id: The workspace ID
path_prefix: Optional path prefix to filter (e.g., "/documents/")
include_deleted: Whether to include soft-deleted files
limit: Maximum number of files to return
offset: Number of files to skip
Returns:
List of UserWorkspaceFile instances
"""
where_clause: UserWorkspaceFileWhereInput = {"workspaceId": workspace_id}
if not include_deleted:
where_clause["isDeleted"] = False
if path_prefix:
# Normalize prefix
if not path_prefix.startswith("/"):
path_prefix = f"/{path_prefix}"
where_clause["path"] = {"startswith": path_prefix}
return await UserWorkspaceFile.prisma().find_many(
where=where_clause,
order={"createdAt": "desc"},
take=limit,
skip=offset,
)
async def count_workspace_files(
workspace_id: str,
path_prefix: Optional[str] = None,
include_deleted: bool = False,
) -> int:
"""
Count files in a workspace.
Args:
workspace_id: The workspace ID
path_prefix: Optional path prefix to filter (e.g., "/sessions/abc123/")
include_deleted: Whether to include soft-deleted files
Returns:
Number of files
"""
where_clause: dict = {"workspaceId": workspace_id}
if not include_deleted:
where_clause["isDeleted"] = False
if path_prefix:
# Normalize prefix
if not path_prefix.startswith("/"):
path_prefix = f"/{path_prefix}"
where_clause["path"] = {"startswith": path_prefix}
return await UserWorkspaceFile.prisma().count(where=where_clause)
async def soft_delete_workspace_file(
file_id: str,
workspace_id: Optional[str] = None,
) -> Optional[UserWorkspaceFile]:
"""
Soft-delete a workspace file.
The path is modified to include a deletion timestamp to free up the original
path for new files while preserving the record for potential recovery.
Args:
file_id: The file ID
workspace_id: Optional workspace ID for validation
Returns:
Updated UserWorkspaceFile instance or None if not found
"""
# First verify the file exists and belongs to workspace
file = await get_workspace_file(file_id, workspace_id)
if file is None:
return None
deleted_at = datetime.now(timezone.utc)
# Modify path to free up the unique constraint for new files at original path
# Format: {original_path}__deleted__{timestamp}
deleted_path = f"{file.path}__deleted__{int(deleted_at.timestamp())}"
updated = await UserWorkspaceFile.prisma().update(
where={"id": file_id},
data={
"isDeleted": True,
"deletedAt": deleted_at,
"path": deleted_path,
},
)
logger.info(f"Soft-deleted workspace file {file_id}")
return updated
async def get_workspace_total_size(workspace_id: str) -> int:
"""
Get the total size of all files in a workspace.
Args:
workspace_id: The workspace ID
Returns:
Total size in bytes
"""
files = await list_workspace_files(workspace_id)
return sum(file.sizeBytes for file in files)

View File

@@ -236,7 +236,14 @@ async def execute_node(
input_size = len(input_data_str)
log_metadata.debug("Executed node with input", input=input_data_str)
# Create node-specific execution context to avoid race conditions
# (multiple nodes can execute concurrently and would otherwise mutate shared state)
execution_context = execution_context.model_copy(
update={"node_id": node_id, "node_exec_id": node_exec_id}
)
# Inject extra execution arguments for the blocks via kwargs
# Keep individual kwargs for backwards compatibility with existing blocks
extra_exec_kwargs: dict = {
"graph_id": graph_id,
"graph_version": graph_version,

View File

@@ -892,11 +892,19 @@ async def add_graph_execution(
settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id)
execution_context = ExecutionContext(
# Execution identity
user_id=user_id,
graph_id=graph_id,
graph_exec_id=graph_exec.id,
graph_version=graph_exec.graph_version,
# Safety settings
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
# User settings
user_timezone=(
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
),
# Execution hierarchy
root_execution_id=graph_exec.id,
)

View File

@@ -348,6 +348,7 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
mock_graph_exec.id = "execution-id-123"
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
mock_graph_exec.graph_version = graph_version
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
# Mock the queue and event bus
@@ -434,6 +435,9 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
# Create a second mock execution for the sanity check
mock_graph_exec_2 = mocker.MagicMock(spec=GraphExecutionWithNodes)
mock_graph_exec_2.id = "execution-id-456"
mock_graph_exec_2.node_executions = []
mock_graph_exec_2.status = ExecutionStatus.QUEUED
mock_graph_exec_2.graph_version = graph_version
mock_graph_exec_2.to_graph_execution_entry.return_value = mocker.MagicMock()
# Reset mocks and set up for second call
@@ -614,6 +618,7 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
mock_graph_exec.id = "execution-id-123"
mock_graph_exec.node_executions = []
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
mock_graph_exec.graph_version = graph_version
# Track what's passed to to_graph_execution_entry
captured_kwargs = {}

View File

@@ -13,6 +13,7 @@ import aiohttp
from gcloud.aio import storage as async_gcs_storage
from google.cloud import storage as gcs_storage
from backend.util.gcs_utils import download_with_fresh_session, generate_signed_url
from backend.util.settings import Config
logger = logging.getLogger(__name__)
@@ -251,7 +252,7 @@ class CloudStorageHandler:
f"in_task: {current_task is not None}"
)
# Parse bucket and blob name from path
# Parse bucket and blob name from path (path already has gcs:// prefix removed)
parts = path.split("/", 1)
if len(parts) != 2:
raise ValueError(f"Invalid GCS path: {path}")
@@ -261,50 +262,19 @@ class CloudStorageHandler:
# Authorization check
self._validate_file_access(blob_name, user_id, graph_exec_id)
# Use a fresh client for each download to avoid session issues
# This is less efficient but more reliable with the executor's event loop
logger.info("[CloudStorage] Creating fresh GCS client for download")
# Create a new session specifically for this download
session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=10, force_close=True)
logger.info(
f"[CloudStorage] About to download from GCS - bucket: {bucket_name}, blob: {blob_name}"
)
async_client = None
try:
# Create a new GCS client with the fresh session
async_client = async_gcs_storage.Storage(session=session)
logger.info(
f"[CloudStorage] About to download from GCS - bucket: {bucket_name}, blob: {blob_name}"
)
# Download content using the fresh client
content = await async_client.download(bucket_name, blob_name)
content = await download_with_fresh_session(bucket_name, blob_name)
logger.info(
f"[CloudStorage] GCS download successful - size: {len(content)} bytes"
)
# Clean up
await async_client.close()
await session.close()
return content
except FileNotFoundError:
raise
except Exception as e:
# Always try to clean up
if async_client is not None:
try:
await async_client.close()
except Exception as cleanup_error:
logger.warning(
f"[CloudStorage] Error closing GCS client: {cleanup_error}"
)
try:
await session.close()
except Exception as cleanup_error:
logger.warning(f"[CloudStorage] Error closing session: {cleanup_error}")
# Log the specific error for debugging
logger.error(
f"[CloudStorage] GCS download failed - error: {str(e)}, "
@@ -319,10 +289,6 @@ class CloudStorageHandler:
f"current_task: {current_task}, "
f"bucket: {bucket_name}, blob: redacted for privacy"
)
# Convert gcloud-aio exceptions to standard ones
if "404" in str(e) or "Not Found" in str(e):
raise FileNotFoundError(f"File not found: gcs://{path}")
raise
def _validate_file_access(
@@ -445,8 +411,7 @@ class CloudStorageHandler:
graph_exec_id: str | None = None,
) -> str:
"""Generate signed URL for GCS with authorization."""
# Parse bucket and blob name from path
# Parse bucket and blob name from path (path already has gcs:// prefix removed)
parts = path.split("/", 1)
if len(parts) != 2:
raise ValueError(f"Invalid GCS path: {path}")
@@ -456,21 +421,11 @@ class CloudStorageHandler:
# Authorization check
self._validate_file_access(blob_name, user_id, graph_exec_id)
# Use sync client for signed URLs since gcloud-aio doesn't support them
sync_client = self._get_sync_gcs_client()
bucket = sync_client.bucket(bucket_name)
blob = bucket.blob(blob_name)
# Generate signed URL asynchronously using sync client
url = await asyncio.to_thread(
blob.generate_signed_url,
version="v4",
expiration=datetime.now(timezone.utc) + timedelta(hours=expiration_hours),
method="GET",
return await generate_signed_url(
sync_client, bucket_name, blob_name, expiration_hours * 3600
)
return url
async def delete_expired_files(self, provider: str = "gcs") -> int:
"""
Delete files that have passed their expiration time.

View File

@@ -135,6 +135,12 @@ class GraphValidationError(ValueError):
)
class InvalidInputError(ValueError):
"""Raised when user input validation fails (e.g., search term too long)"""
pass
class DatabaseError(Exception):
"""Raised when there is an error interacting with the database"""

View File

@@ -5,13 +5,26 @@ import shutil
import tempfile
import uuid
from pathlib import Path
from typing import TYPE_CHECKING, Literal
from urllib.parse import urlparse
from backend.util.cloud_storage import get_cloud_storage_handler
from backend.util.request import Requests
from backend.util.settings import Config
from backend.util.type import MediaFileType
from backend.util.virus_scanner import scan_content_safe
if TYPE_CHECKING:
from backend.data.execution import ExecutionContext
# Return format options for store_media_file
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
# - "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs
MediaReturnFormat = Literal[
"for_local_processing", "for_external_api", "for_block_output"
]
TEMP_DIR = Path(tempfile.gettempdir()).resolve()
# Maximum filename length (conservative limit for most filesystems)
@@ -67,42 +80,56 @@ def clean_exec_files(graph_exec_id: str, file: str = "") -> None:
async def store_media_file(
graph_exec_id: str,
file: MediaFileType,
user_id: str,
return_content: bool = False,
execution_context: "ExecutionContext",
*,
return_format: MediaReturnFormat,
) -> MediaFileType:
"""
Safely handle 'file' (a data URI, a URL, or a local path relative to {temp}/exec_file/{exec_id}),
placing or verifying it under:
Safely handle 'file' (a data URI, a URL, a workspace:// reference, or a local path
relative to {temp}/exec_file/{exec_id}), placing or verifying it under:
{tempdir}/exec_file/{exec_id}/...
If 'return_content=True', return a data URI (data:<mime>;base64,<content>).
Otherwise, returns the file media path relative to the exec_id folder.
For each MediaFileType input:
- Data URI: decode and store locally
- URL: download and store locally
- workspace:// reference: read from workspace, store locally
- Local path: verify it exists in exec_file directory
For each MediaFileType type:
- Data URI:
-> decode and store in a new random file in that folder
- URL:
-> download and store in that folder
- Local path:
-> interpret as relative to that folder; verify it exists
(no copying, as it's presumably already there).
We realpath-check so no symlink or '..' can escape the folder.
Return format options:
- "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
- "for_external_api": Returns data URI (base64) - use when sending to external APIs
- "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs
:param graph_exec_id: The unique ID of the graph execution.
:param file: Data URI, URL, or local (relative) path.
:param return_content: If True, return a data URI of the file content.
If False, return the *relative* path inside the exec_id folder.
:return: The requested result: data URI or relative path of the media.
:param file: Data URI, URL, workspace://, or local (relative) path.
:param execution_context: ExecutionContext with user_id, graph_exec_id, workspace_id.
:param return_format: What to return: "for_local_processing", "for_external_api", or "for_block_output".
:return: The requested result based on return_format.
"""
# Extract values from execution_context
graph_exec_id = execution_context.graph_exec_id
user_id = execution_context.user_id
if not graph_exec_id:
raise ValueError("execution_context.graph_exec_id is required")
if not user_id:
raise ValueError("execution_context.user_id is required")
# Create workspace_manager if we have workspace_id (with session scoping)
# Import here to avoid circular import (file.py → workspace.py → data → blocks → file.py)
from backend.util.workspace import WorkspaceManager
workspace_manager: WorkspaceManager | None = None
if execution_context.workspace_id:
workspace_manager = WorkspaceManager(
user_id, execution_context.workspace_id, execution_context.session_id
)
# Build base path
base_path = Path(get_exec_file_path(graph_exec_id, ""))
base_path.mkdir(parents=True, exist_ok=True)
# Security fix: Add disk space limits to prevent DoS
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB per file
MAX_FILE_SIZE_BYTES = Config().max_file_size_mb * 1024 * 1024
MAX_TOTAL_DISK_USAGE = 1024 * 1024 * 1024 # 1GB total per execution directory
# Check total disk usage in base_path
@@ -142,9 +169,57 @@ async def store_media_file(
"""
return str(absolute_path.relative_to(base))
# Check if this is a cloud storage path
# Get cloud storage handler for checking cloud paths
cloud_storage = await get_cloud_storage_handler()
if cloud_storage.is_cloud_path(file):
# Track if the input came from workspace (don't re-save it)
is_from_workspace = file.startswith("workspace://")
# Check if this is a workspace file reference
if is_from_workspace:
if workspace_manager is None:
raise ValueError(
"Workspace file reference requires workspace context. "
"This file type is only available in CoPilot sessions."
)
# Parse workspace reference
# workspace://abc123 - by file ID
# workspace:///path/to/file.txt - by virtual path
file_ref = file[12:] # Remove "workspace://"
if file_ref.startswith("/"):
# Path reference
workspace_content = await workspace_manager.read_file(file_ref)
file_info = await workspace_manager.get_file_info_by_path(file_ref)
filename = sanitize_filename(
file_info.name if file_info else f"{uuid.uuid4()}.bin"
)
else:
# ID reference
workspace_content = await workspace_manager.read_file_by_id(file_ref)
file_info = await workspace_manager.get_file_info(file_ref)
filename = sanitize_filename(
file_info.name if file_info else f"{uuid.uuid4()}.bin"
)
try:
target_path = _ensure_inside_base(base_path / filename, base_path)
except OSError as e:
raise ValueError(f"Invalid file path '{filename}': {e}") from e
# Check file size limit
if len(workspace_content) > MAX_FILE_SIZE_BYTES:
raise ValueError(
f"File too large: {len(workspace_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
)
# Virus scan the workspace content before writing locally
await scan_content_safe(workspace_content, filename=filename)
target_path.write_bytes(workspace_content)
# Check if this is a cloud storage path
elif cloud_storage.is_cloud_path(file):
# Download from cloud storage and store locally
cloud_content = await cloud_storage.retrieve_file(
file, user_id=user_id, graph_exec_id=graph_exec_id
@@ -159,9 +234,9 @@ async def store_media_file(
raise ValueError(f"Invalid file path '{filename}': {e}") from e
# Check file size limit
if len(cloud_content) > MAX_FILE_SIZE:
if len(cloud_content) > MAX_FILE_SIZE_BYTES:
raise ValueError(
f"File too large: {len(cloud_content)} bytes > {MAX_FILE_SIZE} bytes"
f"File too large: {len(cloud_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
)
# Virus scan the cloud content before writing locally
@@ -189,9 +264,9 @@ async def store_media_file(
content = base64.b64decode(b64_content)
# Check file size limit
if len(content) > MAX_FILE_SIZE:
if len(content) > MAX_FILE_SIZE_BYTES:
raise ValueError(
f"File too large: {len(content)} bytes > {MAX_FILE_SIZE} bytes"
f"File too large: {len(content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
)
# Virus scan the base64 content before writing
@@ -199,23 +274,31 @@ async def store_media_file(
target_path.write_bytes(content)
elif file.startswith(("http://", "https://")):
# URL
# URL - download first to get Content-Type header
resp = await Requests().get(file)
# Check file size limit
if len(resp.content) > MAX_FILE_SIZE_BYTES:
raise ValueError(
f"File too large: {len(resp.content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
)
# Extract filename from URL path
parsed_url = urlparse(file)
filename = sanitize_filename(Path(parsed_url.path).name or f"{uuid.uuid4()}")
# If filename lacks extension, add one from Content-Type header
if "." not in filename:
content_type = resp.headers.get("Content-Type", "").split(";")[0].strip()
if content_type:
ext = _extension_from_mime(content_type)
filename = f"{filename}{ext}"
try:
target_path = _ensure_inside_base(base_path / filename, base_path)
except OSError as e:
raise ValueError(f"Invalid file path '{filename}': {e}") from e
# Download and save
resp = await Requests().get(file)
# Check file size limit
if len(resp.content) > MAX_FILE_SIZE:
raise ValueError(
f"File too large: {len(resp.content)} bytes > {MAX_FILE_SIZE} bytes"
)
# Virus scan the downloaded content before writing
await scan_content_safe(resp.content, filename=filename)
target_path.write_bytes(resp.content)
@@ -230,12 +313,44 @@ async def store_media_file(
if not target_path.is_file():
raise ValueError(f"Local file does not exist: {target_path}")
# Return result
if return_content:
return MediaFileType(_file_to_data_uri(target_path))
else:
# Return based on requested format
if return_format == "for_local_processing":
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL
# Returns: relative path in exec_file directory (e.g., "image.png")
return MediaFileType(_strip_base_prefix(target_path, base_path))
elif return_format == "for_external_api":
# Use when sending content to external APIs that need base64
# Returns: data URI (e.g., "data:image/png;base64,iVBORw0...")
return MediaFileType(_file_to_data_uri(target_path))
elif return_format == "for_block_output":
# Use when returning output from a block to user/next block
# Returns: workspace:// ref (CoPilot) or data URI (graph execution)
if workspace_manager is None:
# No workspace available (graph execution without CoPilot)
# Fallback to data URI so the content can still be used/displayed
return MediaFileType(_file_to_data_uri(target_path))
# Don't re-save if input was already from workspace
if is_from_workspace:
# Return original workspace reference
return MediaFileType(file)
# Save new content to workspace
content = target_path.read_bytes()
filename = target_path.name
file_record = await workspace_manager.write_file(
content=content,
filename=filename,
overwrite=True,
)
return MediaFileType(f"workspace://{file_record.id}")
else:
raise ValueError(f"Invalid return_format: {return_format}")
def get_dir_size(path: Path) -> int:
"""Get total size of directory."""

View File

@@ -7,10 +7,22 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.data.execution import ExecutionContext
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
def make_test_context(
graph_exec_id: str = "test-exec-123",
user_id: str = "test-user-123",
) -> ExecutionContext:
"""Helper to create test ExecutionContext."""
return ExecutionContext(
user_id=user_id,
graph_exec_id=graph_exec_id,
)
class TestFileCloudIntegration:
"""Test cases for cloud storage integration in file utilities."""
@@ -70,10 +82,9 @@ class TestFileCloudIntegration:
mock_path_class.side_effect = path_constructor
result = await store_media_file(
graph_exec_id,
MediaFileType(cloud_path),
"test-user-123",
return_content=False,
file=MediaFileType(cloud_path),
execution_context=make_test_context(graph_exec_id=graph_exec_id),
return_format="for_local_processing",
)
# Verify cloud storage operations
@@ -144,10 +155,9 @@ class TestFileCloudIntegration:
mock_path_obj.name = "image.png"
with patch("backend.util.file.Path", return_value=mock_path_obj):
result = await store_media_file(
graph_exec_id,
MediaFileType(cloud_path),
"test-user-123",
return_content=True,
file=MediaFileType(cloud_path),
execution_context=make_test_context(graph_exec_id=graph_exec_id),
return_format="for_external_api",
)
# Verify result is a data URI
@@ -198,10 +208,9 @@ class TestFileCloudIntegration:
mock_resolved_path.relative_to.return_value = Path("test-uuid-789.txt")
await store_media_file(
graph_exec_id,
MediaFileType(data_uri),
"test-user-123",
return_content=False,
file=MediaFileType(data_uri),
execution_context=make_test_context(graph_exec_id=graph_exec_id),
return_format="for_local_processing",
)
# Verify cloud handler was checked but not used for retrieval
@@ -234,5 +243,7 @@ class TestFileCloudIntegration:
FileNotFoundError, match="File not found in cloud storage"
):
await store_media_file(
graph_exec_id, MediaFileType(cloud_path), "test-user-123"
file=MediaFileType(cloud_path),
execution_context=make_test_context(graph_exec_id=graph_exec_id),
return_format="for_local_processing",
)

View File

@@ -0,0 +1,108 @@
"""
Shared GCS utilities for workspace and cloud storage backends.
This module provides common functionality for working with Google Cloud Storage,
including path parsing, client management, and signed URL generation.
"""
import asyncio
import logging
from datetime import datetime, timedelta, timezone
import aiohttp
from gcloud.aio import storage as async_gcs_storage
from google.cloud import storage as gcs_storage
logger = logging.getLogger(__name__)
def parse_gcs_path(path: str) -> tuple[str, str]:
"""
Parse a GCS path in the format 'gcs://bucket/blob' to (bucket, blob).
Args:
path: GCS path string (e.g., "gcs://my-bucket/path/to/file")
Returns:
Tuple of (bucket_name, blob_name)
Raises:
ValueError: If the path format is invalid
"""
if not path.startswith("gcs://"):
raise ValueError(f"Invalid GCS path: {path}")
path_without_prefix = path[6:] # Remove "gcs://"
parts = path_without_prefix.split("/", 1)
if len(parts) != 2:
raise ValueError(f"Invalid GCS path format: {path}")
return parts[0], parts[1]
async def download_with_fresh_session(bucket: str, blob: str) -> bytes:
"""
Download file content using a fresh session.
This approach avoids event loop issues that can occur when reusing
sessions across different async contexts (e.g., in executors).
Args:
bucket: GCS bucket name
blob: Blob path within the bucket
Returns:
File content as bytes
Raises:
FileNotFoundError: If the file doesn't exist
"""
session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=10, force_close=True)
)
client: async_gcs_storage.Storage | None = None
try:
client = async_gcs_storage.Storage(session=session)
content = await client.download(bucket, blob)
return content
except Exception as e:
if "404" in str(e) or "Not Found" in str(e):
raise FileNotFoundError(f"File not found: gcs://{bucket}/{blob}")
raise
finally:
if client:
try:
await client.close()
except Exception:
pass # Best-effort cleanup
await session.close()
async def generate_signed_url(
sync_client: gcs_storage.Client,
bucket_name: str,
blob_name: str,
expires_in: int,
) -> str:
"""
Generate a signed URL for temporary access to a GCS file.
Uses asyncio.to_thread() to run the sync operation without blocking.
Args:
sync_client: Sync GCS client with service account credentials
bucket_name: GCS bucket name
blob_name: Blob path within the bucket
expires_in: URL expiration time in seconds
Returns:
Signed URL string
"""
bucket = sync_client.bucket(bucket_name)
blob = bucket.blob(blob_name)
return await asyncio.to_thread(
blob.generate_signed_url,
version="v4",
expiration=datetime.now(timezone.utc) + timedelta(seconds=expires_in),
method="GET",
)

View File

@@ -263,6 +263,12 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="The name of the Google Cloud Storage bucket for media files",
)
workspace_storage_dir: str = Field(
default="",
description="Local directory for workspace file storage when GCS is not configured. "
"If empty, defaults to {app_data}/workspaces. Used for self-hosted deployments.",
)
reddit_user_agent: str = Field(
default="web:AutoGPT:v0.6.0 (by /u/autogpt)",
description="The user agent for the Reddit API",
@@ -359,8 +365,8 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="The port for the Agent Generator service",
)
agentgenerator_timeout: int = Field(
default=120,
description="The timeout in seconds for Agent Generator service requests",
default=600,
description="The timeout in seconds for Agent Generator service requests (includes retries for rate limits)",
)
enable_example_blocks: bool = Field(
@@ -389,6 +395,13 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="Maximum file size in MB for file uploads (1-1024 MB)",
)
max_file_size_mb: int = Field(
default=100,
ge=1,
le=1024,
description="Maximum file size in MB for workspace files (1-1024 MB)",
)
# AutoMod configuration
automod_enabled: bool = Field(
default=False,

View File

@@ -140,14 +140,29 @@ async def execute_block_test(block: Block):
setattr(block, mock_name, mock_obj)
# Populate credentials argument(s)
# Generate IDs for execution context
graph_id = str(uuid.uuid4())
node_id = str(uuid.uuid4())
graph_exec_id = str(uuid.uuid4())
node_exec_id = str(uuid.uuid4())
user_id = str(uuid.uuid4())
graph_version = 1 # Default version for tests
extra_exec_kwargs: dict = {
"graph_id": str(uuid.uuid4()),
"node_id": str(uuid.uuid4()),
"graph_exec_id": str(uuid.uuid4()),
"node_exec_id": str(uuid.uuid4()),
"user_id": str(uuid.uuid4()),
"graph_version": 1, # Default version for tests
"execution_context": ExecutionContext(),
"graph_id": graph_id,
"node_id": node_id,
"graph_exec_id": graph_exec_id,
"node_exec_id": node_exec_id,
"user_id": user_id,
"graph_version": graph_version,
"execution_context": ExecutionContext(
user_id=user_id,
graph_id=graph_id,
graph_exec_id=graph_exec_id,
graph_version=graph_version,
node_id=node_id,
node_exec_id=node_exec_id,
),
}
input_model = cast(type[BlockSchema], block.input_schema)

View File

@@ -0,0 +1,419 @@
"""
WorkspaceManager for managing user workspace file operations.
This module provides a high-level interface for workspace file operations,
combining the storage backend and database layer.
"""
import logging
import mimetypes
import uuid
from typing import Optional
from prisma.errors import UniqueViolationError
from prisma.models import UserWorkspaceFile
from backend.data.workspace import (
count_workspace_files,
create_workspace_file,
get_workspace_file,
get_workspace_file_by_path,
list_workspace_files,
soft_delete_workspace_file,
)
from backend.util.settings import Config
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
logger = logging.getLogger(__name__)
class WorkspaceManager:
"""
Manages workspace file operations.
Combines storage backend operations with database record management.
Supports session-scoped file segmentation where files are stored in
session-specific virtual paths: /sessions/{session_id}/{filename}
"""
def __init__(
self, user_id: str, workspace_id: str, session_id: Optional[str] = None
):
"""
Initialize WorkspaceManager.
Args:
user_id: The user's ID
workspace_id: The workspace ID
session_id: Optional session ID for session-scoped file access
"""
self.user_id = user_id
self.workspace_id = workspace_id
self.session_id = session_id
# Session path prefix for file isolation
self.session_path = f"/sessions/{session_id}" if session_id else ""
def _resolve_path(self, path: str) -> str:
"""
Resolve a path, defaulting to session folder if session_id is set.
Cross-session access is allowed by explicitly using /sessions/other-session-id/...
Args:
path: Virtual path (e.g., "/file.txt" or "/sessions/abc123/file.txt")
Returns:
Resolved path with session prefix if applicable
"""
# If path explicitly references a session folder, use it as-is
if path.startswith("/sessions/"):
return path
# If we have a session context, prepend session path
if self.session_path:
# Normalize the path
if not path.startswith("/"):
path = f"/{path}"
return f"{self.session_path}{path}"
# No session context, use path as-is
return path if path.startswith("/") else f"/{path}"
def _get_effective_path(
self, path: Optional[str], include_all_sessions: bool
) -> Optional[str]:
"""
Get effective path for list/count operations based on session context.
Args:
path: Optional path prefix to filter
include_all_sessions: If True, don't apply session scoping
Returns:
Effective path prefix for database query
"""
if include_all_sessions:
# Normalize path to ensure leading slash (stored paths are normalized)
if path is not None and not path.startswith("/"):
return f"/{path}"
return path
elif path is not None:
# Resolve the provided path with session scoping
return self._resolve_path(path)
elif self.session_path:
# Default to session folder with trailing slash to prevent prefix collisions
# e.g., "/sessions/abc" should not match "/sessions/abc123"
return self.session_path.rstrip("/") + "/"
else:
# No session context, use path as-is
return path
async def read_file(self, path: str) -> bytes:
"""
Read file from workspace by virtual path.
When session_id is set, paths are resolved relative to the session folder
unless they explicitly reference /sessions/...
Args:
path: Virtual path (e.g., "/documents/report.pdf")
Returns:
File content as bytes
Raises:
FileNotFoundError: If file doesn't exist
"""
resolved_path = self._resolve_path(path)
file = await get_workspace_file_by_path(self.workspace_id, resolved_path)
if file is None:
raise FileNotFoundError(f"File not found at path: {resolved_path}")
storage = await get_workspace_storage()
return await storage.retrieve(file.storagePath)
async def read_file_by_id(self, file_id: str) -> bytes:
"""
Read file from workspace by file ID.
Args:
file_id: The file's ID
Returns:
File content as bytes
Raises:
FileNotFoundError: If file doesn't exist
"""
file = await get_workspace_file(file_id, self.workspace_id)
if file is None:
raise FileNotFoundError(f"File not found: {file_id}")
storage = await get_workspace_storage()
return await storage.retrieve(file.storagePath)
async def write_file(
self,
content: bytes,
filename: str,
path: Optional[str] = None,
mime_type: Optional[str] = None,
overwrite: bool = False,
) -> UserWorkspaceFile:
"""
Write file to workspace.
When session_id is set, files are written to /sessions/{session_id}/...
by default. Use explicit /sessions/... paths for cross-session access.
Args:
content: File content as bytes
filename: Filename for the file
path: Virtual path (defaults to "/{filename}", session-scoped if session_id set)
mime_type: MIME type (auto-detected if not provided)
overwrite: Whether to overwrite existing file at path
Returns:
Created UserWorkspaceFile instance
Raises:
ValueError: If file exceeds size limit or path already exists
"""
# Enforce file size limit
max_file_size = Config().max_file_size_mb * 1024 * 1024
if len(content) > max_file_size:
raise ValueError(
f"File too large: {len(content)} bytes exceeds "
f"{Config().max_file_size_mb}MB limit"
)
# Determine path with session scoping
if path is None:
path = f"/{filename}"
elif not path.startswith("/"):
path = f"/{path}"
# Resolve path with session prefix
path = self._resolve_path(path)
# Check if file exists at path (only error for non-overwrite case)
# For overwrite=True, we let the write proceed and handle via UniqueViolationError
# This ensures the new file is written to storage BEFORE the old one is deleted,
# preventing data loss if the new write fails
if not overwrite:
existing = await get_workspace_file_by_path(self.workspace_id, path)
if existing is not None:
raise ValueError(f"File already exists at path: {path}")
# Auto-detect MIME type if not provided
if mime_type is None:
mime_type, _ = mimetypes.guess_type(filename)
mime_type = mime_type or "application/octet-stream"
# Compute checksum
checksum = compute_file_checksum(content)
# Generate unique file ID for storage
file_id = str(uuid.uuid4())
# Store file in storage backend
storage = await get_workspace_storage()
storage_path = await storage.store(
workspace_id=self.workspace_id,
file_id=file_id,
filename=filename,
content=content,
)
# Create database record - handle race condition where another request
# created a file at the same path between our check and create
try:
file = await create_workspace_file(
workspace_id=self.workspace_id,
file_id=file_id,
name=filename,
path=path,
storage_path=storage_path,
mime_type=mime_type,
size_bytes=len(content),
checksum=checksum,
)
except UniqueViolationError:
# Race condition: another request created a file at this path
if overwrite:
# Re-fetch and delete the conflicting file, then retry
existing = await get_workspace_file_by_path(self.workspace_id, path)
if existing:
await self.delete_file(existing.id)
# Retry the create - if this also fails, clean up storage file
try:
file = await create_workspace_file(
workspace_id=self.workspace_id,
file_id=file_id,
name=filename,
path=path,
storage_path=storage_path,
mime_type=mime_type,
size_bytes=len(content),
checksum=checksum,
)
except Exception:
# Clean up orphaned storage file on retry failure
try:
await storage.delete(storage_path)
except Exception as e:
logger.warning(f"Failed to clean up orphaned storage file: {e}")
raise
else:
# Clean up the orphaned storage file before raising
try:
await storage.delete(storage_path)
except Exception as e:
logger.warning(f"Failed to clean up orphaned storage file: {e}")
raise ValueError(f"File already exists at path: {path}")
except Exception:
# Any other database error (connection, validation, etc.) - clean up storage
try:
await storage.delete(storage_path)
except Exception as e:
logger.warning(f"Failed to clean up orphaned storage file: {e}")
raise
logger.info(
f"Wrote file {file.id} ({filename}) to workspace {self.workspace_id} "
f"at path {path}, size={len(content)} bytes"
)
return file
async def list_files(
self,
path: Optional[str] = None,
limit: Optional[int] = None,
offset: int = 0,
include_all_sessions: bool = False,
) -> list[UserWorkspaceFile]:
"""
List files in workspace.
When session_id is set and include_all_sessions is False (default),
only files in the current session's folder are listed.
Args:
path: Optional path prefix to filter (e.g., "/documents/")
limit: Maximum number of files to return
offset: Number of files to skip
include_all_sessions: If True, list files from all sessions.
If False (default), only list current session's files.
Returns:
List of UserWorkspaceFile instances
"""
effective_path = self._get_effective_path(path, include_all_sessions)
return await list_workspace_files(
workspace_id=self.workspace_id,
path_prefix=effective_path,
limit=limit,
offset=offset,
)
async def delete_file(self, file_id: str) -> bool:
"""
Delete a file (soft-delete).
Args:
file_id: The file's ID
Returns:
True if deleted, False if not found
"""
file = await get_workspace_file(file_id, self.workspace_id)
if file is None:
return False
# Delete from storage
storage = await get_workspace_storage()
try:
await storage.delete(file.storagePath)
except Exception as e:
logger.warning(f"Failed to delete file from storage: {e}")
# Continue with database soft-delete even if storage delete fails
# Soft-delete database record
result = await soft_delete_workspace_file(file_id, self.workspace_id)
return result is not None
async def get_download_url(self, file_id: str, expires_in: int = 3600) -> str:
"""
Get download URL for a file.
Args:
file_id: The file's ID
expires_in: URL expiration in seconds (default 1 hour)
Returns:
Download URL (signed URL for GCS, API endpoint for local)
Raises:
FileNotFoundError: If file doesn't exist
"""
file = await get_workspace_file(file_id, self.workspace_id)
if file is None:
raise FileNotFoundError(f"File not found: {file_id}")
storage = await get_workspace_storage()
return await storage.get_download_url(file.storagePath, expires_in)
async def get_file_info(self, file_id: str) -> Optional[UserWorkspaceFile]:
"""
Get file metadata.
Args:
file_id: The file's ID
Returns:
UserWorkspaceFile instance or None
"""
return await get_workspace_file(file_id, self.workspace_id)
async def get_file_info_by_path(self, path: str) -> Optional[UserWorkspaceFile]:
"""
Get file metadata by path.
When session_id is set, paths are resolved relative to the session folder
unless they explicitly reference /sessions/...
Args:
path: Virtual path
Returns:
UserWorkspaceFile instance or None
"""
resolved_path = self._resolve_path(path)
return await get_workspace_file_by_path(self.workspace_id, resolved_path)
async def get_file_count(
self,
path: Optional[str] = None,
include_all_sessions: bool = False,
) -> int:
"""
Get number of files in workspace.
When session_id is set and include_all_sessions is False (default),
only counts files in the current session's folder.
Args:
path: Optional path prefix to filter (e.g., "/documents/")
include_all_sessions: If True, count all files in workspace.
If False (default), only count current session's files.
Returns:
Number of files
"""
effective_path = self._get_effective_path(path, include_all_sessions)
return await count_workspace_files(
self.workspace_id, path_prefix=effective_path
)

View File

@@ -0,0 +1,398 @@
"""
Workspace storage backend abstraction for supporting both cloud and local deployments.
This module provides a unified interface for storing workspace files, with implementations
for Google Cloud Storage (cloud deployments) and local filesystem (self-hosted deployments).
"""
import asyncio
import hashlib
import logging
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
import aiofiles
import aiohttp
from gcloud.aio import storage as async_gcs_storage
from google.cloud import storage as gcs_storage
from backend.util.data import get_data_path
from backend.util.gcs_utils import (
download_with_fresh_session,
generate_signed_url,
parse_gcs_path,
)
from backend.util.settings import Config
logger = logging.getLogger(__name__)
class WorkspaceStorageBackend(ABC):
"""Abstract interface for workspace file storage."""
@abstractmethod
async def store(
self,
workspace_id: str,
file_id: str,
filename: str,
content: bytes,
) -> str:
"""
Store file content, return storage path.
Args:
workspace_id: The workspace ID
file_id: Unique file ID for storage
filename: Original filename
content: File content as bytes
Returns:
Storage path string (cloud path or local path)
"""
pass
@abstractmethod
async def retrieve(self, storage_path: str) -> bytes:
"""
Retrieve file content from storage.
Args:
storage_path: The storage path returned from store()
Returns:
File content as bytes
"""
pass
@abstractmethod
async def delete(self, storage_path: str) -> None:
"""
Delete file from storage.
Args:
storage_path: The storage path to delete
"""
pass
@abstractmethod
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
"""
Get URL for downloading the file.
Args:
storage_path: The storage path
expires_in: URL expiration time in seconds (default 1 hour)
Returns:
Download URL (signed URL for GCS, direct API path for local)
"""
pass
class GCSWorkspaceStorage(WorkspaceStorageBackend):
"""Google Cloud Storage implementation for workspace storage."""
def __init__(self, bucket_name: str):
self.bucket_name = bucket_name
self._async_client: Optional[async_gcs_storage.Storage] = None
self._sync_client: Optional[gcs_storage.Client] = None
self._session: Optional[aiohttp.ClientSession] = None
async def _get_async_client(self) -> async_gcs_storage.Storage:
"""Get or create async GCS client."""
if self._async_client is None:
self._session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=100, force_close=False)
)
self._async_client = async_gcs_storage.Storage(session=self._session)
return self._async_client
def _get_sync_client(self) -> gcs_storage.Client:
"""Get or create sync GCS client (for signed URLs)."""
if self._sync_client is None:
self._sync_client = gcs_storage.Client()
return self._sync_client
async def close(self) -> None:
"""Close all client connections."""
if self._async_client is not None:
try:
await self._async_client.close()
except Exception as e:
logger.warning(f"Error closing GCS client: {e}")
self._async_client = None
if self._session is not None:
try:
await self._session.close()
except Exception as e:
logger.warning(f"Error closing session: {e}")
self._session = None
def _build_blob_name(self, workspace_id: str, file_id: str, filename: str) -> str:
"""Build the blob path for workspace files."""
return f"workspaces/{workspace_id}/{file_id}/{filename}"
async def store(
self,
workspace_id: str,
file_id: str,
filename: str,
content: bytes,
) -> str:
"""Store file in GCS."""
client = await self._get_async_client()
blob_name = self._build_blob_name(workspace_id, file_id, filename)
# Upload with metadata
upload_time = datetime.now(timezone.utc)
await client.upload(
self.bucket_name,
blob_name,
content,
metadata={
"uploaded_at": upload_time.isoformat(),
"workspace_id": workspace_id,
"file_id": file_id,
},
)
return f"gcs://{self.bucket_name}/{blob_name}"
async def retrieve(self, storage_path: str) -> bytes:
"""Retrieve file from GCS."""
bucket_name, blob_name = parse_gcs_path(storage_path)
return await download_with_fresh_session(bucket_name, blob_name)
async def delete(self, storage_path: str) -> None:
"""Delete file from GCS."""
bucket_name, blob_name = parse_gcs_path(storage_path)
client = await self._get_async_client()
try:
await client.delete(bucket_name, blob_name)
except Exception as e:
if "404" not in str(e) and "Not Found" not in str(e):
raise
# File already deleted, that's fine
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
"""
Generate download URL for GCS file.
Attempts to generate a signed URL if running with service account credentials.
Falls back to an API proxy endpoint if signed URL generation fails
(e.g., when running locally with user OAuth credentials).
"""
bucket_name, blob_name = parse_gcs_path(storage_path)
# Extract file_id from blob_name for fallback: workspaces/{workspace_id}/{file_id}/{filename}
blob_parts = blob_name.split("/")
file_id = blob_parts[2] if len(blob_parts) >= 3 else None
# Try to generate signed URL (requires service account credentials)
try:
sync_client = self._get_sync_client()
return await generate_signed_url(
sync_client, bucket_name, blob_name, expires_in
)
except AttributeError as e:
# Signed URL generation requires service account with private key.
# When running with user OAuth credentials, fall back to API proxy.
if "private key" in str(e) and file_id:
logger.debug(
"Cannot generate signed URL (no service account credentials), "
"falling back to API proxy endpoint"
)
return f"/api/workspace/files/{file_id}/download"
raise
class LocalWorkspaceStorage(WorkspaceStorageBackend):
"""Local filesystem implementation for workspace storage (self-hosted deployments)."""
def __init__(self, base_dir: Optional[str] = None):
"""
Initialize local storage backend.
Args:
base_dir: Base directory for workspace storage.
If None, defaults to {app_data}/workspaces
"""
if base_dir:
self.base_dir = Path(base_dir)
else:
self.base_dir = Path(get_data_path()) / "workspaces"
# Ensure base directory exists
self.base_dir.mkdir(parents=True, exist_ok=True)
def _build_file_path(self, workspace_id: str, file_id: str, filename: str) -> Path:
"""Build the local file path with path traversal protection."""
# Import here to avoid circular import
# (file.py imports workspace.py which imports workspace_storage.py)
from backend.util.file import sanitize_filename
# Sanitize filename to prevent path traversal (removes / and \ among others)
safe_filename = sanitize_filename(filename)
file_path = (self.base_dir / workspace_id / file_id / safe_filename).resolve()
# Verify the resolved path is still under base_dir
if not file_path.is_relative_to(self.base_dir.resolve()):
raise ValueError("Invalid filename: path traversal detected")
return file_path
def _parse_storage_path(self, storage_path: str) -> Path:
"""Parse local storage path to filesystem path."""
if storage_path.startswith("local://"):
relative_path = storage_path[8:] # Remove "local://"
else:
relative_path = storage_path
full_path = (self.base_dir / relative_path).resolve()
# Security check: ensure path is under base_dir
# Use is_relative_to() for robust path containment check
# (handles case-insensitive filesystems and edge cases)
if not full_path.is_relative_to(self.base_dir.resolve()):
raise ValueError("Invalid storage path: path traversal detected")
return full_path
async def store(
self,
workspace_id: str,
file_id: str,
filename: str,
content: bytes,
) -> str:
"""Store file locally."""
file_path = self._build_file_path(workspace_id, file_id, filename)
# Create parent directories
file_path.parent.mkdir(parents=True, exist_ok=True)
# Write file asynchronously
async with aiofiles.open(file_path, "wb") as f:
await f.write(content)
# Return relative path as storage path
relative_path = file_path.relative_to(self.base_dir)
return f"local://{relative_path}"
async def retrieve(self, storage_path: str) -> bytes:
"""Retrieve file from local storage."""
file_path = self._parse_storage_path(storage_path)
if not file_path.exists():
raise FileNotFoundError(f"File not found: {storage_path}")
async with aiofiles.open(file_path, "rb") as f:
return await f.read()
async def delete(self, storage_path: str) -> None:
"""Delete file from local storage."""
file_path = self._parse_storage_path(storage_path)
if file_path.exists():
# Remove file
file_path.unlink()
# Clean up empty parent directories
parent = file_path.parent
while parent != self.base_dir:
try:
if parent.exists() and not any(parent.iterdir()):
parent.rmdir()
else:
break
except OSError:
break
parent = parent.parent
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
"""
Get download URL for local file.
For local storage, this returns an API endpoint path.
The actual serving is handled by the API layer.
"""
# Parse the storage path to get the components
if storage_path.startswith("local://"):
relative_path = storage_path[8:]
else:
relative_path = storage_path
# Return the API endpoint for downloading
# The file_id is extracted from the path: {workspace_id}/{file_id}/{filename}
parts = relative_path.split("/")
if len(parts) >= 2:
file_id = parts[1] # Second component is file_id
return f"/api/workspace/files/{file_id}/download"
else:
raise ValueError(f"Invalid storage path format: {storage_path}")
# Global storage backend instance
_workspace_storage: Optional[WorkspaceStorageBackend] = None
_storage_lock = asyncio.Lock()
async def get_workspace_storage() -> WorkspaceStorageBackend:
"""
Get the workspace storage backend instance.
Uses GCS if media_gcs_bucket_name is configured, otherwise uses local storage.
"""
global _workspace_storage
if _workspace_storage is None:
async with _storage_lock:
if _workspace_storage is None:
config = Config()
if config.media_gcs_bucket_name:
logger.info(
f"Using GCS workspace storage: {config.media_gcs_bucket_name}"
)
_workspace_storage = GCSWorkspaceStorage(
config.media_gcs_bucket_name
)
else:
storage_dir = (
config.workspace_storage_dir
if config.workspace_storage_dir
else None
)
logger.info(
f"Using local workspace storage: {storage_dir or 'default'}"
)
_workspace_storage = LocalWorkspaceStorage(storage_dir)
return _workspace_storage
async def shutdown_workspace_storage() -> None:
"""
Properly shutdown the global workspace storage backend.
Closes aiohttp sessions and other resources for GCS backend.
Should be called during application shutdown.
"""
global _workspace_storage
if _workspace_storage is not None:
async with _storage_lock:
if _workspace_storage is not None:
if isinstance(_workspace_storage, GCSWorkspaceStorage):
await _workspace_storage.close()
_workspace_storage = None
def compute_file_checksum(content: bytes) -> str:
"""Compute SHA256 checksum of file content."""
return hashlib.sha256(content).hexdigest()

View File

@@ -0,0 +1,2 @@
-- AlterEnum
ALTER TYPE "OnboardingStep" ADD VALUE 'VISIT_COPILOT';

View File

@@ -0,0 +1,52 @@
-- CreateEnum
CREATE TYPE "WorkspaceFileSource" AS ENUM ('UPLOAD', 'EXECUTION', 'COPILOT', 'IMPORT');
-- CreateTable
CREATE TABLE "UserWorkspace" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"userId" TEXT NOT NULL,
CONSTRAINT "UserWorkspace_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "UserWorkspaceFile" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"workspaceId" TEXT NOT NULL,
"name" TEXT NOT NULL,
"path" TEXT NOT NULL,
"storagePath" TEXT NOT NULL,
"mimeType" TEXT NOT NULL,
"sizeBytes" BIGINT NOT NULL,
"checksum" TEXT,
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
"deletedAt" TIMESTAMP(3),
"source" "WorkspaceFileSource" NOT NULL DEFAULT 'UPLOAD',
"sourceExecId" TEXT,
"sourceSessionId" TEXT,
"metadata" JSONB NOT NULL DEFAULT '{}',
CONSTRAINT "UserWorkspaceFile_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "UserWorkspace_userId_key" ON "UserWorkspace"("userId");
-- CreateIndex
CREATE INDEX "UserWorkspace_userId_idx" ON "UserWorkspace"("userId");
-- CreateIndex
CREATE INDEX "UserWorkspaceFile_workspaceId_isDeleted_idx" ON "UserWorkspaceFile"("workspaceId", "isDeleted");
-- CreateIndex
CREATE UNIQUE INDEX "UserWorkspaceFile_workspaceId_path_key" ON "UserWorkspaceFile"("workspaceId", "path");
-- AddForeignKey
ALTER TABLE "UserWorkspace" ADD CONSTRAINT "UserWorkspace_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "UserWorkspaceFile" ADD CONSTRAINT "UserWorkspaceFile_workspaceId_fkey" FOREIGN KEY ("workspaceId") REFERENCES "UserWorkspace"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -0,0 +1,16 @@
/*
Warnings:
- You are about to drop the column `source` on the `UserWorkspaceFile` table. All the data in the column will be lost.
- You are about to drop the column `sourceExecId` on the `UserWorkspaceFile` table. All the data in the column will be lost.
- You are about to drop the column `sourceSessionId` on the `UserWorkspaceFile` table. All the data in the column will be lost.
*/
-- AlterTable
ALTER TABLE "UserWorkspaceFile" DROP COLUMN "source",
DROP COLUMN "sourceExecId",
DROP COLUMN "sourceSessionId";
-- DropEnum
DROP TYPE "WorkspaceFileSource";

View File

@@ -63,6 +63,7 @@ model User {
IntegrationWebhooks IntegrationWebhook[]
NotificationBatches UserNotificationBatch[]
PendingHumanReviews PendingHumanReview[]
Workspace UserWorkspace?
// OAuth Provider relations
OAuthApplications OAuthApplication[]
@@ -81,6 +82,7 @@ enum OnboardingStep {
AGENT_INPUT
CONGRATS
// First Wins
VISIT_COPILOT
GET_RESULTS
MARKETPLACE_VISIT
MARKETPLACE_ADD_AGENT
@@ -136,6 +138,53 @@ model CoPilotUnderstanding {
@@index([userId])
}
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
//////////////// USER WORKSPACE TABLES /////////////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
// User's persistent file storage workspace
model UserWorkspace {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
userId String @unique
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
Files UserWorkspaceFile[]
@@index([userId])
}
// Individual files in a user's workspace
model UserWorkspaceFile {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
workspaceId String
Workspace UserWorkspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
// File metadata
name String // User-visible filename
path String // Virtual path (e.g., "/documents/report.pdf")
storagePath String // Actual GCS or local storage path
mimeType String
sizeBytes BigInt
checksum String? // SHA256 for integrity
// File state
isDeleted Boolean @default(false)
deletedAt DateTime?
metadata Json @default("{}")
@@unique([workspaceId, path])
@@index([workspaceId, isDeleted])
}
model BuilderSearchHistory {
id String @id @default(uuid())
createdAt DateTime @default(now())

View File

@@ -34,3 +34,6 @@ NEXT_PUBLIC_PREVIEW_STEALING_DEV=
# PostHog Analytics
NEXT_PUBLIC_POSTHOG_KEY=
NEXT_PUBLIC_POSTHOG_HOST=https://eu.i.posthog.com
# OpenAI (for voice transcription)
OPENAI_API_KEY=

View File

@@ -2,8 +2,9 @@
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { useRouter } from "next/navigation";
import { useEffect } from "react";
import { resolveResponse, shouldShowOnboarding } from "@/app/api/helpers";
import { resolveResponse, getOnboardingStatus } from "@/app/api/helpers";
import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
import { getHomepageRoute } from "@/lib/constants";
export default function OnboardingPage() {
const router = useRouter();
@@ -11,10 +12,13 @@ export default function OnboardingPage() {
useEffect(() => {
async function redirectToStep() {
try {
// Check if onboarding is enabled
const isEnabled = await shouldShowOnboarding();
if (!isEnabled) {
router.replace("/");
// Check if onboarding is enabled (also gets chat flag for redirect)
const { shouldShowOnboarding, isChatEnabled } =
await getOnboardingStatus();
const homepageRoute = getHomepageRoute(isChatEnabled);
if (!shouldShowOnboarding) {
router.replace(homepageRoute);
return;
}
@@ -22,7 +26,7 @@ export default function OnboardingPage() {
// Handle completed onboarding
if (onboarding.completedSteps.includes("GET_RESULTS")) {
router.replace("/");
router.replace(homepageRoute);
return;
}

View File

@@ -1,8 +1,9 @@
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
import { getHomepageRoute } from "@/lib/constants";
import BackendAPI from "@/lib/autogpt-server-api";
import { NextResponse } from "next/server";
import { revalidatePath } from "next/cache";
import { shouldShowOnboarding } from "@/app/api/helpers";
import { getOnboardingStatus } from "@/app/api/helpers";
// Handle the callback to complete the user session login
export async function GET(request: Request) {
@@ -25,11 +26,15 @@ export async function GET(request: Request) {
const api = new BackendAPI();
await api.createUser();
if (await shouldShowOnboarding()) {
// Get onboarding status from backend (includes chat flag evaluated for this user)
const { shouldShowOnboarding, isChatEnabled } =
await getOnboardingStatus();
if (shouldShowOnboarding) {
next = "/onboarding";
revalidatePath("/onboarding", "layout");
} else {
revalidatePath("/", "layout");
next = getHomepageRoute(isChatEnabled);
revalidatePath(next, "layout");
}
} catch (createUserError) {
console.error("Error creating user:", createUserError);

View File

@@ -1,12 +1,10 @@
"use client";
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
import { Text } from "@/components/atoms/Text/Text";
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
import type { ReactNode } from "react";
import { useEffect } from "react";
import { useCopilotStore } from "../../copilot-page-store";
import { DesktopSidebar } from "./components/DesktopSidebar/DesktopSidebar";
import { LoadingState } from "./components/LoadingState/LoadingState";
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
import { useCopilotShell } from "./useCopilotShell";
@@ -20,38 +18,21 @@ export function CopilotShell({ children }: Props) {
isMobile,
isDrawerOpen,
isLoading,
isCreatingSession,
isLoggedIn,
hasActiveSession,
sessions,
currentSessionId,
handleSelectSession,
handleOpenDrawer,
handleCloseDrawer,
handleDrawerOpenChange,
handleNewChat,
handleNewChatClick,
handleSessionClick,
hasNextPage,
isFetchingNextPage,
fetchNextPage,
isReadyToShowContent,
} = useCopilotShell();
const setNewChatHandler = useCopilotStore((s) => s.setNewChatHandler);
const requestNewChat = useCopilotStore((s) => s.requestNewChat);
useEffect(
function registerNewChatHandler() {
setNewChatHandler(handleNewChat);
return function cleanup() {
setNewChatHandler(null);
};
},
[handleNewChat],
);
function handleNewChatClick() {
requestNewChat();
}
if (!isLoggedIn) {
return (
<div className="flex h-full items-center justify-center">
@@ -72,7 +53,7 @@ export function CopilotShell({ children }: Props) {
isLoading={isLoading}
hasNextPage={hasNextPage}
isFetchingNextPage={isFetchingNextPage}
onSelectSession={handleSelectSession}
onSelectSession={handleSessionClick}
onFetchNextPage={fetchNextPage}
onNewChat={handleNewChatClick}
hasActiveSession={Boolean(hasActiveSession)}
@@ -82,7 +63,18 @@ export function CopilotShell({ children }: Props) {
<div className="relative flex min-h-0 flex-1 flex-col">
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
<div className="flex min-h-0 flex-1 flex-col">
{isReadyToShowContent ? children : <LoadingState />}
{isCreatingSession ? (
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
<div className="flex flex-col items-center gap-4">
<ChatLoader />
<Text variant="body" className="text-zinc-500">
Creating your chat...
</Text>
</div>
</div>
) : (
children
)}
</div>
</div>
@@ -94,7 +86,7 @@ export function CopilotShell({ children }: Props) {
isLoading={isLoading}
hasNextPage={hasNextPage}
isFetchingNextPage={isFetchingNextPage}
onSelectSession={handleSelectSession}
onSelectSession={handleSessionClick}
onFetchNextPage={fetchNextPage}
onNewChat={handleNewChatClick}
onClose={handleCloseDrawer}

View File

@@ -1,15 +0,0 @@
import { Text } from "@/components/atoms/Text/Text";
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
export function LoadingState() {
return (
<div className="flex flex-1 items-center justify-center">
<div className="flex flex-col items-center gap-4">
<ChatLoader />
<Text variant="body" className="text-zinc-500">
Loading your chats...
</Text>
</div>
</div>
);
}

View File

@@ -3,17 +3,17 @@ import { useState } from "react";
export function useMobileDrawer() {
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
function handleOpenDrawer() {
const handleOpenDrawer = () => {
setIsDrawerOpen(true);
}
};
function handleCloseDrawer() {
const handleCloseDrawer = () => {
setIsDrawerOpen(false);
}
};
function handleDrawerOpenChange(open: boolean) {
const handleDrawerOpenChange = (open: boolean) => {
setIsDrawerOpen(open);
}
};
return {
isDrawerOpen,

View File

@@ -1,11 +1,6 @@
import {
getGetV2ListSessionsQueryKey,
useGetV2ListSessions,
} from "@/app/api/__generated__/endpoints/chat/chat";
import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat";
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { okData } from "@/app/api/helpers";
import { useChatStore } from "@/components/contextual/Chat/chat-store";
import { useQueryClient } from "@tanstack/react-query";
import { useEffect, useState } from "react";
const PAGE_SIZE = 50;
@@ -16,12 +11,12 @@ export interface UseSessionsPaginationArgs {
export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
const [offset, setOffset] = useState(0);
const [accumulatedSessions, setAccumulatedSessions] = useState<
SessionSummaryResponse[]
>([]);
const [totalCount, setTotalCount] = useState<number | null>(null);
const queryClient = useQueryClient();
const onStreamComplete = useChatStore((state) => state.onStreamComplete);
const { data, isLoading, isFetching, isError } = useGetV2ListSessions(
{ limit: PAGE_SIZE, offset },
@@ -32,38 +27,23 @@ export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
},
);
useEffect(function refreshOnStreamComplete() {
const unsubscribe = onStreamComplete(function handleStreamComplete() {
setOffset(0);
useEffect(() => {
const responseData = okData(data);
if (responseData) {
const newSessions = responseData.sessions;
const total = responseData.total;
setTotalCount(total);
if (offset === 0) {
setAccumulatedSessions(newSessions);
} else {
setAccumulatedSessions((prev) => [...prev, ...newSessions]);
}
} else if (!enabled) {
setAccumulatedSessions([]);
setTotalCount(null);
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
});
return unsubscribe;
}, []);
useEffect(
function updateSessionsFromResponse() {
const responseData = okData(data);
if (responseData) {
const newSessions = responseData.sessions;
const total = responseData.total;
setTotalCount(total);
if (offset === 0) {
setAccumulatedSessions(newSessions);
} else {
setAccumulatedSessions((prev) => [...prev, ...newSessions]);
}
} else if (!enabled) {
setAccumulatedSessions([]);
setTotalCount(null);
}
},
[data, offset, enabled],
);
}
}, [data, offset, enabled]);
const hasNextPage =
totalCount !== null && accumulatedSessions.length < totalCount;
@@ -86,17 +66,17 @@ export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
}
}, [hasNextPage, isFetching, isLoading, isError, totalCount]);
function fetchNextPage() {
const fetchNextPage = () => {
if (hasNextPage && !isFetching) {
setOffset((prev) => prev + PAGE_SIZE);
}
}
};
function reset() {
const reset = () => {
// Only reset the offset - keep existing sessions visible during refetch
// The effect will replace sessions when new data arrives at offset 0
setOffset(0);
setAccumulatedSessions([]);
setTotalCount(null);
}
};
return {
sessions: accumulatedSessions,

View File

@@ -104,76 +104,3 @@ export function mergeCurrentSessionIntoList(
export function getCurrentSessionId(searchParams: URLSearchParams) {
return searchParams.get("sessionId");
}
export function shouldAutoSelectSession(
areAllSessionsLoaded: boolean,
hasAutoSelectedSession: boolean,
paramSessionId: string | null,
visibleSessions: SessionSummaryResponse[],
accumulatedSessions: SessionSummaryResponse[],
isLoading: boolean,
totalCount: number | null,
) {
if (!areAllSessionsLoaded || hasAutoSelectedSession) {
return {
shouldSelect: false,
sessionIdToSelect: null,
shouldCreate: false,
};
}
if (paramSessionId) {
return {
shouldSelect: false,
sessionIdToSelect: null,
shouldCreate: false,
};
}
if (visibleSessions.length > 0) {
return {
shouldSelect: true,
sessionIdToSelect: visibleSessions[0].id,
shouldCreate: false,
};
}
if (accumulatedSessions.length === 0 && !isLoading && totalCount === 0) {
return { shouldSelect: false, sessionIdToSelect: null, shouldCreate: true };
}
if (totalCount === 0) {
return {
shouldSelect: false,
sessionIdToSelect: null,
shouldCreate: false,
};
}
return { shouldSelect: false, sessionIdToSelect: null, shouldCreate: false };
}
export function checkReadyToShowContent(
areAllSessionsLoaded: boolean,
paramSessionId: string | null,
accumulatedSessions: SessionSummaryResponse[],
isCurrentSessionLoading: boolean,
currentSessionData: SessionDetailResponse | null | undefined,
hasAutoSelectedSession: boolean,
) {
if (!areAllSessionsLoaded) return false;
if (paramSessionId) {
const sessionFound = accumulatedSessions.some(
(s) => s.id === paramSessionId,
);
return (
sessionFound ||
(!isCurrentSessionLoading &&
currentSessionData !== undefined &&
currentSessionData !== null)
);
}
return hasAutoSelectedSession;
}

View File

@@ -1,26 +1,22 @@
"use client";
import {
getGetV2GetSessionQueryKey,
getGetV2ListSessionsQueryKey,
useGetV2GetSession,
} from "@/app/api/__generated__/endpoints/chat/chat";
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { okData } from "@/app/api/helpers";
import { useChatStore } from "@/components/contextual/Chat/chat-store";
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { useQueryClient } from "@tanstack/react-query";
import { parseAsString, useQueryState } from "nuqs";
import { usePathname, useSearchParams } from "next/navigation";
import { useEffect, useRef, useState } from "react";
import { useRef } from "react";
import { useCopilotStore } from "../../copilot-page-store";
import { useCopilotSessionId } from "../../useCopilotSessionId";
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination";
import {
checkReadyToShowContent,
convertSessionDetailToSummary,
filterVisibleSessions,
getCurrentSessionId,
mergeCurrentSessionIntoList,
} from "./helpers";
import { getCurrentSessionId } from "./helpers";
import { useShellSessionList } from "./useShellSessionList";
export function useCopilotShell() {
const pathname = usePathname();
@@ -31,7 +27,7 @@ export function useCopilotShell() {
const isMobile =
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
const [, setUrlSessionId] = useQueryState("sessionId", parseAsString);
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
const isOnHomepage = pathname === "/copilot";
const paramSessionId = searchParams.get("sessionId");
@@ -45,123 +41,80 @@ export function useCopilotShell() {
const paginationEnabled = !isMobile || isDrawerOpen || !!paramSessionId;
const {
sessions: accumulatedSessions,
isLoading: isSessionsLoading,
isFetching: isSessionsFetching,
hasNextPage,
areAllSessionsLoaded,
fetchNextPage,
reset: resetPagination,
} = useSessionsPagination({
enabled: paginationEnabled,
});
const currentSessionId = getCurrentSessionId(searchParams);
const { data: currentSessionData, isLoading: isCurrentSessionLoading } =
useGetV2GetSession(currentSessionId || "", {
const { data: currentSessionData } = useGetV2GetSession(
currentSessionId || "",
{
query: {
enabled: !!currentSessionId,
select: okData,
},
});
const [hasAutoSelectedSession, setHasAutoSelectedSession] = useState(false);
const hasAutoSelectedRef = useRef(false);
const recentlyCreatedSessionsRef = useRef<
Map<string, SessionSummaryResponse>
>(new Map());
// Mark as auto-selected when sessionId is in URL
useEffect(() => {
if (paramSessionId && !hasAutoSelectedRef.current) {
hasAutoSelectedRef.current = true;
setHasAutoSelectedSession(true);
}
}, [paramSessionId]);
// On homepage without sessionId, mark as ready immediately
useEffect(() => {
if (isOnHomepage && !paramSessionId && !hasAutoSelectedRef.current) {
hasAutoSelectedRef.current = true;
setHasAutoSelectedSession(true);
}
}, [isOnHomepage, paramSessionId]);
// Invalidate sessions list when navigating to homepage (to show newly created sessions)
useEffect(() => {
if (isOnHomepage && !paramSessionId) {
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
}
}, [isOnHomepage, paramSessionId, queryClient]);
// Track newly created sessions to ensure they stay visible even when switching away
useEffect(() => {
if (currentSessionId && currentSessionData) {
const isNewSession =
currentSessionData.updated_at === currentSessionData.created_at;
const isNotInAccumulated = !accumulatedSessions.some(
(s) => s.id === currentSessionId,
);
if (isNewSession || isNotInAccumulated) {
const summary = convertSessionDetailToSummary(currentSessionData);
recentlyCreatedSessionsRef.current.set(currentSessionId, summary);
}
}
}, [currentSessionId, currentSessionData, accumulatedSessions]);
// Clean up recently created sessions that are now in the accumulated list
useEffect(() => {
for (const sessionId of recentlyCreatedSessionsRef.current.keys()) {
if (accumulatedSessions.some((s) => s.id === sessionId)) {
recentlyCreatedSessionsRef.current.delete(sessionId);
}
}
}, [accumulatedSessions]);
// Reset pagination when query becomes disabled
const prevPaginationEnabledRef = useRef(paginationEnabled);
useEffect(() => {
if (prevPaginationEnabledRef.current && !paginationEnabled) {
resetPagination();
resetAutoSelect();
}
prevPaginationEnabledRef.current = paginationEnabled;
}, [paginationEnabled, resetPagination]);
const sessions = mergeCurrentSessionIntoList(
accumulatedSessions,
currentSessionId,
currentSessionData,
recentlyCreatedSessionsRef.current,
},
);
const visibleSessions = filterVisibleSessions(sessions);
const {
sessions,
isLoading,
isSessionsFetching,
hasNextPage,
fetchNextPage,
resetPagination,
recentlyCreatedSessionsRef,
} = useShellSessionList({
paginationEnabled,
currentSessionId,
currentSessionData,
isOnHomepage,
paramSessionId,
});
const sidebarSelectedSessionId =
isOnHomepage && !paramSessionId ? null : currentSessionId;
const stopStream = useChatStore((s) => s.stopStream);
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
const isStreaming = useCopilotStore((s) => s.isStreaming);
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession);
const openInterruptModal = useCopilotStore((s) => s.openInterruptModal);
const isReadyToShowContent = isOnHomepage
? true
: checkReadyToShowContent(
areAllSessionsLoaded,
paramSessionId,
accumulatedSessions,
isCurrentSessionLoading,
currentSessionData,
hasAutoSelectedSession,
);
const pendingActionRef = useRef<(() => void) | null>(null);
function handleSelectSession(sessionId: string) {
async function stopCurrentStream() {
if (!currentSessionId) return;
setIsSwitchingSession(true);
await new Promise<void>((resolve) => {
const unsubscribe = onStreamComplete((completedId) => {
if (completedId === currentSessionId) {
clearTimeout(timeout);
unsubscribe();
resolve();
}
});
const timeout = setTimeout(() => {
unsubscribe();
resolve();
}, 3000);
stopStream(currentSessionId);
});
queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(currentSessionId),
});
setIsSwitchingSession(false);
}
function selectSession(sessionId: string) {
if (sessionId === currentSessionId) return;
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(sessionId),
});
}
setUrlSessionId(sessionId, { shallow: false });
if (isMobile) handleCloseDrawer();
}
function handleNewChat() {
resetAutoSelect();
function startNewChat() {
resetPagination();
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
@@ -170,12 +123,31 @@ export function useCopilotShell() {
if (isMobile) handleCloseDrawer();
}
function resetAutoSelect() {
hasAutoSelectedRef.current = false;
setHasAutoSelectedSession(false);
function handleSessionClick(sessionId: string) {
if (sessionId === currentSessionId) return;
if (isStreaming) {
pendingActionRef.current = async () => {
await stopCurrentStream();
selectSession(sessionId);
};
openInterruptModal(pendingActionRef.current);
} else {
selectSession(sessionId);
}
}
const isLoading = isSessionsLoading && accumulatedSessions.length === 0;
function handleNewChatClick() {
if (isStreaming) {
pendingActionRef.current = async () => {
await stopCurrentStream();
startNewChat();
};
openInterruptModal(pendingActionRef.current);
} else {
startNewChat();
}
}
return {
isMobile,
@@ -183,17 +155,17 @@ export function useCopilotShell() {
isLoggedIn,
hasActiveSession:
Boolean(currentSessionId) && (!isOnHomepage || Boolean(paramSessionId)),
isLoading,
sessions: visibleSessions,
currentSessionId: sidebarSelectedSessionId,
handleSelectSession,
isLoading: isLoading || isCreatingSession,
isCreatingSession,
sessions,
currentSessionId: urlSessionId,
handleOpenDrawer,
handleCloseDrawer,
handleDrawerOpenChange,
handleNewChat,
handleNewChatClick,
handleSessionClick,
hasNextPage,
isFetchingNextPage: isSessionsFetching,
fetchNextPage,
isReadyToShowContent,
};
}

View File

@@ -0,0 +1,113 @@
import { getGetV2ListSessionsQueryKey } from "@/app/api/__generated__/endpoints/chat/chat";
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { useChatStore } from "@/components/contextual/Chat/chat-store";
import { useQueryClient } from "@tanstack/react-query";
import { useEffect, useMemo, useRef } from "react";
import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination";
import {
convertSessionDetailToSummary,
filterVisibleSessions,
mergeCurrentSessionIntoList,
} from "./helpers";
interface UseShellSessionListArgs {
paginationEnabled: boolean;
currentSessionId: string | null;
currentSessionData: SessionDetailResponse | null | undefined;
isOnHomepage: boolean;
paramSessionId: string | null;
}
export function useShellSessionList({
paginationEnabled,
currentSessionId,
currentSessionData,
isOnHomepage,
paramSessionId,
}: UseShellSessionListArgs) {
const queryClient = useQueryClient();
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
const {
sessions: accumulatedSessions,
isLoading: isSessionsLoading,
isFetching: isSessionsFetching,
hasNextPage,
fetchNextPage,
reset: resetPagination,
} = useSessionsPagination({
enabled: paginationEnabled,
});
const recentlyCreatedSessionsRef = useRef<
Map<string, SessionSummaryResponse>
>(new Map());
useEffect(() => {
if (isOnHomepage && !paramSessionId) {
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
}
}, [isOnHomepage, paramSessionId, queryClient]);
useEffect(() => {
if (currentSessionId && currentSessionData) {
const isNewSession =
currentSessionData.updated_at === currentSessionData.created_at;
const isNotInAccumulated = !accumulatedSessions.some(
(s) => s.id === currentSessionId,
);
if (isNewSession || isNotInAccumulated) {
const summary = convertSessionDetailToSummary(currentSessionData);
recentlyCreatedSessionsRef.current.set(currentSessionId, summary);
}
}
}, [currentSessionId, currentSessionData, accumulatedSessions]);
useEffect(() => {
for (const sessionId of recentlyCreatedSessionsRef.current.keys()) {
if (accumulatedSessions.some((s) => s.id === sessionId)) {
recentlyCreatedSessionsRef.current.delete(sessionId);
}
}
}, [accumulatedSessions]);
useEffect(() => {
const unsubscribe = onStreamComplete(() => {
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
});
return unsubscribe;
}, [onStreamComplete, queryClient]);
const sessions = useMemo(
() =>
mergeCurrentSessionIntoList(
accumulatedSessions,
currentSessionId,
currentSessionData,
recentlyCreatedSessionsRef.current,
),
[accumulatedSessions, currentSessionId, currentSessionData],
);
const visibleSessions = useMemo(
() => filterVisibleSessions(sessions),
[sessions],
);
const isLoading = isSessionsLoading && accumulatedSessions.length === 0;
return {
sessions: visibleSessions,
isLoading,
isSessionsFetching,
hasNextPage,
fetchNextPage,
resetPagination,
recentlyCreatedSessionsRef,
};
}

View File

@@ -4,51 +4,53 @@ import { create } from "zustand";
interface CopilotStoreState {
isStreaming: boolean;
isNewChatModalOpen: boolean;
newChatHandler: (() => void) | null;
isSwitchingSession: boolean;
isCreatingSession: boolean;
isInterruptModalOpen: boolean;
pendingAction: (() => void) | null;
}
interface CopilotStoreActions {
setIsStreaming: (isStreaming: boolean) => void;
setNewChatHandler: (handler: (() => void) | null) => void;
requestNewChat: () => void;
confirmNewChat: () => void;
cancelNewChat: () => void;
setIsSwitchingSession: (isSwitchingSession: boolean) => void;
setIsCreatingSession: (isCreating: boolean) => void;
openInterruptModal: (onConfirm: () => void) => void;
confirmInterrupt: () => void;
cancelInterrupt: () => void;
}
type CopilotStore = CopilotStoreState & CopilotStoreActions;
export const useCopilotStore = create<CopilotStore>((set, get) => ({
isStreaming: false,
isNewChatModalOpen: false,
newChatHandler: null,
isSwitchingSession: false,
isCreatingSession: false,
isInterruptModalOpen: false,
pendingAction: null,
setIsStreaming(isStreaming) {
set({ isStreaming });
},
setNewChatHandler(handler) {
set({ newChatHandler: handler });
setIsSwitchingSession(isSwitchingSession) {
set({ isSwitchingSession });
},
requestNewChat() {
const { isStreaming, newChatHandler } = get();
if (isStreaming) {
set({ isNewChatModalOpen: true });
} else if (newChatHandler) {
newChatHandler();
}
setIsCreatingSession(isCreatingSession) {
set({ isCreatingSession });
},
confirmNewChat() {
const { newChatHandler } = get();
set({ isNewChatModalOpen: false });
if (newChatHandler) {
newChatHandler();
}
openInterruptModal(onConfirm) {
set({ isInterruptModalOpen: true, pendingAction: onConfirm });
},
cancelNewChat() {
set({ isNewChatModalOpen: false });
confirmInterrupt() {
const { pendingAction } = get();
set({ isInterruptModalOpen: false, pendingAction: null });
if (pendingAction) pendingAction();
},
cancelInterrupt() {
set({ isInterruptModalOpen: false, pendingAction: null });
},
}));

View File

@@ -1,28 +1,5 @@
import type { User } from "@supabase/supabase-js";
export type PageState =
| { type: "welcome" }
| { type: "newChat" }
| { type: "creating"; prompt: string }
| { type: "chat"; sessionId: string; initialPrompt?: string };
export function getInitialPromptFromState(
pageState: PageState,
storedInitialPrompt: string | undefined,
) {
if (storedInitialPrompt) return storedInitialPrompt;
if (pageState.type === "creating") return pageState.prompt;
if (pageState.type === "chat") return pageState.initialPrompt;
}
export function shouldResetToWelcome(pageState: PageState) {
return (
pageState.type !== "newChat" &&
pageState.type !== "creating" &&
pageState.type !== "welcome"
);
}
export function getGreetingName(user?: User | null): string {
if (!user) return "there";
const metadata = user.user_metadata as Record<string, unknown> | undefined;

View File

@@ -1,25 +1,25 @@
"use client";
import { Button } from "@/components/atoms/Button/Button";
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
import { Text } from "@/components/atoms/Text/Text";
import { Chat } from "@/components/contextual/Chat/Chat";
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { useCopilotStore } from "./copilot-page-store";
import { useCopilotPage } from "./useCopilotPage";
export default function CopilotPage() {
const { state, handlers } = useCopilotPage();
const confirmNewChat = useCopilotStore((s) => s.confirmNewChat);
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
const {
greetingName,
quickActions,
isLoading,
pageState,
isNewChatModalOpen,
hasSession,
initialPrompt,
isReady,
} = state;
const {
@@ -27,20 +27,16 @@ export default function CopilotPage() {
startChatWithPrompt,
handleSessionNotFound,
handleStreamingChange,
handleCancelNewChat,
handleNewChatModalOpen,
} = handlers;
if (!isReady) return null;
if (pageState.type === "chat") {
if (hasSession) {
return (
<div className="flex h-full flex-col">
<Chat
key={pageState.sessionId ?? "welcome"}
className="flex-1"
urlSessionId={pageState.sessionId}
initialPrompt={pageState.initialPrompt}
initialPrompt={initialPrompt}
onSessionNotFound={handleSessionNotFound}
onStreamingChange={handleStreamingChange}
/>
@@ -48,31 +44,33 @@ export default function CopilotPage() {
title="Interrupt current chat?"
styling={{ maxWidth: 300, width: "100%" }}
controlled={{
isOpen: isNewChatModalOpen,
set: handleNewChatModalOpen,
isOpen: isInterruptModalOpen,
set: (open) => {
if (!open) cancelInterrupt();
},
}}
onClose={handleCancelNewChat}
onClose={cancelInterrupt}
>
<Dialog.Content>
<div className="flex flex-col gap-4">
<Text variant="body">
The current chat response will be interrupted. Are you sure you
want to start a new chat?
want to continue?
</Text>
<Dialog.Footer>
<Button
type="button"
variant="outline"
onClick={handleCancelNewChat}
onClick={cancelInterrupt}
>
Cancel
</Button>
<Button
type="button"
variant="primary"
onClick={confirmNewChat}
onClick={confirmInterrupt}
>
Start new chat
Continue
</Button>
</Dialog.Footer>
</div>
@@ -82,19 +80,6 @@ export default function CopilotPage() {
);
}
if (pageState.type === "newChat" || pageState.type === "creating") {
return (
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
<div className="flex flex-col items-center gap-4">
<ChatLoader />
<Text variant="body" className="text-zinc-500">
Loading your chats...
</Text>
</div>
</div>
);
}
return (
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10">
<div className="w-full text-center">

View File

@@ -5,79 +5,40 @@ import {
import { useToast } from "@/components/molecules/Toast/use-toast";
import { getHomepageRoute } from "@/lib/constants";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
import {
Flag,
type FlagValues,
useGetFlag,
} from "@/services/feature-flags/use-get-flag";
import { SessionKey, sessionStorage } from "@/services/storage/session-storage";
import * as Sentry from "@sentry/nextjs";
import { useQueryClient } from "@tanstack/react-query";
import { useFlags } from "launchdarkly-react-client-sdk";
import { useRouter } from "next/navigation";
import { useEffect, useReducer } from "react";
import { useEffect } from "react";
import { useCopilotStore } from "./copilot-page-store";
import { getGreetingName, getQuickActions, type PageState } from "./helpers";
import { useCopilotURLState } from "./useCopilotURLState";
type CopilotState = {
pageState: PageState;
initialPrompts: Record<string, string>;
previousSessionId: string | null;
};
type CopilotAction =
| { type: "setPageState"; pageState: PageState }
| { type: "setInitialPrompt"; sessionId: string; prompt: string }
| { type: "setPreviousSessionId"; sessionId: string | null };
function isSamePageState(next: PageState, current: PageState) {
if (next.type !== current.type) return false;
if (next.type === "creating" && current.type === "creating") {
return next.prompt === current.prompt;
}
if (next.type === "chat" && current.type === "chat") {
return (
next.sessionId === current.sessionId &&
next.initialPrompt === current.initialPrompt
);
}
return true;
}
function copilotReducer(
state: CopilotState,
action: CopilotAction,
): CopilotState {
if (action.type === "setPageState") {
if (isSamePageState(action.pageState, state.pageState)) return state;
return { ...state, pageState: action.pageState };
}
if (action.type === "setInitialPrompt") {
if (state.initialPrompts[action.sessionId] === action.prompt) return state;
return {
...state,
initialPrompts: {
...state.initialPrompts,
[action.sessionId]: action.prompt,
},
};
}
if (action.type === "setPreviousSessionId") {
if (state.previousSessionId === action.sessionId) return state;
return { ...state, previousSessionId: action.sessionId };
}
return state;
}
import { getGreetingName, getQuickActions } from "./helpers";
import { useCopilotSessionId } from "./useCopilotSessionId";
export function useCopilotPage() {
const router = useRouter();
const queryClient = useQueryClient();
const { user, isLoggedIn, isUserLoading } = useSupabase();
const { toast } = useToast();
const { completeStep } = useOnboarding();
const isNewChatModalOpen = useCopilotStore((s) => s.isNewChatModalOpen);
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
const setIsStreaming = useCopilotStore((s) => s.setIsStreaming);
const cancelNewChat = useCopilotStore((s) => s.cancelNewChat);
const isCreating = useCopilotStore((s) => s.isCreatingSession);
const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession);
// Complete VISIT_COPILOT onboarding step to grant $5 welcome bonus
useEffect(() => {
if (isLoggedIn) {
completeStep("VISIT_COPILOT");
}
}, [completeStep, isLoggedIn]);
const isChatEnabled = useGetFlag(Flag.CHAT);
const flags = useFlags<FlagValues>();
@@ -88,72 +49,27 @@ export function useCopilotPage() {
const isFlagReady =
!isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined;
const [state, dispatch] = useReducer(copilotReducer, {
pageState: { type: "welcome" },
initialPrompts: {},
previousSessionId: null,
});
const greetingName = getGreetingName(user);
const quickActions = getQuickActions();
function setPageState(pageState: PageState) {
dispatch({ type: "setPageState", pageState });
}
const hasSession = Boolean(urlSessionId);
const initialPrompt = urlSessionId
? getInitialPrompt(urlSessionId)
: undefined;
function setInitialPrompt(sessionId: string, prompt: string) {
dispatch({ type: "setInitialPrompt", sessionId, prompt });
}
function setPreviousSessionId(sessionId: string | null) {
dispatch({ type: "setPreviousSessionId", sessionId });
}
const { setUrlSessionId } = useCopilotURLState({
pageState: state.pageState,
initialPrompts: state.initialPrompts,
previousSessionId: state.previousSessionId,
setPageState,
setInitialPrompt,
setPreviousSessionId,
});
useEffect(
function transitionNewChatToWelcome() {
if (state.pageState.type === "newChat") {
function setWelcomeState() {
dispatch({ type: "setPageState", pageState: { type: "welcome" } });
}
const timer = setTimeout(setWelcomeState, 300);
return function cleanup() {
clearTimeout(timer);
};
}
},
[state.pageState.type],
);
useEffect(
function ensureAccess() {
if (!isFlagReady) return;
if (isChatEnabled === false) {
router.replace(homepageRoute);
}
},
[homepageRoute, isChatEnabled, isFlagReady, router],
);
useEffect(() => {
if (!isFlagReady) return;
if (isChatEnabled === false) {
router.replace(homepageRoute);
}
}, [homepageRoute, isChatEnabled, isFlagReady, router]);
async function startChatWithPrompt(prompt: string) {
if (!prompt?.trim()) return;
if (state.pageState.type === "creating") return;
if (isCreating) return;
const trimmedPrompt = prompt.trim();
dispatch({
type: "setPageState",
pageState: { type: "creating", prompt: trimmedPrompt },
});
setIsCreating(true);
try {
const sessionResponse = await postV2CreateSession({
@@ -165,27 +81,19 @@ export function useCopilotPage() {
}
const sessionId = sessionResponse.data.id;
dispatch({
type: "setInitialPrompt",
sessionId,
prompt: trimmedPrompt,
});
setInitialPrompt(sessionId, trimmedPrompt);
await queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
await setUrlSessionId(sessionId, { shallow: false });
dispatch({
type: "setPageState",
pageState: { type: "chat", sessionId, initialPrompt: trimmedPrompt },
});
await setUrlSessionId(sessionId, { shallow: true });
} catch (error) {
console.error("[CopilotPage] Failed to start chat:", error);
toast({ title: "Failed to start chat", variant: "destructive" });
Sentry.captureException(error);
dispatch({ type: "setPageState", pageState: { type: "welcome" } });
} finally {
setIsCreating(false);
}
}
@@ -201,21 +109,13 @@ export function useCopilotPage() {
setIsStreaming(isStreamingValue);
}
function handleCancelNewChat() {
cancelNewChat();
}
function handleNewChatModalOpen(isOpen: boolean) {
if (!isOpen) cancelNewChat();
}
return {
state: {
greetingName,
quickActions,
isLoading: isUserLoading,
pageState: state.pageState,
isNewChatModalOpen,
hasSession,
initialPrompt,
isReady: isFlagReady && isChatEnabled !== false && isLoggedIn,
},
handlers: {
@@ -223,8 +123,32 @@ export function useCopilotPage() {
startChatWithPrompt,
handleSessionNotFound,
handleStreamingChange,
handleCancelNewChat,
handleNewChatModalOpen,
},
};
}
function getInitialPrompt(sessionId: string): string | undefined {
try {
const prompts = JSON.parse(
sessionStorage.get(SessionKey.CHAT_INITIAL_PROMPTS) || "{}",
);
return prompts[sessionId];
} catch {
return undefined;
}
}
function setInitialPrompt(sessionId: string, prompt: string): void {
try {
const prompts = JSON.parse(
sessionStorage.get(SessionKey.CHAT_INITIAL_PROMPTS) || "{}",
);
prompts[sessionId] = prompt;
sessionStorage.set(
SessionKey.CHAT_INITIAL_PROMPTS,
JSON.stringify(prompts),
);
} catch {
// Ignore storage errors
}
}

View File

@@ -0,0 +1,10 @@
import { parseAsString, useQueryState } from "nuqs";
export function useCopilotSessionId() {
const [urlSessionId, setUrlSessionId] = useQueryState(
"sessionId",
parseAsString,
);
return { urlSessionId, setUrlSessionId };
}

View File

@@ -1,80 +0,0 @@
import { parseAsString, useQueryState } from "nuqs";
import { useLayoutEffect } from "react";
import {
getInitialPromptFromState,
type PageState,
shouldResetToWelcome,
} from "./helpers";
interface UseCopilotUrlStateArgs {
pageState: PageState;
initialPrompts: Record<string, string>;
previousSessionId: string | null;
setPageState: (pageState: PageState) => void;
setInitialPrompt: (sessionId: string, prompt: string) => void;
setPreviousSessionId: (sessionId: string | null) => void;
}
export function useCopilotURLState({
pageState,
initialPrompts,
previousSessionId,
setPageState,
setInitialPrompt,
setPreviousSessionId,
}: UseCopilotUrlStateArgs) {
const [urlSessionId, setUrlSessionId] = useQueryState(
"sessionId",
parseAsString,
);
function syncSessionFromUrl() {
if (urlSessionId) {
if (pageState.type === "chat" && pageState.sessionId === urlSessionId) {
setPreviousSessionId(urlSessionId);
return;
}
const storedInitialPrompt = initialPrompts[urlSessionId];
const currentInitialPrompt = getInitialPromptFromState(
pageState,
storedInitialPrompt,
);
if (currentInitialPrompt) {
setInitialPrompt(urlSessionId, currentInitialPrompt);
}
setPageState({
type: "chat",
sessionId: urlSessionId,
initialPrompt: currentInitialPrompt,
});
setPreviousSessionId(urlSessionId);
return;
}
const wasInChat = previousSessionId !== null && pageState.type === "chat";
setPreviousSessionId(null);
if (wasInChat) {
setPageState({ type: "newChat" });
return;
}
if (shouldResetToWelcome(pageState)) {
setPageState({ type: "welcome" });
}
}
useLayoutEffect(syncSessionFromUrl, [
urlSessionId,
pageState.type,
previousSessionId,
initialPrompts,
]);
return {
urlSessionId,
setUrlSessionId,
};
}

View File

@@ -1,10 +1,11 @@
"use server";
import { getHomepageRoute } from "@/lib/constants";
import BackendAPI from "@/lib/autogpt-server-api";
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
import { loginFormSchema } from "@/types/auth";
import * as Sentry from "@sentry/nextjs";
import { shouldShowOnboarding } from "../../api/helpers";
import { getOnboardingStatus } from "../../api/helpers";
export async function login(email: string, password: string) {
try {
@@ -36,11 +37,15 @@ export async function login(email: string, password: string) {
const api = new BackendAPI();
await api.createUser();
const onboarding = await shouldShowOnboarding();
// Get onboarding status from backend (includes chat flag evaluated for this user)
const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus();
const next = shouldShowOnboarding
? "/onboarding"
: getHomepageRoute(isChatEnabled);
return {
success: true,
onboarding,
next,
};
} catch (err) {
Sentry.captureException(err);

View File

@@ -97,13 +97,8 @@ export function useLoginPage() {
throw new Error(result.error || "Login failed");
}
if (nextUrl) {
router.replace(nextUrl);
} else if (result.onboarding) {
router.replace("/onboarding");
} else {
router.replace(homepageRoute);
}
// Prefer URL's next parameter, then use backend-determined route
router.replace(nextUrl || result.next || homepageRoute);
} catch (error) {
toast({
title:

View File

@@ -5,14 +5,13 @@ import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
import { signupFormSchema } from "@/types/auth";
import * as Sentry from "@sentry/nextjs";
import { isWaitlistError, logWaitlistError } from "../../api/auth/utils";
import { shouldShowOnboarding } from "../../api/helpers";
import { getOnboardingStatus } from "../../api/helpers";
export async function signup(
email: string,
password: string,
confirmPassword: string,
agreeToTerms: boolean,
isChatEnabled: boolean,
) {
try {
const parsed = signupFormSchema.safeParse({
@@ -59,8 +58,9 @@ export async function signup(
await supabase.auth.setSession(data.session);
}
const isOnboardingEnabled = await shouldShowOnboarding();
const next = isOnboardingEnabled
// Get onboarding status from backend (includes chat flag evaluated for this user)
const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus();
const next = shouldShowOnboarding
? "/onboarding"
: getHomepageRoute(isChatEnabled);

View File

@@ -108,7 +108,6 @@ export function useSignupPage() {
data.password,
data.confirmPassword,
data.agreeToTerms,
isChatEnabled === true,
);
setIsLoading(false);

View File

@@ -175,9 +175,12 @@ export async function resolveResponse<
return res.data;
}
export async function shouldShowOnboarding() {
const isEnabled = await resolveResponse(getV1IsOnboardingEnabled());
export async function getOnboardingStatus() {
const status = await resolveResponse(getV1IsOnboardingEnabled());
const onboarding = await resolveResponse(getV1OnboardingState());
const isCompleted = onboarding.completedSteps.includes("CONGRATS");
return isEnabled && !isCompleted;
return {
shouldShowOnboarding: status.is_onboarding_enabled && !isCompleted,
isChatEnabled: status.is_chat_enabled,
};
}

View File

@@ -3339,7 +3339,7 @@
"get": {
"tags": ["v2", "library", "private"],
"summary": "List Library Agents",
"description": "Get all agents in the user's library (both created and saved).\n\nArgs:\n user_id: ID of the authenticated user.\n search_term: Optional search term to filter agents by name/description.\n filter_by: List of filters to apply (favorites, created by user).\n sort_by: List of sorting criteria (created date, updated date).\n page: Page number to retrieve.\n page_size: Number of agents per page.\n\nReturns:\n A LibraryAgentResponse containing agents and pagination metadata.\n\nRaises:\n HTTPException: If a server/database error occurs.",
"description": "Get all agents in the user's library (both created and saved).",
"operationId": "getV2List library agents",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
@@ -3394,7 +3394,7 @@
],
"responses": {
"200": {
"description": "List of library agents",
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
@@ -3413,17 +3413,13 @@
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
},
"500": {
"description": "Server error",
"content": { "application/json": {} }
}
}
},
"post": {
"tags": ["v2", "library", "private"],
"summary": "Add Marketplace Agent",
"description": "Add an agent from the marketplace to the user's library.\n\nArgs:\n store_listing_version_id: ID of the store listing version to add.\n user_id: ID of the authenticated user.\n\nReturns:\n library_model.LibraryAgent: Agent added to the library\n\nRaises:\n HTTPException(404): If the listing version is not found.\n HTTPException(500): If a server/database error occurs.",
"description": "Add an agent from the marketplace to the user's library.",
"operationId": "postV2Add marketplace agent",
"security": [{ "HTTPBearerJWT": [] }],
"requestBody": {
@@ -3438,7 +3434,7 @@
},
"responses": {
"201": {
"description": "Agent added successfully",
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/LibraryAgent" }
@@ -3448,7 +3444,6 @@
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"404": { "description": "Store listing version not found" },
"422": {
"description": "Validation Error",
"content": {
@@ -3456,8 +3451,7 @@
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
},
"500": { "description": "Server error" }
}
}
}
},
@@ -3511,7 +3505,7 @@
"get": {
"tags": ["v2", "library", "private"],
"summary": "List Favorite Library Agents",
"description": "Get all favorite agents in the user's library.\n\nArgs:\n user_id: ID of the authenticated user.\n page: Page number to retrieve.\n page_size: Number of agents per page.\n\nReturns:\n A LibraryAgentResponse containing favorite agents and pagination metadata.\n\nRaises:\n HTTPException: If a server/database error occurs.",
"description": "Get all favorite agents in the user's library.",
"operationId": "getV2List favorite library agents",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
@@ -3563,10 +3557,6 @@
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
},
"500": {
"description": "Server error",
"content": { "application/json": {} }
}
}
}
@@ -3588,7 +3578,7 @@
],
"responses": {
"200": {
"description": "Library agent found",
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
@@ -3604,7 +3594,6 @@
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"404": { "description": "Agent not found" },
"422": {
"description": "Validation Error",
"content": {
@@ -3620,7 +3609,7 @@
"delete": {
"tags": ["v2", "library", "private"],
"summary": "Delete Library Agent",
"description": "Soft-delete the specified library agent.\n\nArgs:\n library_agent_id: ID of the library agent to delete.\n user_id: ID of the authenticated user.\n\nReturns:\n 204 No Content if successful.\n\nRaises:\n HTTPException(404): If the agent does not exist.\n HTTPException(500): If a server/database error occurs.",
"description": "Soft-delete the specified library agent.",
"operationId": "deleteV2Delete library agent",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
@@ -3636,11 +3625,9 @@
"description": "Successful Response",
"content": { "application/json": { "schema": {} } }
},
"204": { "description": "Agent deleted successfully" },
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"404": { "description": "Agent not found" },
"422": {
"description": "Validation Error",
"content": {
@@ -3648,8 +3635,7 @@
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
},
"500": { "description": "Server error" }
}
}
},
"get": {
@@ -3690,7 +3676,7 @@
"patch": {
"tags": ["v2", "library", "private"],
"summary": "Update Library Agent",
"description": "Update the library agent with the given fields.\n\nArgs:\n library_agent_id: ID of the library agent to update.\n payload: Fields to update (auto_update_version, is_favorite, etc.).\n user_id: ID of the authenticated user.\n\nRaises:\n HTTPException(500): If a server/database error occurs.",
"description": "Update the library agent with the given fields.",
"operationId": "patchV2Update library agent",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
@@ -3713,7 +3699,7 @@
},
"responses": {
"200": {
"description": "Agent updated successfully",
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/LibraryAgent" }
@@ -3730,8 +3716,7 @@
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
},
"500": { "description": "Server error" }
}
}
}
},
@@ -4540,8 +4525,7 @@
"content": {
"application/json": {
"schema": {
"type": "boolean",
"title": "Response Getv1Is Onboarding Enabled"
"$ref": "#/components/schemas/OnboardingStatusResponse"
}
}
}
@@ -4594,6 +4578,7 @@
"AGENT_NEW_RUN",
"AGENT_INPUT",
"CONGRATS",
"VISIT_COPILOT",
"MARKETPLACE_VISIT",
"BUILDER_OPEN"
],
@@ -5927,6 +5912,40 @@
}
}
},
"/api/workspace/files/{file_id}/download": {
"get": {
"tags": ["workspace"],
"summary": "Download file by ID",
"description": "Download a file by its ID.\n\nReturns the file content directly or redirects to a signed URL for GCS.",
"operationId": "getWorkspaceDownload file by id",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "file_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "File Id" }
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": { "application/json": { "schema": {} } }
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/health": {
"get": {
"tags": ["health"],
@@ -8744,6 +8763,19 @@
"title": "OAuthApplicationPublicInfo",
"description": "Public information about an OAuth application (for consent screen)"
},
"OnboardingStatusResponse": {
"properties": {
"is_onboarding_enabled": {
"type": "boolean",
"title": "Is Onboarding Enabled"
},
"is_chat_enabled": { "type": "boolean", "title": "Is Chat Enabled" }
},
"type": "object",
"required": ["is_onboarding_enabled", "is_chat_enabled"],
"title": "OnboardingStatusResponse",
"description": "Response for onboarding status check."
},
"OnboardingStep": {
"type": "string",
"enum": [
@@ -8754,6 +8786,7 @@
"AGENT_NEW_RUN",
"AGENT_INPUT",
"CONGRATS",
"VISIT_COPILOT",
"GET_RESULTS",
"MARKETPLACE_VISIT",
"MARKETPLACE_ADD_AGENT",

View File

@@ -1,5 +1,6 @@
import {
ApiError,
getServerAuthToken,
makeAuthenticatedFileUpload,
makeAuthenticatedRequest,
} from "@/lib/autogpt-server-api/helpers";
@@ -15,6 +16,69 @@ function buildBackendUrl(path: string[], queryString: string): string {
return `${environment.getAGPTServerBaseUrl()}/${backendPath}${queryString}`;
}
/**
* Check if this is a workspace file download request that needs binary response handling.
*/
function isWorkspaceDownloadRequest(path: string[]): boolean {
// Match pattern: api/workspace/files/{id}/download (5 segments)
return (
path.length == 5 &&
path[0] === "api" &&
path[1] === "workspace" &&
path[2] === "files" &&
path[path.length - 1] === "download"
);
}
/**
* Handle workspace file download requests with proper binary response streaming.
*/
async function handleWorkspaceDownload(
req: NextRequest,
backendUrl: string,
): Promise<NextResponse> {
const token = await getServerAuthToken();
const headers: Record<string, string> = {};
if (token && token !== "no-token-found") {
headers["Authorization"] = `Bearer ${token}`;
}
const response = await fetch(backendUrl, {
method: "GET",
headers,
redirect: "follow", // Follow redirects to signed URLs
});
if (!response.ok) {
return NextResponse.json(
{ error: `Failed to download file: ${response.statusText}` },
{ status: response.status },
);
}
// Get the content type from the backend response
const contentType =
response.headers.get("Content-Type") || "application/octet-stream";
const contentDisposition = response.headers.get("Content-Disposition");
// Stream the response body
const responseHeaders: Record<string, string> = {
"Content-Type": contentType,
};
if (contentDisposition) {
responseHeaders["Content-Disposition"] = contentDisposition;
}
// Return the binary content
const arrayBuffer = await response.arrayBuffer();
return new NextResponse(arrayBuffer, {
status: 200,
headers: responseHeaders,
});
}
async function handleJsonRequest(
req: NextRequest,
method: string,
@@ -180,6 +244,11 @@ async function handler(
};
try {
// Handle workspace file downloads separately (binary response)
if (method === "GET" && isWorkspaceDownloadRequest(path)) {
return await handleWorkspaceDownload(req, backendUrl);
}
if (method === "GET" || method === "DELETE") {
responseBody = await handleGetDeleteRequest(method, backendUrl, req);
} else if (contentType?.includes("application/json")) {

View File

@@ -0,0 +1,77 @@
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
import { NextRequest, NextResponse } from "next/server";
const WHISPER_API_URL = "https://api.openai.com/v1/audio/transcriptions";
const MAX_FILE_SIZE = 25 * 1024 * 1024; // 25MB - Whisper's limit
function getExtensionFromMimeType(mimeType: string): string {
const subtype = mimeType.split("/")[1]?.split(";")[0];
return subtype || "webm";
}
export async function POST(request: NextRequest) {
const token = await getServerAuthToken();
if (!token || token === "no-token-found") {
return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
}
const apiKey = process.env.OPENAI_API_KEY;
if (!apiKey) {
return NextResponse.json(
{ error: "OpenAI API key not configured" },
{ status: 401 },
);
}
try {
const formData = await request.formData();
const audioFile = formData.get("audio");
if (!audioFile || !(audioFile instanceof Blob)) {
return NextResponse.json(
{ error: "No audio file provided" },
{ status: 400 },
);
}
if (audioFile.size > MAX_FILE_SIZE) {
return NextResponse.json(
{ error: "File too large. Maximum size is 25MB." },
{ status: 413 },
);
}
const ext = getExtensionFromMimeType(audioFile.type);
const whisperFormData = new FormData();
whisperFormData.append("file", audioFile, `recording.${ext}`);
whisperFormData.append("model", "whisper-1");
const response = await fetch(WHISPER_API_URL, {
method: "POST",
headers: {
Authorization: `Bearer ${apiKey}`,
},
body: whisperFormData,
});
if (!response.ok) {
const errorData = await response.json().catch(() => ({}));
console.error("Whisper API error:", errorData);
return NextResponse.json(
{ error: errorData.error?.message || "Transcription failed" },
{ status: response.status },
);
}
const result = await response.json();
return NextResponse.json({ text: result.text });
} catch (error) {
console.error("Transcription error:", error);
return NextResponse.json(
{ error: "Failed to process audio" },
{ status: 500 },
);
}
}

View File

@@ -1,16 +1,17 @@
"use client";
import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId";
import { useCopilotStore } from "@/app/(platform)/copilot/copilot-page-store";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { Text } from "@/components/atoms/Text/Text";
import { cn } from "@/lib/utils";
import { useEffect, useRef } from "react";
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
import { ChatErrorState } from "./components/ChatErrorState/ChatErrorState";
import { ChatLoader } from "./components/ChatLoader/ChatLoader";
import { useChat } from "./useChat";
export interface ChatProps {
className?: string;
urlSessionId?: string | null;
initialPrompt?: string;
onSessionNotFound?: () => void;
onStreamingChange?: (isStreaming: boolean) => void;
@@ -18,12 +19,13 @@ export interface ChatProps {
export function Chat({
className,
urlSessionId,
initialPrompt,
onSessionNotFound,
onStreamingChange,
}: ChatProps) {
const { urlSessionId } = useCopilotSessionId();
const hasHandledNotFoundRef = useRef(false);
const isSwitchingSession = useCopilotStore((s) => s.isSwitchingSession);
const {
messages,
isLoading,
@@ -33,49 +35,59 @@ export function Chat({
sessionId,
createSession,
showLoader,
startPollingForOperation,
} = useChat({ urlSessionId });
useEffect(
function handleMissingSession() {
if (!onSessionNotFound) return;
if (!urlSessionId) return;
if (!isSessionNotFound || isLoading || isCreating) return;
if (hasHandledNotFoundRef.current) return;
hasHandledNotFoundRef.current = true;
onSessionNotFound();
},
[onSessionNotFound, urlSessionId, isSessionNotFound, isLoading, isCreating],
);
useEffect(() => {
if (!onSessionNotFound) return;
if (!urlSessionId) return;
if (!isSessionNotFound || isLoading || isCreating) return;
if (hasHandledNotFoundRef.current) return;
hasHandledNotFoundRef.current = true;
onSessionNotFound();
}, [
onSessionNotFound,
urlSessionId,
isSessionNotFound,
isLoading,
isCreating,
]);
const shouldShowLoader =
(showLoader && (isLoading || isCreating)) || isSwitchingSession;
return (
<div className={cn("flex h-full flex-col", className)}>
{/* Main Content */}
<main className="flex min-h-0 w-full flex-1 flex-col overflow-hidden bg-[#f8f8f9]">
{/* Loading State */}
{showLoader && (isLoading || isCreating) && (
{shouldShowLoader && (
<div className="flex flex-1 items-center justify-center">
<div className="flex flex-col items-center gap-4">
<ChatLoader />
<div className="flex flex-col items-center gap-3">
<LoadingSpinner size="large" className="text-neutral-400" />
<Text variant="body" className="text-zinc-500">
Loading your chats...
{isSwitchingSession
? "Switching chat..."
: "Loading your chat..."}
</Text>
</div>
</div>
)}
{/* Error State */}
{error && !isLoading && (
{error && !isLoading && !isSwitchingSession && (
<ChatErrorState error={error} onRetry={createSession} />
)}
{/* Session Content */}
{sessionId && !isLoading && !error && (
{sessionId && !isLoading && !error && !isSwitchingSession && (
<ChatContainer
sessionId={sessionId}
initialMessages={messages}
initialPrompt={initialPrompt}
className="flex-1"
onStreamingChange={onStreamingChange}
onOperationStarted={startPollingForOperation}
/>
)}
</main>

View File

@@ -58,39 +58,17 @@ function notifyStreamComplete(
}
}
function cleanupCompletedStreams(completedStreams: Map<string, StreamResult>) {
function cleanupExpiredStreams(
completedStreams: Map<string, StreamResult>,
): Map<string, StreamResult> {
const now = Date.now();
for (const [sessionId, result] of completedStreams) {
const cleaned = new Map(completedStreams);
for (const [sessionId, result] of cleaned) {
if (now - result.completedAt > COMPLETED_STREAM_TTL) {
completedStreams.delete(sessionId);
cleaned.delete(sessionId);
}
}
}
function moveToCompleted(
activeStreams: Map<string, ActiveStream>,
completedStreams: Map<string, StreamResult>,
streamCompleteCallbacks: Set<StreamCompleteCallback>,
sessionId: string,
) {
const stream = activeStreams.get(sessionId);
if (!stream) return;
const result: StreamResult = {
sessionId,
status: stream.status,
chunks: stream.chunks,
completedAt: Date.now(),
error: stream.error,
};
completedStreams.set(sessionId, result);
activeStreams.delete(sessionId);
cleanupCompletedStreams(completedStreams);
if (stream.status === "completed" || stream.status === "error") {
notifyStreamComplete(streamCompleteCallbacks, sessionId);
}
return cleaned;
}
export const useChatStore = create<ChatStore>((set, get) => ({
@@ -106,17 +84,31 @@ export const useChatStore = create<ChatStore>((set, get) => ({
context,
onChunk,
) {
const { activeStreams, completedStreams, streamCompleteCallbacks } = get();
const state = get();
const newActiveStreams = new Map(state.activeStreams);
let newCompletedStreams = new Map(state.completedStreams);
const callbacks = state.streamCompleteCallbacks;
const existingStream = activeStreams.get(sessionId);
const existingStream = newActiveStreams.get(sessionId);
if (existingStream) {
existingStream.abortController.abort();
moveToCompleted(
activeStreams,
completedStreams,
streamCompleteCallbacks,
const normalizedStatus =
existingStream.status === "streaming"
? "completed"
: existingStream.status;
const result: StreamResult = {
sessionId,
);
status: normalizedStatus,
chunks: existingStream.chunks,
completedAt: Date.now(),
error: existingStream.error,
};
newCompletedStreams.set(sessionId, result);
newActiveStreams.delete(sessionId);
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
if (normalizedStatus === "completed" || normalizedStatus === "error") {
notifyStreamComplete(callbacks, sessionId);
}
}
const abortController = new AbortController();
@@ -132,36 +124,76 @@ export const useChatStore = create<ChatStore>((set, get) => ({
onChunkCallbacks: initialCallbacks,
};
activeStreams.set(sessionId, stream);
newActiveStreams.set(sessionId, stream);
set({
activeStreams: newActiveStreams,
completedStreams: newCompletedStreams,
});
try {
await executeStream(stream, message, isUserMessage, context);
} finally {
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
if (stream.status !== "streaming") {
moveToCompleted(
activeStreams,
completedStreams,
streamCompleteCallbacks,
sessionId,
);
const currentState = get();
const finalActiveStreams = new Map(currentState.activeStreams);
let finalCompletedStreams = new Map(currentState.completedStreams);
const storedStream = finalActiveStreams.get(sessionId);
if (storedStream === stream) {
const result: StreamResult = {
sessionId,
status: stream.status,
chunks: stream.chunks,
completedAt: Date.now(),
error: stream.error,
};
finalCompletedStreams.set(sessionId, result);
finalActiveStreams.delete(sessionId);
finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams);
set({
activeStreams: finalActiveStreams,
completedStreams: finalCompletedStreams,
});
if (stream.status === "completed" || stream.status === "error") {
notifyStreamComplete(
currentState.streamCompleteCallbacks,
sessionId,
);
}
}
}
}
},
stopStream: function stopStream(sessionId) {
const { activeStreams, completedStreams, streamCompleteCallbacks } = get();
const stream = activeStreams.get(sessionId);
if (stream) {
stream.abortController.abort();
stream.status = "completed";
moveToCompleted(
activeStreams,
completedStreams,
streamCompleteCallbacks,
sessionId,
);
}
const state = get();
const stream = state.activeStreams.get(sessionId);
if (!stream) return;
stream.abortController.abort();
stream.status = "completed";
const newActiveStreams = new Map(state.activeStreams);
let newCompletedStreams = new Map(state.completedStreams);
const result: StreamResult = {
sessionId,
status: stream.status,
chunks: stream.chunks,
completedAt: Date.now(),
error: stream.error,
};
newCompletedStreams.set(sessionId, result);
newActiveStreams.delete(sessionId);
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
set({
activeStreams: newActiveStreams,
completedStreams: newCompletedStreams,
});
notifyStreamComplete(state.streamCompleteCallbacks, sessionId);
},
subscribeToStream: function subscribeToStream(
@@ -169,16 +201,18 @@ export const useChatStore = create<ChatStore>((set, get) => ({
onChunk,
skipReplay = false,
) {
const { activeStreams } = get();
const state = get();
const stream = state.activeStreams.get(sessionId);
const stream = activeStreams.get(sessionId);
if (stream) {
if (!skipReplay) {
for (const chunk of stream.chunks) {
onChunk(chunk);
}
}
stream.onChunkCallbacks.add(onChunk);
return function unsubscribe() {
stream.onChunkCallbacks.delete(onChunk);
};
@@ -204,7 +238,12 @@ export const useChatStore = create<ChatStore>((set, get) => ({
},
clearCompletedStream: function clearCompletedStream(sessionId) {
get().completedStreams.delete(sessionId);
const state = get();
if (!state.completedStreams.has(sessionId)) return;
const newCompletedStreams = new Map(state.completedStreams);
newCompletedStreams.delete(sessionId);
set({ completedStreams: newCompletedStreams });
},
isStreaming: function isStreaming(sessionId) {
@@ -213,11 +252,21 @@ export const useChatStore = create<ChatStore>((set, get) => ({
},
registerActiveSession: function registerActiveSession(sessionId) {
get().activeSessions.add(sessionId);
const state = get();
if (state.activeSessions.has(sessionId)) return;
const newActiveSessions = new Set(state.activeSessions);
newActiveSessions.add(sessionId);
set({ activeSessions: newActiveSessions });
},
unregisterActiveSession: function unregisterActiveSession(sessionId) {
get().activeSessions.delete(sessionId);
const state = get();
if (!state.activeSessions.has(sessionId)) return;
const newActiveSessions = new Set(state.activeSessions);
newActiveSessions.delete(sessionId);
set({ activeSessions: newActiveSessions });
},
isSessionActive: function isSessionActive(sessionId) {
@@ -225,10 +274,16 @@ export const useChatStore = create<ChatStore>((set, get) => ({
},
onStreamComplete: function onStreamComplete(callback) {
const { streamCompleteCallbacks } = get();
streamCompleteCallbacks.add(callback);
const state = get();
const newCallbacks = new Set(state.streamCompleteCallbacks);
newCallbacks.add(callback);
set({ streamCompleteCallbacks: newCallbacks });
return function unsubscribe() {
streamCompleteCallbacks.delete(callback);
const currentState = get();
const cleanedCallbacks = new Set(currentState.streamCompleteCallbacks);
cleanedCallbacks.delete(callback);
set({ streamCompleteCallbacks: cleanedCallbacks });
};
},
}));

View File

@@ -16,6 +16,7 @@ export interface ChatContainerProps {
initialPrompt?: string;
className?: string;
onStreamingChange?: (isStreaming: boolean) => void;
onOperationStarted?: () => void;
}
export function ChatContainer({
@@ -24,6 +25,7 @@ export function ChatContainer({
initialPrompt,
className,
onStreamingChange,
onOperationStarted,
}: ChatContainerProps) {
const {
messages,
@@ -38,6 +40,7 @@ export function ChatContainer({
sessionId,
initialMessages,
initialPrompt,
onOperationStarted,
});
useEffect(() => {

View File

@@ -22,6 +22,7 @@ export interface HandlerDependencies {
setIsStreamingInitiated: Dispatch<SetStateAction<boolean>>;
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
sessionId: string;
onOperationStarted?: () => void;
}
export function isRegionBlockedError(chunk: StreamChunk): boolean {
@@ -48,6 +49,15 @@ export function handleTextEnded(
const completedText = deps.streamingChunksRef.current.join("");
if (completedText.trim()) {
deps.setMessages((prev) => {
// Check if this exact message already exists to prevent duplicates
const exists = prev.some(
(msg) =>
msg.type === "message" &&
msg.role === "assistant" &&
msg.content === completedText,
);
if (exists) return prev;
const assistantMessage: ChatMessageData = {
type: "message",
role: "assistant",
@@ -154,6 +164,11 @@ export function handleToolResponse(
}
return;
}
// Trigger polling when operation_started is received
if (responseMessage.type === "operation_started") {
deps.onOperationStarted?.();
}
deps.setMessages((prev) => {
const toolCallIndex = prev.findIndex(
(msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id,
@@ -203,13 +218,24 @@ export function handleStreamEnd(
]);
}
if (completedContent.trim()) {
const assistantMessage: ChatMessageData = {
type: "message",
role: "assistant",
content: completedContent,
timestamp: new Date(),
};
deps.setMessages((prev) => [...prev, assistantMessage]);
deps.setMessages((prev) => {
// Check if this exact message already exists to prevent duplicates
const exists = prev.some(
(msg) =>
msg.type === "message" &&
msg.role === "assistant" &&
msg.content === completedContent,
);
if (exists) return prev;
const assistantMessage: ChatMessageData = {
type: "message",
role: "assistant",
content: completedContent,
timestamp: new Date(),
};
return [...prev, assistantMessage];
});
}
deps.setStreamingChunks([]);
deps.streamingChunksRef.current = [];

View File

@@ -304,6 +304,7 @@ export function parseToolResponse(
if (isAgentArray(agentsData)) {
return {
type: "agent_carousel",
toolId,
toolName: "agent_carousel",
agents: agentsData,
totalCount: parsedResult.total_count as number | undefined,
@@ -316,6 +317,7 @@ export function parseToolResponse(
if (responseType === "execution_started") {
return {
type: "execution_started",
toolId,
toolName: "execution_started",
executionId: (parsedResult.execution_id as string) || "",
agentName: (parsedResult.graph_name as string) || undefined,
@@ -341,6 +343,41 @@ export function parseToolResponse(
timestamp: timestamp || new Date(),
};
}
if (responseType === "operation_started") {
return {
type: "operation_started",
toolName: (parsedResult.tool_name as string) || toolName,
toolId,
operationId: (parsedResult.operation_id as string) || "",
message:
(parsedResult.message as string) ||
"Operation started. You can close this tab.",
timestamp: timestamp || new Date(),
};
}
if (responseType === "operation_pending") {
return {
type: "operation_pending",
toolName: (parsedResult.tool_name as string) || toolName,
toolId,
operationId: (parsedResult.operation_id as string) || "",
message:
(parsedResult.message as string) ||
"Operation in progress. Please wait...",
timestamp: timestamp || new Date(),
};
}
if (responseType === "operation_in_progress") {
return {
type: "operation_in_progress",
toolName: (parsedResult.tool_name as string) || toolName,
toolCallId: (parsedResult.tool_call_id as string) || toolId,
message:
(parsedResult.message as string) ||
"Operation already in progress. Please wait...",
timestamp: timestamp || new Date(),
};
}
if (responseType === "need_login") {
return {
type: "login_needed",

View File

@@ -14,16 +14,40 @@ import {
processInitialMessages,
} from "./helpers";
// Helper to generate deduplication key for a message
function getMessageKey(msg: ChatMessageData): string {
if (msg.type === "message") {
// Don't include timestamp - dedupe by role + content only
// This handles the case where local and server timestamps differ
// Server messages are authoritative, so duplicates from local state are filtered
return `msg:${msg.role}:${msg.content}`;
} else if (msg.type === "tool_call") {
return `toolcall:${msg.toolId}`;
} else if (msg.type === "tool_response") {
return `toolresponse:${(msg as any).toolId}`;
} else if (
msg.type === "operation_started" ||
msg.type === "operation_pending" ||
msg.type === "operation_in_progress"
) {
return `op:${(msg as any).toolId || (msg as any).operationId || (msg as any).toolCallId || ""}:${msg.toolName}`;
} else {
return `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`;
}
}
interface Args {
sessionId: string | null;
initialMessages: SessionDetailResponse["messages"];
initialPrompt?: string;
onOperationStarted?: () => void;
}
export function useChatContainer({
sessionId,
initialMessages,
initialPrompt,
onOperationStarted,
}: Args) {
const [messages, setMessages] = useState<ChatMessageData[]>([]);
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
@@ -73,20 +97,102 @@ export function useChatContainer({
setIsRegionBlockedModalOpen,
sessionId,
setIsStreamingInitiated,
onOperationStarted,
});
setIsStreamingInitiated(true);
const skipReplay = initialMessages.length > 0;
return subscribeToStream(sessionId, dispatcher, skipReplay);
},
[sessionId, stopStreaming, activeStreams, subscribeToStream],
[
sessionId,
stopStreaming,
activeStreams,
subscribeToStream,
onOperationStarted,
],
);
const allMessages = useMemo(
() => [...processInitialMessages(initialMessages), ...messages],
[initialMessages, messages],
// Collect toolIds from completed tool results in initialMessages
// Used to filter out operation messages when their results arrive
const completedToolIds = useMemo(() => {
const processedInitial = processInitialMessages(initialMessages);
const ids = new Set<string>();
for (const msg of processedInitial) {
if (
msg.type === "tool_response" ||
msg.type === "agent_carousel" ||
msg.type === "execution_started"
) {
const toolId = (msg as any).toolId;
if (toolId) {
ids.add(toolId);
}
}
}
return ids;
}, [initialMessages]);
// Clean up local operation messages when their completed results arrive from polling
// This effect runs when completedToolIds changes (i.e., when polling brings new results)
useEffect(
function cleanupCompletedOperations() {
if (completedToolIds.size === 0) return;
setMessages((prev) => {
const filtered = prev.filter((msg) => {
if (
msg.type === "operation_started" ||
msg.type === "operation_pending" ||
msg.type === "operation_in_progress"
) {
const toolId = (msg as any).toolId || (msg as any).toolCallId;
if (toolId && completedToolIds.has(toolId)) {
return false; // Remove - operation completed
}
}
return true;
});
// Only update state if something was actually filtered
return filtered.length === prev.length ? prev : filtered;
});
},
[completedToolIds],
);
// Combine initial messages from backend with local streaming messages,
// Server messages maintain correct order; only append truly new local messages
const allMessages = useMemo(() => {
const processedInitial = processInitialMessages(initialMessages);
// Build a set of keys from server messages for deduplication
const serverKeys = new Set<string>();
for (const msg of processedInitial) {
serverKeys.add(getMessageKey(msg));
}
// Filter local messages: remove duplicates and completed operation messages
const newLocalMessages = messages.filter((msg) => {
// Remove operation messages for completed tools
if (
msg.type === "operation_started" ||
msg.type === "operation_pending" ||
msg.type === "operation_in_progress"
) {
const toolId = (msg as any).toolId || (msg as any).toolCallId;
if (toolId && completedToolIds.has(toolId)) {
return false;
}
}
// Remove messages that already exist in server data
const key = getMessageKey(msg);
return !serverKeys.has(key);
});
// Server messages first (correct order), then new local messages
return [...processedInitial, ...newLocalMessages];
}, [initialMessages, messages, completedToolIds]);
async function sendMessage(
content: string,
isUserMessage: boolean = true,
@@ -118,6 +224,7 @@ export function useChatContainer({
setIsRegionBlockedModalOpen,
sessionId,
setIsStreamingInitiated,
onOperationStarted,
});
try {

View File

@@ -1,7 +1,14 @@
import { Button } from "@/components/atoms/Button/Button";
import { cn } from "@/lib/utils";
import { ArrowUpIcon, StopIcon } from "@phosphor-icons/react";
import {
ArrowUpIcon,
CircleNotchIcon,
MicrophoneIcon,
StopIcon,
} from "@phosphor-icons/react";
import { RecordingIndicator } from "./components/RecordingIndicator";
import { useChatInput } from "./useChatInput";
import { useVoiceRecording } from "./useVoiceRecording";
export interface Props {
onSend: (message: string) => void;
@@ -21,13 +28,36 @@ export function ChatInput({
className,
}: Props) {
const inputId = "chat-input";
const { value, handleKeyDown, handleSubmit, handleChange, hasMultipleLines } =
useChatInput({
onSend,
disabled: disabled || isStreaming,
maxRows: 4,
inputId,
});
const {
value,
setValue,
handleKeyDown: baseHandleKeyDown,
handleSubmit,
handleChange,
hasMultipleLines,
} = useChatInput({
onSend,
disabled: disabled || isStreaming,
maxRows: 4,
inputId,
});
const {
isRecording,
isTranscribing,
elapsedTime,
toggleRecording,
handleKeyDown,
showMicButton,
isInputDisabled,
audioStream,
} = useVoiceRecording({
setValue,
disabled: disabled || isStreaming,
isStreaming,
value,
baseHandleKeyDown,
});
return (
<form onSubmit={handleSubmit} className={cn("relative flex-1", className)}>
@@ -35,8 +65,11 @@ export function ChatInput({
<div
id={`${inputId}-wrapper`}
className={cn(
"relative overflow-hidden border border-neutral-200 bg-white shadow-sm",
"focus-within:border-zinc-400 focus-within:ring-1 focus-within:ring-zinc-400",
"relative overflow-hidden border bg-white shadow-sm",
"focus-within:ring-1",
isRecording
? "border-red-400 focus-within:border-red-400 focus-within:ring-red-400"
: "border-neutral-200 focus-within:border-zinc-400 focus-within:ring-zinc-400",
hasMultipleLines ? "rounded-xlarge" : "rounded-full",
)}
>
@@ -46,48 +79,94 @@ export function ChatInput({
value={value}
onChange={handleChange}
onKeyDown={handleKeyDown}
placeholder={placeholder}
disabled={disabled || isStreaming}
placeholder={
isTranscribing
? "Transcribing..."
: isRecording
? ""
: placeholder
}
disabled={isInputDisabled}
rows={1}
className={cn(
"w-full resize-none overflow-y-auto border-0 bg-transparent text-[1rem] leading-6 text-black",
"placeholder:text-zinc-400",
"focus:outline-none focus:ring-0",
"disabled:text-zinc-500",
hasMultipleLines ? "pb-6 pl-4 pr-4 pt-2" : "pb-4 pl-4 pr-14 pt-4",
hasMultipleLines
? "pb-6 pl-4 pr-4 pt-2"
: showMicButton
? "pb-4 pl-14 pr-14 pt-4"
: "pb-4 pl-4 pr-14 pt-4",
)}
/>
{isRecording && !value && (
<div className="pointer-events-none absolute inset-0 flex items-center justify-center">
<RecordingIndicator
elapsedTime={elapsedTime}
audioStream={audioStream}
/>
</div>
)}
</div>
<span id="chat-input-hint" className="sr-only">
Press Enter to send, Shift+Enter for new line
Press Enter to send, Shift+Enter for new line, Space to record voice
</span>
{isStreaming ? (
<Button
type="button"
variant="icon"
size="icon"
aria-label="Stop generating"
onClick={onStop}
className="absolute bottom-[7px] right-2 border-red-600 bg-red-600 text-white hover:border-red-800 hover:bg-red-800"
>
<StopIcon className="h-4 w-4" weight="bold" />
</Button>
) : (
<Button
type="submit"
variant="icon"
size="icon"
aria-label="Send message"
className={cn(
"absolute bottom-[7px] right-2 border-zinc-800 bg-zinc-800 text-white hover:border-zinc-900 hover:bg-zinc-900",
(disabled || !value.trim()) && "opacity-20",
)}
disabled={disabled || !value.trim()}
>
<ArrowUpIcon className="h-4 w-4" weight="bold" />
</Button>
{showMicButton && (
<div className="absolute bottom-[7px] left-2 flex items-center gap-1">
<Button
type="button"
variant="icon"
size="icon"
aria-label={isRecording ? "Stop recording" : "Start recording"}
onClick={toggleRecording}
disabled={disabled || isTranscribing}
className={cn(
isRecording
? "animate-pulse border-red-500 bg-red-500 text-white hover:border-red-600 hover:bg-red-600"
: isTranscribing
? "border-zinc-300 bg-zinc-100 text-zinc-400"
: "border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
)}
>
{isTranscribing ? (
<CircleNotchIcon className="h-4 w-4 animate-spin" />
) : (
<MicrophoneIcon className="h-4 w-4" weight="bold" />
)}
</Button>
</div>
)}
<div className="absolute bottom-[7px] right-2 flex items-center gap-1">
{isStreaming ? (
<Button
type="button"
variant="icon"
size="icon"
aria-label="Stop generating"
onClick={onStop}
className="border-red-600 bg-red-600 text-white hover:border-red-800 hover:bg-red-800"
>
<StopIcon className="h-4 w-4" weight="bold" />
</Button>
) : (
<Button
type="submit"
variant="icon"
size="icon"
aria-label="Send message"
className={cn(
"border-zinc-800 bg-zinc-800 text-white hover:border-zinc-900 hover:bg-zinc-900",
(disabled || !value.trim() || isRecording) && "opacity-20",
)}
disabled={disabled || !value.trim() || isRecording}
>
<ArrowUpIcon className="h-4 w-4" weight="bold" />
</Button>
)}
</div>
</div>
</form>
);

View File

@@ -0,0 +1,142 @@
"use client";
import { useEffect, useRef, useState } from "react";
interface Props {
stream: MediaStream | null;
barCount?: number;
barWidth?: number;
barGap?: number;
barColor?: string;
minBarHeight?: number;
maxBarHeight?: number;
}
export function AudioWaveform({
stream,
barCount = 24,
barWidth = 3,
barGap = 2,
barColor = "#ef4444", // red-500
minBarHeight = 4,
maxBarHeight = 32,
}: Props) {
const [bars, setBars] = useState<number[]>(() =>
Array(barCount).fill(minBarHeight),
);
const analyserRef = useRef<AnalyserNode | null>(null);
const audioContextRef = useRef<AudioContext | null>(null);
const sourceRef = useRef<MediaStreamAudioSourceNode | null>(null);
const animationRef = useRef<number | null>(null);
useEffect(() => {
if (!stream) {
setBars(Array(barCount).fill(minBarHeight));
return;
}
// Create audio context and analyser
const audioContext = new AudioContext();
const analyser = audioContext.createAnalyser();
analyser.fftSize = 512;
analyser.smoothingTimeConstant = 0.8;
// Connect the stream to the analyser
const source = audioContext.createMediaStreamSource(stream);
source.connect(analyser);
audioContextRef.current = audioContext;
analyserRef.current = analyser;
sourceRef.current = source;
const timeData = new Uint8Array(analyser.frequencyBinCount);
const updateBars = () => {
if (!analyserRef.current) return;
analyserRef.current.getByteTimeDomainData(timeData);
// Distribute time-domain data across bars
// This shows waveform amplitude, making all bars respond to audio
const newBars: number[] = [];
const samplesPerBar = timeData.length / barCount;
for (let i = 0; i < barCount; i++) {
// Sample waveform data for this bar
let maxAmplitude = 0;
const startIdx = Math.floor(i * samplesPerBar);
const endIdx = Math.floor((i + 1) * samplesPerBar);
for (let j = startIdx; j < endIdx && j < timeData.length; j++) {
// Convert to amplitude (distance from center 128)
const amplitude = Math.abs(timeData[j] - 128);
maxAmplitude = Math.max(maxAmplitude, amplitude);
}
// Map amplitude (0-128) to bar height
const normalized = (maxAmplitude / 128) * 255;
const height =
minBarHeight + (normalized / 255) * (maxBarHeight - minBarHeight);
newBars.push(height);
}
setBars(newBars);
animationRef.current = requestAnimationFrame(updateBars);
};
updateBars();
return () => {
if (animationRef.current) {
cancelAnimationFrame(animationRef.current);
}
if (sourceRef.current) {
sourceRef.current.disconnect();
}
if (audioContextRef.current) {
audioContextRef.current.close();
}
analyserRef.current = null;
audioContextRef.current = null;
sourceRef.current = null;
};
}, [stream, barCount, minBarHeight, maxBarHeight]);
const totalWidth = barCount * barWidth + (barCount - 1) * barGap;
return (
<div
className="flex items-center justify-center"
style={{
width: totalWidth,
height: maxBarHeight,
gap: barGap,
}}
>
{bars.map((height, i) => {
const barHeight = Math.max(minBarHeight, height);
return (
<div
key={i}
className="relative"
style={{
width: barWidth,
height: maxBarHeight,
}}
>
<div
className="absolute left-0 rounded-full transition-[height] duration-75"
style={{
width: barWidth,
height: barHeight,
top: "50%",
transform: "translateY(-50%)",
backgroundColor: barColor,
}}
/>
</div>
);
})}
</div>
);
}

View File

@@ -0,0 +1,26 @@
import { formatElapsedTime } from "../helpers";
import { AudioWaveform } from "./AudioWaveform";
type Props = {
elapsedTime: number;
audioStream: MediaStream | null;
};
export function RecordingIndicator({ elapsedTime, audioStream }: Props) {
return (
<div className="flex items-center gap-3">
<AudioWaveform
stream={audioStream}
barCount={20}
barWidth={3}
barGap={2}
barColor="#ef4444"
minBarHeight={4}
maxBarHeight={24}
/>
<span className="min-w-[3ch] text-sm font-medium text-red-500">
{formatElapsedTime(elapsedTime)}
</span>
</div>
);
}

View File

@@ -0,0 +1,6 @@
export function formatElapsedTime(ms: number): string {
const seconds = Math.floor(ms / 1000);
const minutes = Math.floor(seconds / 60);
const remainingSeconds = seconds % 60;
return `${minutes}:${remainingSeconds.toString().padStart(2, "0")}`;
}

View File

@@ -6,7 +6,7 @@ import {
useState,
} from "react";
interface UseChatInputArgs {
interface Args {
onSend: (message: string) => void;
disabled?: boolean;
maxRows?: number;
@@ -18,7 +18,7 @@ export function useChatInput({
disabled = false,
maxRows = 5,
inputId = "chat-input",
}: UseChatInputArgs) {
}: Args) {
const [value, setValue] = useState("");
const [hasMultipleLines, setHasMultipleLines] = useState(false);

View File

@@ -0,0 +1,240 @@
import { useToast } from "@/components/molecules/Toast/use-toast";
import React, {
KeyboardEvent,
useCallback,
useEffect,
useRef,
useState,
} from "react";
const MAX_RECORDING_DURATION = 2 * 60 * 1000; // 2 minutes in ms
interface Args {
setValue: React.Dispatch<React.SetStateAction<string>>;
disabled?: boolean;
isStreaming?: boolean;
value: string;
baseHandleKeyDown: (event: KeyboardEvent<HTMLTextAreaElement>) => void;
}
export function useVoiceRecording({
setValue,
disabled = false,
isStreaming = false,
value,
baseHandleKeyDown,
}: Args) {
const [isRecording, setIsRecording] = useState(false);
const [isTranscribing, setIsTranscribing] = useState(false);
const [error, setError] = useState<string | null>(null);
const [elapsedTime, setElapsedTime] = useState(0);
const mediaRecorderRef = useRef<MediaRecorder | null>(null);
const chunksRef = useRef<Blob[]>([]);
const timerRef = useRef<NodeJS.Timeout | null>(null);
const startTimeRef = useRef<number>(0);
const streamRef = useRef<MediaStream | null>(null);
const isRecordingRef = useRef(false);
const isSupported =
typeof window !== "undefined" &&
!!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
const clearTimer = useCallback(() => {
if (timerRef.current) {
clearInterval(timerRef.current);
timerRef.current = null;
}
}, []);
const cleanup = useCallback(() => {
clearTimer();
if (streamRef.current) {
streamRef.current.getTracks().forEach((track) => track.stop());
streamRef.current = null;
}
mediaRecorderRef.current = null;
chunksRef.current = [];
setElapsedTime(0);
}, [clearTimer]);
const handleTranscription = useCallback(
(text: string) => {
setValue((prev) => {
const trimmedPrev = prev.trim();
if (trimmedPrev) {
return `${trimmedPrev} ${text}`;
}
return text;
});
},
[setValue],
);
const transcribeAudio = useCallback(
async (audioBlob: Blob) => {
setIsTranscribing(true);
setError(null);
try {
const formData = new FormData();
formData.append("audio", audioBlob);
const response = await fetch("/api/transcribe", {
method: "POST",
body: formData,
});
if (!response.ok) {
const data = await response.json().catch(() => ({}));
throw new Error(data.error || "Transcription failed");
}
const data = await response.json();
if (data.text) {
handleTranscription(data.text);
}
} catch (err) {
const message =
err instanceof Error ? err.message : "Transcription failed";
setError(message);
console.error("Transcription error:", err);
} finally {
setIsTranscribing(false);
}
},
[handleTranscription],
);
const stopRecording = useCallback(() => {
if (mediaRecorderRef.current && isRecordingRef.current) {
mediaRecorderRef.current.stop();
isRecordingRef.current = false;
setIsRecording(false);
clearTimer();
}
}, [clearTimer]);
const startRecording = useCallback(async () => {
if (disabled || isRecordingRef.current || isTranscribing) return;
setError(null);
chunksRef.current = [];
try {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
streamRef.current = stream;
const mediaRecorder = new MediaRecorder(stream, {
mimeType: MediaRecorder.isTypeSupported("audio/webm")
? "audio/webm"
: "audio/mp4",
});
mediaRecorderRef.current = mediaRecorder;
mediaRecorder.ondataavailable = (event) => {
if (event.data.size > 0) {
chunksRef.current.push(event.data);
}
};
mediaRecorder.onstop = async () => {
const audioBlob = new Blob(chunksRef.current, {
type: mediaRecorder.mimeType,
});
// Cleanup stream
if (streamRef.current) {
streamRef.current.getTracks().forEach((track) => track.stop());
streamRef.current = null;
}
if (audioBlob.size > 0) {
await transcribeAudio(audioBlob);
}
};
mediaRecorder.start(1000); // Collect data every second
isRecordingRef.current = true;
setIsRecording(true);
startTimeRef.current = Date.now();
// Start elapsed time timer
timerRef.current = setInterval(() => {
const elapsed = Date.now() - startTimeRef.current;
setElapsedTime(elapsed);
// Auto-stop at max duration
if (elapsed >= MAX_RECORDING_DURATION) {
stopRecording();
}
}, 100);
} catch (err) {
console.error("Failed to start recording:", err);
if (err instanceof DOMException && err.name === "NotAllowedError") {
setError("Microphone permission denied");
} else {
setError("Failed to access microphone");
}
cleanup();
}
}, [disabled, isTranscribing, stopRecording, transcribeAudio, cleanup]);
const toggleRecording = useCallback(() => {
if (isRecording) {
stopRecording();
} else {
startRecording();
}
}, [isRecording, startRecording, stopRecording]);
const { toast } = useToast();
useEffect(() => {
if (error) {
toast({
title: "Voice recording failed",
description: error,
variant: "destructive",
});
}
}, [error, toast]);
const handleKeyDown = useCallback(
(event: KeyboardEvent<HTMLTextAreaElement>) => {
if (event.key === " " && !value.trim() && !isTranscribing) {
event.preventDefault();
toggleRecording();
return;
}
baseHandleKeyDown(event);
},
[value, isTranscribing, toggleRecording, baseHandleKeyDown],
);
const showMicButton = isSupported && !isStreaming;
const isInputDisabled = disabled || isStreaming || isTranscribing;
// Cleanup on unmount
useEffect(() => {
return () => {
cleanup();
};
}, [cleanup]);
return {
isRecording,
isTranscribing,
error,
elapsedTime,
startRecording,
stopRecording,
toggleRecording,
isSupported,
handleKeyDown,
showMicButton,
isInputDisabled,
audioStream: streamRef.current,
};
}

View File

@@ -16,6 +16,7 @@ import { AuthPromptWidget } from "../AuthPromptWidget/AuthPromptWidget";
import { ChatCredentialsSetup } from "../ChatCredentialsSetup/ChatCredentialsSetup";
import { ClarificationQuestionsWidget } from "../ClarificationQuestionsWidget/ClarificationQuestionsWidget";
import { ExecutionStartedMessage } from "../ExecutionStartedMessage/ExecutionStartedMessage";
import { PendingOperationWidget } from "../PendingOperationWidget/PendingOperationWidget";
import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
import { NoResultsMessage } from "../NoResultsMessage/NoResultsMessage";
import { ToolCallMessage } from "../ToolCallMessage/ToolCallMessage";
@@ -71,6 +72,9 @@ export function ChatMessage({
isLoginNeeded,
isCredentialsNeeded,
isClarificationNeeded,
isOperationStarted,
isOperationPending,
isOperationInProgress,
} = useChatMessage(message);
const displayContent = getDisplayContent(message, isUser);
@@ -126,10 +130,6 @@ export function ChatMessage({
[displayContent, message],
);
function isLongResponse(content: string): boolean {
return content.split("\n").length > 5;
}
const handleTryAgain = useCallback(() => {
if (message.type !== "message" || !onSendMessage) return;
onSendMessage(message.content, message.role === "user");
@@ -294,6 +294,42 @@ export function ChatMessage({
);
}
// Render operation_started messages (long-running background operations)
if (isOperationStarted && message.type === "operation_started") {
return (
<PendingOperationWidget
status="started"
message={message.message}
toolName={message.toolName}
className={className}
/>
);
}
// Render operation_pending messages (operations in progress when refreshing)
if (isOperationPending && message.type === "operation_pending") {
return (
<PendingOperationWidget
status="pending"
message={message.message}
toolName={message.toolName}
className={className}
/>
);
}
// Render operation_in_progress messages (duplicate request while operation running)
if (isOperationInProgress && message.type === "operation_in_progress") {
return (
<PendingOperationWidget
status="in_progress"
message={message.message}
toolName={message.toolName}
className={className}
/>
);
}
// Render tool response messages (but skip agent_output if it's being rendered inside assistant message)
if (isToolResponse && message.type === "tool_response") {
return (
@@ -358,7 +394,7 @@ export function ChatMessage({
<ArrowsClockwiseIcon className="size-4 text-zinc-600" />
</Button>
)}
{!isUser && isFinalMessage && isLongResponse(displayContent) && (
{!isUser && isFinalMessage && !isStreaming && (
<Button
variant="ghost"
size="icon"

View File

@@ -61,6 +61,7 @@ export type ChatMessageData =
}
| {
type: "agent_carousel";
toolId: string;
toolName: string;
agents: Array<{
id: string;
@@ -74,6 +75,7 @@ export type ChatMessageData =
}
| {
type: "execution_started";
toolId: string;
toolName: string;
executionId: string;
agentName?: string;
@@ -103,6 +105,29 @@ export type ChatMessageData =
message: string;
sessionId: string;
timestamp?: string | Date;
}
| {
type: "operation_started";
toolName: string;
toolId: string;
operationId: string;
message: string;
timestamp?: string | Date;
}
| {
type: "operation_pending";
toolName: string;
toolId: string;
operationId: string;
message: string;
timestamp?: string | Date;
}
| {
type: "operation_in_progress";
toolName: string;
toolCallId: string;
message: string;
timestamp?: string | Date;
};
export function useChatMessage(message: ChatMessageData) {
@@ -124,5 +149,8 @@ export function useChatMessage(message: ChatMessageData) {
isExecutionStarted: message.type === "execution_started",
isInputsNeeded: message.type === "inputs_needed",
isClarificationNeeded: message.type === "clarification_needed",
isOperationStarted: message.type === "operation_started",
isOperationPending: message.type === "operation_pending",
isOperationInProgress: message.type === "operation_in_progress",
};
}

Some files were not shown because too many files have changed in this diff Show More