Compare commits

...

17 Commits

Author SHA1 Message Date
Swifty
af7e8d19fe feat(platform): add beta invite provisioning 2026-03-09 16:38:23 +01:00
Otto
eadc68f2a5 feat(frontend/copilot): move microphone button to right side of input box (#12320)
Requested by @olivia-1421

Moves the microphone/recording button from the left-side tools group to
the right side, next to the submit button. The left side is now reserved
for the attachment/upload (plus) button only.

**Before:** `[ 📎 🎤 ] .................. [ ➤ ]`
**After:**  `[ 📎 ] .................. [ 🎤 ➤ ]`

---
Co-authored-by: Olivia <olivia-1421@users.noreply.github.com>

---------

Co-authored-by: Ubbe <hi@ubbe.dev>
Co-authored-by: Lluis Agusti <hi@llu.lu>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 18:37:02 +08:00
Reinier van der Leer
eca7b5e793 Merge commit from fork 2026-03-08 10:24:44 +01:00
Otto
c304a4937a fix(backend): Handle manual run attempts for triggered agents (#12298)
When a webhook-triggered agent is executed directly (e.g. via Copilot)
without actual webhook data, `GraphExecution.from_db()` crashes with
`KeyError: 'payload'` because it does a hard key access on
`exec.input_data["payload"]` for webhook blocks.

This caused 232 Sentry events (AUTOGPT-SERVER-821) and multiple
INCOMPLETE graph executions due to retries.

**Changes:**

1. **Defensive fix in `from_db()`** — use `.get("payload")` instead of
`["payload"]` to handle missing keys gracefully (matches existing
pattern for input blocks using `.get("value")`)

2. **Upfront refusal in `_construct_starting_node_execution_input()`** —
refuse execution of webhook/webhook_manual blocks when no payload is
provided. The check is placed after `nodes_input_masks` application, so
legitimate webhook triggers (which inject payload via
`nodes_input_masks`) pass through fine.

Resolves [SENTRY-1113: Copilot is able to manually initiate runs for
triggered agents (which
fails)](https://linear.app/autogpt/issue/SENTRY-1113/copilot-is-able-to-manually-initiate-runs-for-triggered-agents-which)

---
Co-authored-by: Reinier van der Leer (@Pwuts) <pwuts@agpt.co>
2026-03-06 20:47:51 +00:00
Zamil Majdy
8cfabcf4fd refactor(backend/copilot): centralize prompt building in prompting.py (#12324)
## Summary

Centralizes all prompt building logic into a new
`backend/copilot/prompting.py` module with clear SDK vs baseline and
local vs E2B distinctions.

### Key Changes

**New `prompting.py` module:**
- `get_sdk_supplement(use_e2b, cwd)` - For SDK mode (NO tool docs -
Claude gets schemas automatically)
- `get_baseline_supplement(use_e2b, cwd)` - For baseline mode (WITH
auto-generated tool docs from TOOL_REGISTRY)
- Handles local/E2B storage differences

**SDK mode (`sdk/service.py`):**
- Removed 165+ lines of duplicate constants
- Now imports and uses `get_sdk_supplement()`
- Cleaner, more maintainable

**Baseline mode (`baseline/service.py`):**
- Now appends `get_baseline_supplement()` to system prompt
- Baseline mode finally gets tool documentation!

**Enhanced tool descriptions:**
- `create_agent`: Added feedback loop workflow (suggested_goal,
clarifying_questions)
- `run_mcp_tool`: Added known server URLs, 2-step workflow, auth
handling

**Tests:**
- Updated to verify SDK excludes tool docs, baseline includes them
- All existing tests pass

### Architecture Benefits

 Single source of truth for prompt supplements
 Clear SDK vs baseline distinction (SDK doesn't need tool docs)
 Clear local vs E2B distinction (storage systems)
 Easy to maintain and update
 Eliminates code duplication

## Test plan

- [x] Unit tests pass (TestPromptSupplement class)
- [x] SDK mode excludes tool documentation
- [x] Baseline mode includes tool documentation
- [x] E2B vs local mode differences handled correctly
2026-03-06 18:56:20 +00:00
Zamil Majdy
7bf407b66c Merge branch 'master' of github.com:Significant-Gravitas/AutoGPT into dev 2026-03-07 02:01:41 +07:00
Zamil Majdy
7ead4c040f hotfix(backend/copilot): capture tool results in transcript (#12323)
## Summary
- Fixes tool results not being captured in the CoPilot transcript during
SDK-based streaming
- Adds `transcript_builder.add_user_message()` call with `tool_result`
content block when a `StreamToolOutputAvailable` event is received
- Ensures transcript accurately reflects the full conversation including
tool outputs, which is critical for Langfuse tracing and debugging

## Context
After the transcript refactor in #12318, tool call results from the SDK
streaming loop were not being recorded in the transcript. This meant
Langfuse traces were missing tool outputs, making it hard to debug agent
behavior.

## Test plan
- [ ] Verify CoPilot conversation with tool calls captures tool results
in Langfuse traces
- [ ] Verify transcript includes tool_result content blocks after tool
execution
2026-03-06 18:58:48 +00:00
Abhimanyu Yadav
0f813f1bf9 feat(copilot): Add folder management tools to CoPilot (#12290)
Adds folder management capabilities to the CoPilot, allowing users to
organize agents into folders directly from the chat interface.

<img width="823" height="356" alt="Screenshot 2026-03-05 at 5 26 30 PM"
src="https://github.com/user-attachments/assets/4c55f926-1e71-488f-9eb6-fca87c4ab01b"
/>
<img width="797" height="150" alt="Screenshot 2026-03-05 at 5 28 40 PM"
src="https://github.com/user-attachments/assets/5c9c6f8b-57ac-4122-b17d-b9f091bb7c4e"
/>
<img width="763" height="196" alt="Screenshot 2026-03-05 at 5 28 36 PM"
src="https://github.com/user-attachments/assets/d1b22b5d-921d-44ac-90e8-a5820bb3146d"
/>
<img width="756" height="199" alt="Screenshot 2026-03-05 at 5 30 17 PM"
src="https://github.com/user-attachments/assets/40a59748-f42e-4521-bae0-cc786918a9b5"
/>

### Changes

**Backend -- 6 new CoPilot tools** (`manage_folders.py`):
- `create_folder` -- Create folders with optional parent, icon, and
color
- `list_folders` -- List folder tree or children of a specific folder,
with optional `include_agents` to show agents inside each folder
- `update_folder` -- Rename or change icon/color
- `move_folder` -- Reparent a folder or move to root
- `delete_folder` -- Soft-delete (agents moved to root, not deleted)
- `move_agents_to_folder` -- Bulk-move agents into a folder or back to
root

**Backend -- DatabaseManager RPC registration**:
- Registered all 7 folder DB functions (`create_folder`, `list_folders`,
`get_folder_tree`, `update_folder`, `move_folder`, `delete_folder`,
`bulk_move_agents_to_folder`) in `DatabaseManager` and
`DatabaseManagerAsyncClient` so they work via RPC in the CoPilotExecutor
process
- `manage_folders.py` uses `db_accessors.library_db()` pattern
(consistent with all other copilot tools) instead of direct Prisma
imports

**Backend -- folder_id threading**:
- `create_agent` and `customize_agent` tools accept optional `folder_id`
to save agents directly into a folder
- `save_agent_to_library` -> `create_graph_in_library` ->
`create_library_agent` pipeline passes `folder_id` through
- `create_library_agent` refactored from `asyncio.gather` to sequential
loop to support conditional `folderId` assignment on the main graph only
(not sub-graphs)

**Backend -- system prompt and models**:
- Added folder tool descriptions and usage guidance to Otto's system
prompt
- Added `FolderAgentSummary` model for lightweight agent info in folder
listings
- Added 6 `ResponseType` enum values and corresponding Pydantic response
models (`FolderInfo`, `FolderTreeInfo`, `FolderCreatedResponse`, etc.)

**Frontend -- FolderTool UI component**:
- `FolderTool.tsx` -- Renders folder operations in chat using the
`file-tree` molecule component for tree view, with `FileIcon` for agents
and `FolderIcon` for folders (both `text-neutral-600`)
- `helpers.ts` -- Type guards, output parsing, animation text helpers,
and `FolderAgentSummary` type
- `MessagePartRenderer.tsx` -- Routes 6 folder tool types to
`FolderTool` component
- Flat folder list view shows agents inside `FolderCard` when
`include_agents` is set

**Frontend -- file-tree molecule**:
- Fixed 3 pre-existing lint errors in `file-tree.tsx` (unused `ref`,
`handleSelect`, `className` params)
- Updated tree indicator line color from `bg-neutral-100` to
`bg-neutral-400` for visibility
- Added `file-tree.stories.tsx` with 5 stories: Default, AllExpanded,
FoldersOnly, WithInitialSelection, NoIndicator
- Added `ui/scroll-area.tsx` (dependency of file-tree, was missing from
non-legacy ui folder)

### Checklist

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Create a folder via copilot chat ("create a folder called
Marketing")
  - [x] List folders ("show me my folders")
- [x] List folders with agents ("show me my folders and the agents in
them")
- [x] Update folder name/icon/color ("rename Marketing folder to Sales")
- [x] Move folder to a different parent ("move Sales into the Projects
folder")
  - [x] Delete a folder and verify agents move to root
- [x] Move agents into a folder ("put my newsletter agent in the
Marketing folder")
- [x] Create agent with folder_id ("create a scraper agent and save it
in my Tools folder")
- [x] Verify FolderTool UI renders loading, success, error, and empty
states correctly
- [x] Verify folder tree renders nested folders with file-tree component
- [x] Verify agents appear as FileIcon nodes in tree view when
include_agents is true
  - [x] Verify file-tree storybook stories render correctly
2026-03-06 14:59:03 +00:00
Reinier van der Leer
aa08063939 refactor(backend/db): Improve & clean up Marketplace DB layer & API (#12284)
These changes were part of #12206, but here they are separately for
easier review.
This is all primarily to make the v2 API (#11678) work possible/easier.

### Changes 🏗️

- Fix relations between `Profile`, `StoreListing`, and `AgentGraph`
- Redefine `StoreSubmission` view with more efficient joins (100x
speed-up on dev DB) and more consistent field names
- Clean up query functions in `store/db.py`
- Clean up models in `store/model.py`
- Add missing fields to `StoreAgent` and `StoreSubmission` views
- Rename ambiguous `agent_id` -> `graph_id`
- Clean up API route definitions & docs in `store/routes.py`
  - Make routes more consistent
- Avoid collision edge-case between `/agents/{username}/{agent_name}`
and `/agents/{store_listing_version_id}/*`
- Replace all usages of legacy `BackendAPI` for store endpoints with
generated client
- Remove scope requirements on public store endpoints in v1 external API

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Test all Marketplace views (including admin views)
    - [x] Download an agent from the marketplace
  - [x] Submit an agent to the Marketplace
  - [x] Approve/reject Marketplace submission
2026-03-06 14:38:12 +00:00
Zamil Majdy
bde6a4c0df Merge branch 'master' of github.com:Significant-Gravitas/AutoGPT into dev
# Conflicts:
#	autogpt_platform/backend/backend/copilot/sdk/service.py
2026-03-06 21:07:37 +07:00
Zamil Majdy
d56452898a hotfix(backend/copilot): refactor transcript to SDK-based atomic full-context model (#12318)
## Summary

Major refactor to eliminate CLI transcript race conditions and simplify
the codebase by building transcripts directly from SDK messages instead
of reading CLI files.

## Problem

The previous approach had race conditions:
- SDK reads CLI transcript file during stop hook
- CLI may not have finished writing → incomplete transcript
- Complex merge logic to detect and fix incomplete writes
- ~200 lines of synthetic entry detection and merge code

## Solution

**Atomic Full-Context Transcript Model:**
- Build transcript from SDK messages during streaming
(`TranscriptBuilder`)
- Each upload REPLACES the previous transcript entirely (atomic)
- No CLI file reading → no race conditions
- Eliminates all merge complexity

## Key Changes

### Core Refactor
- **NEW**: `transcript_builder.py` - Build JSONL from SDK messages
during streaming
- **SIMPLIFIED**: `transcript.py` - Removed merge logic, simplified
upload/download
- **SIMPLIFIED**: `service.py` - Use TranscriptBuilder, removed stop
hook callback
- **CLEANED**: `security_hooks.py` - Removed `on_stop` parameter

### Performance & Code Quality
- **orjson migration**: Use `backend.util.json` (2-3x faster than
stdlib)
- Added `fallback` parameter to `json.loads()` for cleaner error
handling
- Moved SDK imports to top-level per code style guidelines

### Bug Fixes
- Fixed garbage collection bug in background task handling
- Fixed double upload bug in timeout handling  
- Downgraded PII-risk logging from WARNING to DEBUG
- Added 30s timeout to prevent session lock hang

## Code Removed (~200 lines)

- `merge_with_previous_transcript()` - No longer needed
- `read_transcript_file()` - No longer needed
- `CapturedTranscript` dataclass - No longer needed
- `_on_stop()` callback - No longer needed
- Synthetic entry detection logic - No longer needed
- Manual append/merge logic in finally block - No longer needed

## Testing

-  All transcript tests passing (24/24)
-  Verified with real session logs showing proper transcript growth
-  Verified with Langfuse traces showing proper turn tracking (1-8)

## Transcript Growth Pattern

From session logs:
- **Turn 1**: 2 entries (initial)
- **Turn 2**: 5 entries (+3), 2257B uploaded
- **Turn N**: ~2N entries (linear growth)

Each upload is the **complete atomic state** - always REPLACES, never
incremental.

## Files Changed

```
backend/copilot/sdk/transcript_builder.py (NEW)   | +140 lines
backend/copilot/sdk/transcript.py                  | -198, +125 lines  
backend/copilot/sdk/service.py                     | -214, +160 lines
backend/copilot/sdk/security_hooks.py              | -33, +10 lines
backend/copilot/sdk/transcript_test.py             | -85, +36 lines
backend/util/json.py                               | +45 lines
```

**Net result**: -200 lines, more reliable, faster JSON operations.

## Migration Notes

This is a **breaking change** for any code that:
- Directly calls `merge_with_previous_transcript()` or
`read_transcript_file()`
- Relies on incremental transcript uploads
- Expects stop hook callbacks

All internal usage has been updated.

---

@ntindle - Tagging for autogpt-reviewer
2026-03-06 21:03:49 +07:00
Ubbe
7507240177 feat(copilot): collapse repeated tool calls and fix stream stuck on completion (#12282)
## Summary
- **Frontend:** Group consecutive completed generic tool parts into
collapsible summary rows with a "Reasoning" collapse for finalized
messages. Merge consecutive assistant messages on hydration to avoid
split bubbles. Extract GenericTool helpers. Add `reconnectExhausted`
state and a brief delay before refetching session to reduce stale
`active_stream` reconnect cycles.
- **Backend:** Make transcript upload fire-and-forget instead of
blocking the generator exit. The 30s upload timeout in
`_try_upload_transcript` was delaying `mark_session_completed()`,
keeping the SSE stream alive with only heartbeats after the LLM had
finished — causing the UI to stay stuck in "streaming" state.

## Test plan
- [ ] Send a message in Copilot that triggers multiple tool calls —
verify they collapse into a grouped summary row once completed
- [ ] Verify the final text response appears below the collapsed
reasoning section
- [ ] Confirm the stream properly closes after the agent finishes (no
stuck "Stop" button)
- [ ] Refresh mid-stream and verify reconnection works correctly
- [ ] Click Stop during streaming — verify the UI becomes responsive
immediately

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-06 21:21:59 +08:00
Abhimanyu Yadav
d7c3f5b8fc fix(frontend): bypass Next.js proxy for file uploads to fix 413 error (#12315)
## Summary
- File uploads routed through the Next.js API proxy (`/api/proxy/...`)
fail with HTTP 413 for files >4.5MB due to Vercel's serverless function
body size limit
- Created shared `uploadFileDirect` utility (`src/lib/direct-upload.ts`)
that uploads files directly from the browser to the Python backend,
bypassing the proxy entirely
- Updated `useWorkspaceUpload` to use direct upload instead of the
generated hook (which went through the proxy)
- Deduplicated the copilot page's inline upload logic to use the same
shared utility

## Changes 🏗️
- **New**: `src/lib/direct-upload.ts` — shared utility for
direct-to-backend file uploads (up to 256MB)
- **Updated**: `useWorkspaceUpload.ts` — replaced proxy-based generated
hook with `uploadFileDirect`
- **Updated**: `useCopilotPage.ts` — replaced inline upload logic with
shared `uploadFileDirect`, removed unused imports

## Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Upload a file >5MB via workspace file input (e.g. in agent
builder) — should succeed without 413
  - [x] Upload a file >5MB via copilot chat — should succeed without 413
  - [x] Upload a small file (<1MB) via both paths — should still work
  - [x] Verify file delete still works from workspace file input
2026-03-06 12:20:18 +00:00
Otto
3e108a813a fix(backend): Use db_manager for workspace in add_graph_execution (#12312)
When `add_graph_execution` is called from a context where the global
Prisma client isn't connected (e.g. CoPilot tools, external API), the
call to `get_or_create_workspace(user_id)` crashes with
`ClientNotConnectedError` because it directly accesses
`UserWorkspace.prisma()`.

The fix adds `workspace_db` to the existing `if prisma.is_connected()`
fallback pattern, consistent with how all other DB calls in the function
already work.

**Sentry:** AUTOGPT-SERVER-83T (and ~15 related issues going back to Jan
2026)

---
Co-authored-by: Reinier van der Leer (@Pwuts) <pwuts@agpt.co>

Co-authored-by: Reinier van der Leer (@Pwuts) <pwuts@agpt.co>
2026-03-06 08:48:15 +01:00
Krzysztof Czerwinski
08c49a78f8 feat(copilot): UX improvements (#12258)
CoPilot conversation UX improvements (SECRT-2055):

1. **Rename conversations** — Inline rename via the session dropdown
menu. New `PATCH /sessions/{session_id}/title` endpoint with server-side
validation (rejects blank/whitespace-only titles, normalizes
whitespace). Pressing Enter or clicking away submits; Escape cancels
without submitting.

2. **New Chat button moved to top & sticky** — The 'New Chat' button is
now at the top of the sidebar (under 'Your chats') instead of the
footer, and stays fixed — only the session list below it scrolls. A
subtle shadow separator mirrors the original footer style.

3. **Auto-generated title appears live** — After the first message in a
new chat, the sidebar polls for the backend-generated title and animates
it in smoothly once available. The backend also guards against
auto-title overwriting a user-set title.

4. **External Link popup redesign** — Replaced the CSS-hacked external
link confirmation dialog with a proper AutoGPT `Dialog` component using
the design system (`Button`, `Text`, `Dialog`). Removed the old
`globals.css` workaround.

<img width="321" height="263" alt="Screenshot 2026-03-03 at 6 31 50 pm"
src="https://github.com/user-attachments/assets/3cdd1c6f-cca6-4f16-8165-15a1dc2d53f7"
/>

<img width="374" height="74" alt="Screenshot 2026-03-02 at 6 39 07 pm"
src="https://github.com/user-attachments/assets/6f9fc953-5fa7-4469-9eab-7074e7604519"
/>

<img width="548" height="293" alt="Screenshot 2026-03-02 at 6 36 28 pm"
src="https://github.com/user-attachments/assets/0f34683b-7281-4826-ac6f-ac7926e67854"
/>

### Changes 🏗️

**Backend:**
- `routes.py`: Added `PATCH /sessions/{session_id}/title` endpoint with
`UpdateSessionTitleRequest` Pydantic model — validates non-blank title,
normalizes whitespace, returns 404 vs 500 correctly
- `routes_test.py`: New test file — 7 test cases covering success,
whitespace trimming, blank rejection (422), not found (404), internal
failure (500)
- `service.py`: Auto-title generation now checks if a user-set title
already exists before overwriting
- `openapi.json`: Updated with new endpoint schema

**Frontend:**
- `ChatSidebar.tsx`: Inline rename (Enter/blur submits, Escape cancels
via ref flag); "New Chat" button sticky at top with shadow separator;
session title animates when auto-generated title appears
(`AnimatePresence`)
- `useCopilotPage.ts`: Polls for auto-generated title after stream ends,
stops as soon as title appears in cache
- `MobileDrawer.tsx`: Updated to match sidebar layout changes
- `DeleteChatDialog.tsx`: Removed redundant `onClose` prop (controlled
Dialog already handles close)
- `message.tsx`: Added `ExternalLinkModal` using AutoGPT design system;
removed redundant `onClose` prop
- `globals.css`: Removed old CSS hack for external link modal

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Create a new chat, send a message — verify auto-generated title
appears in sidebar without refresh
- [x] Rename a chat via dropdown — Enter submits, Escape reverts, blank
title rejected
- [x] Rename a chat, then send another message — verify user title is
not overwritten by auto-title
- [x] With many chats, scroll the sidebar — verify "New Chat" button
stays fixed at top
- [x] Click an external link in a message — verify the new dialog
appears with AutoGPT styling

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-06 06:01:41 +00:00
Zamil Majdy
0b9e0665dd Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT 2026-03-06 02:32:36 +07:00
Zamil Majdy
f6f268a1f0 Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into HEAD 2026-03-06 02:29:56 +07:00
148 changed files with 10047 additions and 4370 deletions

View File

@@ -1,7 +1,7 @@
import logging
import urllib.parse
from collections import defaultdict
from typing import Annotated, Any, Literal, Optional, Sequence
from typing import Annotated, Any, Optional, Sequence
from fastapi import APIRouter, Body, HTTPException, Security
from prisma.enums import AgentExecutionStatus, APIKeyPermission
@@ -9,9 +9,10 @@ from pydantic import BaseModel, Field
from typing_extensions import TypedDict
import backend.api.features.store.cache as store_cache
import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model
import backend.blocks
from backend.api.external.middleware import require_permission
from backend.api.external.middleware import require_auth, require_permission
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data import user as user_db
@@ -230,13 +231,13 @@ async def get_graph_execution_results(
@v1_router.get(
path="/store/agents",
tags=["store"],
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
response_model=store_model.StoreAgentsResponse,
)
async def get_store_agents(
featured: bool = False,
creator: str | None = None,
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
sorted_by: store_db.StoreAgentsSortOptions | None = None,
search_query: str | None = None,
category: str | None = None,
page: int = 1,
@@ -278,7 +279,7 @@ async def get_store_agents(
@v1_router.get(
path="/store/agents/{username}/{agent_name}",
tags=["store"],
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
response_model=store_model.StoreAgentDetails,
)
async def get_store_agent(
@@ -306,13 +307,13 @@ async def get_store_agent(
@v1_router.get(
path="/store/creators",
tags=["store"],
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
response_model=store_model.CreatorsResponse,
)
async def get_store_creators(
featured: bool = False,
search_query: str | None = None,
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
page: int = 1,
page_size: int = 20,
) -> store_model.CreatorsResponse:
@@ -348,7 +349,7 @@ async def get_store_creators(
@v1_router.get(
path="/store/creators/{username}",
tags=["store"],
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
response_model=store_model.CreatorDetails,
)
async def get_store_creator(

View File

@@ -1,4 +1,8 @@
from pydantic import BaseModel
from datetime import datetime
from typing import Any, Literal, Optional
import prisma.enums
from pydantic import BaseModel, EmailStr
from backend.data.model import UserTransaction
from backend.util.models import Pagination
@@ -14,3 +18,42 @@ class UserHistoryResponse(BaseModel):
class AddUserCreditsResponse(BaseModel):
new_balance: int
transaction_key: str
class CreateInvitedUserRequest(BaseModel):
email: EmailStr
name: Optional[str] = None
class InvitedUserResponse(BaseModel):
id: str
email: str
status: prisma.enums.InvitedUserStatus
auth_user_id: Optional[str] = None
name: Optional[str] = None
tally_understanding: Optional[dict[str, Any]] = None
tally_status: prisma.enums.TallyComputationStatus
tally_computed_at: Optional[datetime] = None
tally_error: Optional[str] = None
created_at: datetime
updated_at: datetime
class InvitedUsersResponse(BaseModel):
invited_users: list[InvitedUserResponse]
class BulkInvitedUserRowResponse(BaseModel):
row_number: int
email: Optional[str] = None
name: Optional[str] = None
status: Literal["CREATED", "SKIPPED", "ERROR"]
message: str
invited_user: Optional[InvitedUserResponse] = None
class BulkInvitedUsersResponse(BaseModel):
created_count: int
skipped_count: int
error_count: int
results: list[BulkInvitedUserRowResponse]

View File

@@ -24,14 +24,13 @@ router = fastapi.APIRouter(
@router.get(
"/listings",
summary="Get Admin Listings History",
response_model=store_model.StoreListingsWithVersionsResponse,
)
async def get_admin_listings_with_versions(
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
search: typing.Optional[str] = None,
page: int = 1,
page_size: int = 20,
):
) -> store_model.StoreListingsWithVersionsAdminViewResponse:
"""
Get store listings with their version history for admins.
@@ -45,36 +44,26 @@ async def get_admin_listings_with_versions(
page_size: Number of items per page
Returns:
StoreListingsWithVersionsResponse with listings and their versions
Paginated listings with their versions
"""
try:
listings = await store_db.get_admin_listings_with_versions(
status=status,
search_query=search,
page=page,
page_size=page_size,
)
return listings
except Exception as e:
logger.exception("Error getting admin listings with versions: %s", e)
return fastapi.responses.JSONResponse(
status_code=500,
content={
"detail": "An error occurred while retrieving listings with versions"
},
)
listings = await store_db.get_admin_listings_with_versions(
status=status,
search_query=search,
page=page,
page_size=page_size,
)
return listings
@router.post(
"/submissions/{store_listing_version_id}/review",
summary="Review Store Submission",
response_model=store_model.StoreSubmission,
)
async def review_submission(
store_listing_version_id: str,
request: store_model.ReviewSubmissionRequest,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
) -> store_model.StoreSubmissionAdminView:
"""
Review a store listing submission.
@@ -84,31 +73,24 @@ async def review_submission(
user_id: Authenticated admin user performing the review
Returns:
StoreSubmission with updated review information
StoreSubmissionAdminView with updated review information
"""
try:
already_approved = await store_db.check_submission_already_approved(
store_listing_version_id=store_listing_version_id,
)
submission = await store_db.review_store_submission(
store_listing_version_id=store_listing_version_id,
is_approved=request.is_approved,
external_comments=request.comments,
internal_comments=request.internal_comments or "",
reviewer_id=user_id,
)
already_approved = await store_db.check_submission_already_approved(
store_listing_version_id=store_listing_version_id,
)
submission = await store_db.review_store_submission(
store_listing_version_id=store_listing_version_id,
is_approved=request.is_approved,
external_comments=request.comments,
internal_comments=request.internal_comments or "",
reviewer_id=user_id,
)
state_changed = already_approved != request.is_approved
# Clear caches when the request is approved as it updates what is shown on the store
if state_changed:
store_cache.clear_all_caches()
return submission
except Exception as e:
logger.exception("Error reviewing submission: %s", e)
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while reviewing the submission"},
)
state_changed = already_approved != request.is_approved
# Clear caches whenever approval state changes, since store visibility can change
if state_changed:
store_cache.clear_all_caches()
return submission
@router.get(

View File

@@ -0,0 +1,143 @@
import logging
from autogpt_libs.auth import get_user_id, requires_admin_user
from fastapi import APIRouter, File, Security, UploadFile
from backend.data.invited_user import (
BulkInvitedUsersResult,
InvitedUserRecord,
bulk_create_invited_users_from_file,
create_invited_user,
list_invited_users,
retry_invited_user_tally,
revoke_invited_user,
)
from .model import (
BulkInvitedUserRowResponse,
BulkInvitedUsersResponse,
CreateInvitedUserRequest,
InvitedUserResponse,
InvitedUsersResponse,
)
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/admin",
tags=["users", "admin"],
dependencies=[Security(requires_admin_user)],
)
def _to_response(invited_user: InvitedUserRecord) -> InvitedUserResponse:
return InvitedUserResponse(**invited_user.model_dump())
def _to_bulk_response(result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
return BulkInvitedUsersResponse(
created_count=result.created_count,
skipped_count=result.skipped_count,
error_count=result.error_count,
results=[
BulkInvitedUserRowResponse(
row_number=row.row_number,
email=row.email,
name=row.name,
status=row.status,
message=row.message,
invited_user=(
_to_response(row.invited_user)
if row.invited_user is not None
else None
),
)
for row in result.results
],
)
@router.get(
"/invited-users",
response_model=InvitedUsersResponse,
summary="List Invited Users",
)
async def get_invited_users(
admin_user_id: str = Security(get_user_id),
) -> InvitedUsersResponse:
logger.info("Admin user %s requested invited users", admin_user_id)
invited_users = await list_invited_users()
return InvitedUsersResponse(
invited_users=[_to_response(invited_user) for invited_user in invited_users]
)
@router.post(
"/invited-users",
response_model=InvitedUserResponse,
summary="Create Invited User",
)
async def create_invited_user_route(
request: CreateInvitedUserRequest,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info(
"Admin user %s created invited user for %s",
admin_user_id,
request.email,
)
invited_user = await create_invited_user(request.email, request.name)
return _to_response(invited_user)
@router.post(
"/invited-users/bulk",
response_model=BulkInvitedUsersResponse,
summary="Bulk Create Invited Users",
operation_id="postV2BulkCreateInvitedUsers",
)
async def bulk_create_invited_users_route(
file: UploadFile = File(...),
admin_user_id: str = Security(get_user_id),
) -> BulkInvitedUsersResponse:
logger.info(
"Admin user %s bulk invited users from %s",
admin_user_id,
file.filename or "<unnamed>",
)
content = await file.read()
result = await bulk_create_invited_users_from_file(file.filename, content)
return _to_bulk_response(result)
@router.post(
"/invited-users/{invited_user_id}/revoke",
response_model=InvitedUserResponse,
summary="Revoke Invited User",
)
async def revoke_invited_user_route(
invited_user_id: str,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info("Admin user %s revoked invited user %s", admin_user_id, invited_user_id)
invited_user = await revoke_invited_user(invited_user_id)
return _to_response(invited_user)
@router.post(
"/invited-users/{invited_user_id}/retry-tally",
response_model=InvitedUserResponse,
summary="Retry Invited User Tally",
)
async def retry_invited_user_tally_route(
invited_user_id: str,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info(
"Admin user %s retried Tally seed for invited user %s",
admin_user_id,
invited_user_id,
)
invited_user = await retry_invited_user_tally(invited_user_id)
return _to_response(invited_user)

View File

@@ -0,0 +1,165 @@
from datetime import datetime, timezone
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import prisma.enums
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from backend.data.invited_user import (
BulkInvitedUserRowResult,
BulkInvitedUsersResult,
InvitedUserRecord,
)
from .user_admin_routes import router as user_admin_router
app = fastapi.FastAPI()
app.include_router(user_admin_router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def _sample_invited_user() -> InvitedUserRecord:
now = datetime.now(timezone.utc)
return InvitedUserRecord(
id="invite-1",
email="invited@example.com",
status=prisma.enums.InvitedUserStatus.INVITED,
auth_user_id=None,
name="Invited User",
tally_understanding=None,
tally_status=prisma.enums.TallyComputationStatus.PENDING,
tally_computed_at=None,
tally_error=None,
created_at=now,
updated_at=now,
)
def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
return BulkInvitedUsersResult(
created_count=1,
skipped_count=1,
error_count=0,
results=[
BulkInvitedUserRowResult(
row_number=1,
email="invited@example.com",
name=None,
status="CREATED",
message="Invite created",
invited_user=_sample_invited_user(),
),
BulkInvitedUserRowResult(
row_number=2,
email="duplicate@example.com",
name=None,
status="SKIPPED",
message="An invited user with this email already exists",
invited_user=None,
),
],
)
def test_get_invited_users(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.list_invited_users",
AsyncMock(return_value=[_sample_invited_user()]),
)
response = client.get("/admin/invited-users")
assert response.status_code == 200
data = response.json()
assert len(data["invited_users"]) == 1
assert data["invited_users"][0]["email"] == "invited@example.com"
assert data["invited_users"][0]["status"] == "INVITED"
def test_create_invited_user(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.create_invited_user",
AsyncMock(return_value=_sample_invited_user()),
)
response = client.post(
"/admin/invited-users",
json={"email": "invited@example.com", "name": "Invited User"},
)
assert response.status_code == 200
data = response.json()
assert data["email"] == "invited@example.com"
assert data["name"] == "Invited User"
def test_bulk_create_invited_users(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.bulk_create_invited_users_from_file",
AsyncMock(return_value=_sample_bulk_invited_users_result()),
)
response = client.post(
"/admin/invited-users/bulk",
files={
"file": ("invites.txt", b"invited@example.com\nduplicate@example.com\n")
},
)
assert response.status_code == 200
data = response.json()
assert data["created_count"] == 1
assert data["skipped_count"] == 1
assert data["results"][0]["status"] == "CREATED"
assert data["results"][1]["status"] == "SKIPPED"
def test_revoke_invited_user(
mocker: pytest_mock.MockerFixture,
) -> None:
revoked = _sample_invited_user().model_copy(
update={"status": prisma.enums.InvitedUserStatus.REVOKED}
)
mocker.patch(
"backend.api.features.admin.user_admin_routes.revoke_invited_user",
AsyncMock(return_value=revoked),
)
response = client.post("/admin/invited-users/invite-1/revoke")
assert response.status_code == 200
assert response.json()["status"] == "REVOKED"
def test_retry_invited_user_tally(
mocker: pytest_mock.MockerFixture,
) -> None:
retried = _sample_invited_user().model_copy(
update={"tally_status": prisma.enums.TallyComputationStatus.RUNNING}
)
mocker.patch(
"backend.api.features.admin.user_admin_routes.retry_invited_user_tally",
AsyncMock(return_value=retried),
)
response = client.post("/admin/invited-users/invite-1/retry-tally")
assert response.status_code == 200
assert response.json()["tally_status"] == "RUNNING"

View File

@@ -11,7 +11,7 @@ from autogpt_libs import auth
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from prisma.models import UserWorkspaceFile
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
@@ -25,6 +25,7 @@ from backend.copilot.model import (
delete_chat_session,
get_chat_session,
get_user_sessions,
update_session_title,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.models import (
@@ -141,6 +142,20 @@ class CancelSessionResponse(BaseModel):
reason: str | None = None
class UpdateSessionTitleRequest(BaseModel):
"""Request model for updating a session's title."""
title: str
@field_validator("title")
@classmethod
def title_must_not_be_blank(cls, v: str) -> str:
stripped = v.strip()
if not stripped:
raise ValueError("Title must not be blank")
return stripped
# ========== Routes ==========
@@ -264,6 +279,43 @@ async def delete_session(
return Response(status_code=204)
@router.patch(
"/sessions/{session_id}/title",
summary="Update session title",
dependencies=[Security(auth.requires_user)],
status_code=200,
responses={404: {"description": "Session not found or access denied"}},
)
async def update_session_title_route(
session_id: str,
request: UpdateSessionTitleRequest,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> dict:
"""
Update the title of a chat session.
Allows the user to rename their chat session.
Args:
session_id: The session ID to update.
request: Request body containing the new title.
user_id: The authenticated user's ID.
Returns:
dict: Status of the update.
Raises:
HTTPException: 404 if session not found or not owned by user.
"""
success = await update_session_title(session_id, user_id, request.title)
if not success:
raise HTTPException(
status_code=404,
detail=f"Session {session_id} not found or access denied",
)
return {"status": "ok"}
@router.get(
"/sessions/{session_id}",
)
@@ -753,7 +805,6 @@ async def resume_session_stream(
@router.patch(
"/sessions/{session_id}/assign-user",
dependencies=[Security(auth.requires_user)],
status_code=200,
)
async def session_assign_user(
session_id: str,

View File

@@ -1,4 +1,6 @@
"""Tests for chat route file_ids validation and enrichment."""
"""Tests for chat API routes: session title update and file attachment validation."""
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
@@ -17,6 +19,7 @@ TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
"""Setup auth overrides for all tests in this module"""
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
@@ -24,7 +27,95 @@ def setup_app_auth(mock_jwt_user):
app.dependency_overrides.clear()
# ---- file_ids Pydantic validation (B1) ----
def _mock_update_session_title(
mocker: pytest_mock.MockerFixture, *, success: bool = True
):
"""Mock update_session_title."""
return mocker.patch(
"backend.api.features.chat.routes.update_session_title",
new_callable=AsyncMock,
return_value=success,
)
# ─── Update title: success ─────────────────────────────────────────────
def test_update_title_success(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
mock_update = _mock_update_session_title(mocker, success=True)
response = client.patch(
"/sessions/sess-1/title",
json={"title": "My project"},
)
assert response.status_code == 200
assert response.json() == {"status": "ok"}
mock_update.assert_called_once_with("sess-1", test_user_id, "My project")
def test_update_title_trims_whitespace(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
mock_update = _mock_update_session_title(mocker, success=True)
response = client.patch(
"/sessions/sess-1/title",
json={"title": " trimmed "},
)
assert response.status_code == 200
mock_update.assert_called_once_with("sess-1", test_user_id, "trimmed")
# ─── Update title: blank / whitespace-only → 422 ──────────────────────
def test_update_title_blank_rejected(
test_user_id: str,
) -> None:
"""Whitespace-only titles must be rejected before hitting the DB."""
response = client.patch(
"/sessions/sess-1/title",
json={"title": " "},
)
assert response.status_code == 422
def test_update_title_empty_rejected(
test_user_id: str,
) -> None:
response = client.patch(
"/sessions/sess-1/title",
json={"title": ""},
)
assert response.status_code == 422
# ─── Update title: session not found or wrong user → 404 ──────────────
def test_update_title_not_found(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
_mock_update_session_title(mocker, success=False)
response = client.patch(
"/sessions/sess-1/title",
json={"title": "New name"},
)
assert response.status_code == 404
# ─── file_ids Pydantic validation ─────────────────────────────────────
def test_stream_chat_rejects_too_many_file_ids():
@@ -92,7 +183,7 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
assert response.status_code == 200
# ---- UUID format filtering ----
# ─── UUID format filtering ─────────────────────────────────────────────
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
@@ -131,7 +222,7 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
assert call_kwargs["where"]["id"]["in"] == [valid_id]
# ---- Cross-workspace file_ids ----
# ─── Cross-workspace file_ids ─────────────────────────────────────────
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):

View File

@@ -8,7 +8,6 @@ import prisma.errors
import prisma.models
import prisma.types
import backend.api.features.store.exceptions as store_exceptions
import backend.api.features.store.image_gen as store_image_gen
import backend.api.features.store.media as store_media
import backend.data.graph as graph_db
@@ -251,7 +250,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
The requested LibraryAgent.
Raises:
AgentNotFoundError: If the specified agent does not exist.
NotFoundError: If the specified agent does not exist.
DatabaseError: If there's an error during retrieval.
"""
library_agent = await prisma.models.LibraryAgent.prisma().find_first(
@@ -398,6 +397,7 @@ async def create_library_agent(
hitl_safe_mode: bool = True,
sensitive_action_safe_mode: bool = False,
create_library_agents_for_sub_graphs: bool = True,
folder_id: str | None = None,
) -> list[library_model.LibraryAgent]:
"""
Adds an agent to the user's library (LibraryAgent table).
@@ -414,12 +414,18 @@ async def create_library_agent(
If the graph has sub-graphs, the parent graph will always be the first entry in the list.
Raises:
AgentNotFoundError: If the specified agent does not exist.
NotFoundError: If the specified agent does not exist.
DatabaseError: If there's an error during creation or if image generation fails.
"""
logger.info(
f"Creating library agent for graph #{graph.id} v{graph.version}; user:<redacted>"
)
# Authorization: FK only checks existence, not ownership.
# Verify the folder belongs to this user to prevent cross-user nesting.
if folder_id:
await get_folder(folder_id, user_id)
graph_entries = (
[graph, *graph.sub_graphs] if create_library_agents_for_sub_graphs else [graph]
)
@@ -432,7 +438,6 @@ async def create_library_agent(
isCreatedByUser=(user_id == user_id),
useGraphIsActiveVersion=True,
User={"connect": {"id": user_id}},
# Creator={"connect": {"id": user_id}},
AgentGraph={
"connect": {
"graphVersionId": {
@@ -448,6 +453,11 @@ async def create_library_agent(
sensitive_action_safe_mode=sensitive_action_safe_mode,
).model_dump()
),
**(
{"Folder": {"connect": {"id": folder_id}}}
if folder_id and graph_entry is graph
else {}
),
),
include=library_agent_include(
user_id, include_nodes=False, include_executions=False
@@ -529,6 +539,7 @@ async def update_agent_version_in_library(
async def create_graph_in_library(
graph: graph_db.Graph,
user_id: str,
folder_id: str | None = None,
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
"""Create a new graph and add it to the user's library."""
graph.version = 1
@@ -542,6 +553,7 @@ async def create_graph_in_library(
user_id=user_id,
sensitive_action_safe_mode=True,
create_library_agents_for_sub_graphs=False,
folder_id=folder_id,
)
if created_graph.is_active:
@@ -817,7 +829,7 @@ async def add_store_agent_to_library(
The newly created LibraryAgent if successfully added, the existing corresponding one if any.
Raises:
AgentNotFoundError: If the store listing or associated agent is not found.
NotFoundError: If the store listing or associated agent is not found.
DatabaseError: If there's an issue creating the LibraryAgent record.
"""
logger.debug(
@@ -832,7 +844,7 @@ async def add_store_agent_to_library(
)
if not store_listing_version or not store_listing_version.AgentGraph:
logger.warning(f"Store listing version not found: {store_listing_version_id}")
raise store_exceptions.AgentNotFoundError(
raise NotFoundError(
f"Store listing version {store_listing_version_id} not found or invalid"
)
@@ -846,7 +858,7 @@ async def add_store_agent_to_library(
include_subgraphs=False,
)
if not graph_model:
raise store_exceptions.AgentNotFoundError(
raise NotFoundError(
f"Graph #{graph.id} v{graph.version} not found or accessible"
)
@@ -1481,6 +1493,67 @@ async def bulk_move_agents_to_folder(
return [library_model.LibraryAgent.from_db(agent) for agent in agents]
def collect_tree_ids(
nodes: list[library_model.LibraryFolderTree],
visited: set[str] | None = None,
) -> list[str]:
"""Collect all folder IDs from a folder tree."""
if visited is None:
visited = set()
ids: list[str] = []
for n in nodes:
if n.id in visited:
continue
visited.add(n.id)
ids.append(n.id)
ids.extend(collect_tree_ids(n.children, visited))
return ids
async def get_folder_agent_summaries(
user_id: str, folder_id: str
) -> list[dict[str, str | None]]:
"""Get a lightweight list of agents in a folder (id, name, description)."""
all_agents: list[library_model.LibraryAgent] = []
for page in itertools.count(1):
resp = await list_library_agents(
user_id=user_id, folder_id=folder_id, page=page
)
all_agents.extend(resp.agents)
if page >= resp.pagination.total_pages:
break
return [
{"id": a.id, "name": a.name, "description": a.description} for a in all_agents
]
async def get_root_agent_summaries(
user_id: str,
) -> list[dict[str, str | None]]:
"""Get a lightweight list of root-level agents (folderId IS NULL)."""
all_agents: list[library_model.LibraryAgent] = []
for page in itertools.count(1):
resp = await list_library_agents(
user_id=user_id, include_root_only=True, page=page
)
all_agents.extend(resp.agents)
if page >= resp.pagination.total_pages:
break
return [
{"id": a.id, "name": a.name, "description": a.description} for a in all_agents
]
async def get_folder_agents_map(
user_id: str, folder_ids: list[str]
) -> dict[str, list[dict[str, str | None]]]:
"""Get agent summaries for multiple folders concurrently."""
results = await asyncio.gather(
*(get_folder_agent_summaries(user_id, fid) for fid in folder_ids)
)
return dict(zip(folder_ids, results))
##############################################
########### Presets DB Functions #############
##############################################

View File

@@ -4,7 +4,6 @@ import prisma.enums
import prisma.models
import pytest
import backend.api.features.store.exceptions
from backend.data.db import connect
from backend.data.includes import library_agent_include
@@ -218,7 +217,7 @@ async def test_add_agent_to_library_not_found(mocker):
)
# Call function and verify exception
with pytest.raises(backend.api.features.store.exceptions.AgentNotFoundError):
with pytest.raises(db.NotFoundError):
await db.add_store_agent_to_library("version123", "test-user")
# Verify mock called correctly

View File

@@ -1,5 +1,3 @@
from typing import Literal
from backend.util.cache import cached
from . import db as store_db
@@ -23,7 +21,7 @@ def clear_all_caches():
async def _get_cached_store_agents(
featured: bool,
creator: str | None,
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None,
sorted_by: store_db.StoreAgentsSortOptions | None,
search_query: str | None,
category: str | None,
page: int,
@@ -57,7 +55,7 @@ async def _get_cached_agent_details(
async def _get_cached_store_creators(
featured: bool,
search_query: str | None,
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None,
sorted_by: store_db.StoreCreatorsSortOptions | None,
page: int,
page_size: int,
):
@@ -75,4 +73,4 @@ async def _get_cached_store_creators(
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
async def _get_cached_creator_details(username: str):
"""Cached helper to get creator details."""
return await store_db.get_store_creator_details(username=username.lower())
return await store_db.get_store_creator(username=username.lower())

File diff suppressed because it is too large Load Diff

View File

@@ -26,7 +26,7 @@ async def test_get_store_agents(mocker):
mock_agents = [
prisma.models.StoreAgent(
listing_id="test-id",
storeListingVersionId="version123",
listing_version_id="version123",
slug="test-agent",
agent_name="Test Agent",
agent_video=None,
@@ -40,11 +40,11 @@ async def test_get_store_agents(mocker):
runs=10,
rating=4.5,
versions=["1.0"],
agentGraphVersions=["1"],
agentGraphId="test-graph-id",
graph_id="test-graph-id",
graph_versions=["1"],
updated_at=datetime.now(),
is_available=False,
useForOnboarding=False,
use_for_onboarding=False,
)
]
@@ -68,10 +68,10 @@ async def test_get_store_agents(mocker):
@pytest.mark.asyncio(loop_scope="session")
async def test_get_store_agent_details(mocker):
# Mock data
# Mock data - StoreAgent view already contains the active version data
mock_agent = prisma.models.StoreAgent(
listing_id="test-id",
storeListingVersionId="version123",
listing_version_id="version123",
slug="test-agent",
agent_name="Test Agent",
agent_video="video.mp4",
@@ -85,102 +85,38 @@ async def test_get_store_agent_details(mocker):
runs=10,
rating=4.5,
versions=["1.0"],
agentGraphVersions=["1"],
agentGraphId="test-graph-id",
updated_at=datetime.now(),
is_available=False,
useForOnboarding=False,
)
# Mock active version agent (what we want to return for active version)
mock_active_agent = prisma.models.StoreAgent(
listing_id="test-id",
storeListingVersionId="active-version-id",
slug="test-agent",
agent_name="Test Agent Active",
agent_video="active_video.mp4",
agent_image=["active_image.jpg"],
featured=False,
creator_username="creator",
creator_avatar="avatar.jpg",
sub_heading="Test heading active",
description="Test description active",
categories=["test"],
runs=15,
rating=4.8,
versions=["1.0", "2.0"],
agentGraphVersions=["1", "2"],
agentGraphId="test-graph-id-active",
graph_id="test-graph-id",
graph_versions=["1"],
updated_at=datetime.now(),
is_available=True,
useForOnboarding=False,
use_for_onboarding=False,
)
# Create a mock StoreListing result
mock_store_listing = mocker.MagicMock()
mock_store_listing.activeVersionId = "active-version-id"
mock_store_listing.hasApprovedVersion = True
mock_store_listing.ActiveVersion = mocker.MagicMock()
mock_store_listing.ActiveVersion.recommendedScheduleCron = None
# Mock StoreAgent prisma call - need to handle multiple calls
# Mock StoreAgent prisma call
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
# Set up side_effect to return different results for different calls
def mock_find_first_side_effect(*args, **kwargs):
where_clause = kwargs.get("where", {})
if "storeListingVersionId" in where_clause:
# Second call for active version
return mock_active_agent
else:
# First call for initial lookup
return mock_agent
mock_store_agent.return_value.find_first = mocker.AsyncMock(
side_effect=mock_find_first_side_effect
)
# Mock Profile prisma call
mock_profile = mocker.MagicMock()
mock_profile.userId = "user-id-123"
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
mock_profile_db.return_value.find_first = mocker.AsyncMock(
return_value=mock_profile
)
# Mock StoreListing prisma call
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(
return_value=mock_store_listing
)
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
# Call function
result = await db.get_store_agent_details("creator", "test-agent")
# Verify results - should use active version data
# Verify results - constructed from the StoreAgent view
assert result.slug == "test-agent"
assert result.agent_name == "Test Agent Active" # From active version
assert result.active_version_id == "active-version-id"
assert result.agent_name == "Test Agent"
assert result.active_version_id == "version123"
assert result.has_approved_version is True
assert (
result.store_listing_version_id == "active-version-id"
) # Should be active version ID
assert result.store_listing_version_id == "version123"
assert result.graph_id == "test-graph-id"
assert result.runs == 10
assert result.rating == 4.5
# Verify mocks called correctly - now expecting 2 calls
assert mock_store_agent.return_value.find_first.call_count == 2
# Check the specific calls
calls = mock_store_agent.return_value.find_first.call_args_list
assert calls[0] == mocker.call(
# Verify single StoreAgent lookup
mock_store_agent.return_value.find_first.assert_called_once_with(
where={"creator_username": "creator", "slug": "test-agent"}
)
assert calls[1] == mocker.call(where={"storeListingVersionId": "active-version-id"})
mock_store_listing_db.return_value.find_first.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_get_store_creator_details(mocker):
async def test_get_store_creator(mocker):
# Mock data
mock_creator_data = prisma.models.Creator(
name="Test Creator",
@@ -202,7 +138,7 @@ async def test_get_store_creator_details(mocker):
mock_creator.return_value.find_unique.return_value = mock_creator_data
# Call function
result = await db.get_store_creator_details("creator")
result = await db.get_store_creator("creator")
# Verify results
assert result.username == "creator"
@@ -218,61 +154,110 @@ async def test_get_store_creator_details(mocker):
@pytest.mark.asyncio(loop_scope="session")
async def test_create_store_submission(mocker):
# Mock data
now = datetime.now()
# Mock agent graph (with no pending submissions) and user with profile
mock_profile = prisma.models.Profile(
id="profile-id",
userId="user-id",
name="Test User",
username="testuser",
description="Test",
isFeatured=False,
links=[],
createdAt=now,
updatedAt=now,
)
mock_user = prisma.models.User(
id="user-id",
email="test@example.com",
createdAt=now,
updatedAt=now,
Profile=[mock_profile],
emailVerified=True,
metadata="{}", # type: ignore[reportArgumentType]
integrations="",
maxEmailsPerDay=1,
notifyOnAgentRun=True,
notifyOnZeroBalance=True,
notifyOnLowBalance=True,
notifyOnBlockExecutionFailed=True,
notifyOnContinuousAgentError=True,
notifyOnDailySummary=True,
notifyOnWeeklySummary=True,
notifyOnMonthlySummary=True,
notifyOnAgentApproved=True,
notifyOnAgentRejected=True,
timezone="Europe/Delft",
)
mock_agent = prisma.models.AgentGraph(
id="agent-id",
version=1,
userId="user-id",
createdAt=datetime.now(),
createdAt=now,
isActive=True,
StoreListingVersions=[],
User=mock_user,
)
mock_listing = prisma.models.StoreListing(
# Mock the created StoreListingVersion (returned by create)
mock_store_listing_obj = prisma.models.StoreListing(
id="listing-id",
createdAt=datetime.now(),
updatedAt=datetime.now(),
createdAt=now,
updatedAt=now,
isDeleted=False,
hasApprovedVersion=False,
slug="test-agent",
agentGraphId="agent-id",
agentGraphVersion=1,
owningUserId="user-id",
Versions=[
prisma.models.StoreListingVersion(
id="version-id",
agentGraphId="agent-id",
agentGraphVersion=1,
name="Test Agent",
description="Test description",
createdAt=datetime.now(),
updatedAt=datetime.now(),
subHeading="Test heading",
imageUrls=["image.jpg"],
categories=["test"],
isFeatured=False,
isDeleted=False,
version=1,
storeListingId="listing-id",
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
isAvailable=True,
)
],
useForOnboarding=False,
)
mock_version = prisma.models.StoreListingVersion(
id="version-id",
agentGraphId="agent-id",
agentGraphVersion=1,
name="Test Agent",
description="Test description",
createdAt=now,
updatedAt=now,
subHeading="",
imageUrls=[],
categories=[],
isFeatured=False,
isDeleted=False,
version=1,
storeListingId="listing-id",
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
isAvailable=True,
submittedAt=now,
StoreListing=mock_store_listing_obj,
)
# Mock prisma calls
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
mock_store_listing = mocker.patch("prisma.models.StoreListing.prisma")
mock_store_listing.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_store_listing.return_value.create = mocker.AsyncMock(return_value=mock_listing)
# Mock transaction context manager
mock_tx = mocker.MagicMock()
mocker.patch(
"backend.api.features.store.db.transaction",
return_value=mocker.AsyncMock(
__aenter__=mocker.AsyncMock(return_value=mock_tx),
__aexit__=mocker.AsyncMock(return_value=False),
),
)
mock_sl = mocker.patch("prisma.models.StoreListing.prisma")
mock_sl.return_value.find_unique = mocker.AsyncMock(return_value=None)
mock_slv = mocker.patch("prisma.models.StoreListingVersion.prisma")
mock_slv.return_value.create = mocker.AsyncMock(return_value=mock_version)
# Call function
result = await db.create_store_submission(
user_id="user-id",
agent_id="agent-id",
agent_version=1,
graph_id="agent-id",
graph_version=1,
slug="test-agent",
name="Test Agent",
description="Test description",
@@ -281,11 +266,11 @@ async def test_create_store_submission(mocker):
# Verify results
assert result.name == "Test Agent"
assert result.description == "Test description"
assert result.store_listing_version_id == "version-id"
assert result.listing_version_id == "version-id"
# Verify mocks called correctly
mock_agent_graph.return_value.find_first.assert_called_once()
mock_store_listing.return_value.create.assert_called_once()
mock_slv.return_value.create.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
@@ -318,7 +303,6 @@ async def test_update_profile(mocker):
description="Test description",
links=["link1"],
avatar_url="avatar.jpg",
is_featured=False,
)
# Call function
@@ -389,7 +373,7 @@ async def test_get_store_agents_with_search_and_filters_parameterized():
creators=["creator1'; DROP TABLE Users; --", "creator2"],
category="AI'; DELETE FROM StoreAgent; --",
featured=True,
sorted_by="rating",
sorted_by=db.StoreAgentsSortOptions.RATING,
page=1,
page_size=20,
)

View File

@@ -57,12 +57,6 @@ class StoreError(ValueError):
pass
class AgentNotFoundError(NotFoundError):
"""Raised when an agent is not found"""
pass
class CreatorNotFoundError(NotFoundError):
"""Raised when a creator is not found"""

View File

@@ -568,7 +568,7 @@ async def hybrid_search(
SELECT uce."contentId" as "storeListingVersionId"
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
INNER JOIN {{schema_prefix}}"StoreAgent" sa
ON uce."contentId" = sa."storeListingVersionId"
ON uce."contentId" = sa.listing_version_id
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
AND uce."userId" IS NULL
AND uce.search @@ plainto_tsquery('english', {query_param})
@@ -582,7 +582,7 @@ async def hybrid_search(
SELECT uce."contentId", uce.embedding
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
INNER JOIN {{schema_prefix}}"StoreAgent" sa
ON uce."contentId" = sa."storeListingVersionId"
ON uce."contentId" = sa.listing_version_id
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
AND uce."userId" IS NULL
AND {where_clause}
@@ -605,7 +605,7 @@ async def hybrid_search(
sa.featured,
sa.is_available,
sa.updated_at,
sa."agentGraphId",
sa.graph_id,
-- Searchable text for BM25 reranking
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
-- Semantic score
@@ -627,9 +627,9 @@ async def hybrid_search(
sa.runs as popularity_raw
FROM candidates c
INNER JOIN {{schema_prefix}}"StoreAgent" sa
ON c."storeListingVersionId" = sa."storeListingVersionId"
ON c."storeListingVersionId" = sa.listing_version_id
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
ON sa."storeListingVersionId" = uce."contentId"
ON sa.listing_version_id = uce."contentId"
AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
),
max_vals AS (
@@ -665,7 +665,7 @@ async def hybrid_search(
featured,
is_available,
updated_at,
"agentGraphId",
graph_id,
searchable_text,
semantic_score,
lexical_score,

View File

@@ -1,11 +1,14 @@
import datetime
from typing import List
from typing import TYPE_CHECKING, List, Self
import prisma.enums
import pydantic
from backend.util.models import Pagination
if TYPE_CHECKING:
import prisma.models
class ChangelogEntry(pydantic.BaseModel):
version: str
@@ -13,9 +16,9 @@ class ChangelogEntry(pydantic.BaseModel):
date: datetime.datetime
class MyAgent(pydantic.BaseModel):
agent_id: str
agent_version: int
class MyUnpublishedAgent(pydantic.BaseModel):
graph_id: str
graph_version: int
agent_name: str
agent_image: str | None = None
description: str
@@ -23,8 +26,8 @@ class MyAgent(pydantic.BaseModel):
recommended_schedule_cron: str | None = None
class MyAgentsResponse(pydantic.BaseModel):
agents: list[MyAgent]
class MyUnpublishedAgentsResponse(pydantic.BaseModel):
agents: list[MyUnpublishedAgent]
pagination: Pagination
@@ -40,6 +43,21 @@ class StoreAgent(pydantic.BaseModel):
rating: float
agent_graph_id: str
@classmethod
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgent":
return cls(
slug=agent.slug,
agent_name=agent.agent_name,
agent_image=agent.agent_image[0] if agent.agent_image else "",
creator=agent.creator_username or "Needs Profile",
creator_avatar=agent.creator_avatar or "",
sub_heading=agent.sub_heading,
description=agent.description,
runs=agent.runs,
rating=agent.rating,
agent_graph_id=agent.graph_id,
)
class StoreAgentsResponse(pydantic.BaseModel):
agents: list[StoreAgent]
@@ -62,81 +80,192 @@ class StoreAgentDetails(pydantic.BaseModel):
runs: int
rating: float
versions: list[str]
agentGraphVersions: list[str]
agentGraphId: str
graph_id: str
graph_versions: list[str]
last_updated: datetime.datetime
recommended_schedule_cron: str | None = None
active_version_id: str | None = None
has_approved_version: bool = False
active_version_id: str
has_approved_version: bool
# Optional changelog data when include_changelog=True
changelog: list[ChangelogEntry] | None = None
class Creator(pydantic.BaseModel):
name: str
username: str
description: str
avatar_url: str
num_agents: int
agent_rating: float
agent_runs: int
is_featured: bool
class CreatorsResponse(pydantic.BaseModel):
creators: List[Creator]
pagination: Pagination
class CreatorDetails(pydantic.BaseModel):
name: str
username: str
description: str
links: list[str]
avatar_url: str
agent_rating: float
agent_runs: int
top_categories: list[str]
@classmethod
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgentDetails":
return cls(
store_listing_version_id=agent.listing_version_id,
slug=agent.slug,
agent_name=agent.agent_name,
agent_video=agent.agent_video or "",
agent_output_demo=agent.agent_output_demo or "",
agent_image=agent.agent_image,
creator=agent.creator_username or "",
creator_avatar=agent.creator_avatar or "",
sub_heading=agent.sub_heading,
description=agent.description,
categories=agent.categories,
runs=agent.runs,
rating=agent.rating,
versions=agent.versions,
graph_id=agent.graph_id,
graph_versions=agent.graph_versions,
last_updated=agent.updated_at,
recommended_schedule_cron=agent.recommended_schedule_cron,
active_version_id=agent.listing_version_id,
has_approved_version=True, # StoreAgent view only has approved agents
)
class Profile(pydantic.BaseModel):
name: str
"""Marketplace user profile (only attributes that the user can update)"""
username: str
name: str
description: str
avatar_url: str | None
links: list[str]
avatar_url: str
is_featured: bool = False
class ProfileDetails(Profile):
"""Marketplace user profile (including read-only fields)"""
is_featured: bool
@classmethod
def from_db(cls, profile: "prisma.models.Profile") -> "ProfileDetails":
return cls(
name=profile.name,
username=profile.username,
avatar_url=profile.avatarUrl,
description=profile.description,
links=profile.links,
is_featured=profile.isFeatured,
)
class CreatorDetails(ProfileDetails):
"""Marketplace creator profile details, including aggregated stats"""
num_agents: int
agent_runs: int
agent_rating: float
top_categories: list[str]
@classmethod
def from_db(cls, creator: "prisma.models.Creator") -> "CreatorDetails": # type: ignore[override]
return cls(
name=creator.name,
username=creator.username,
avatar_url=creator.avatar_url,
description=creator.description,
links=creator.links,
is_featured=creator.is_featured,
num_agents=creator.num_agents,
agent_runs=creator.agent_runs,
agent_rating=creator.agent_rating,
top_categories=creator.top_categories,
)
class CreatorsResponse(pydantic.BaseModel):
creators: List[CreatorDetails]
pagination: Pagination
class StoreSubmission(pydantic.BaseModel):
# From StoreListing:
listing_id: str
agent_id: str
agent_version: int
user_id: str
slug: str
# From StoreListingVersion:
listing_version_id: str
listing_version: int
graph_id: str
graph_version: int
name: str
sub_heading: str
slug: str
description: str
instructions: str | None = None
instructions: str | None
categories: list[str]
image_urls: list[str]
date_submitted: datetime.datetime
status: prisma.enums.SubmissionStatus
runs: int
rating: float
store_listing_version_id: str | None = None
version: int | None = None # Actual version number from the database
video_url: str | None
agent_output_demo_url: str | None
submitted_at: datetime.datetime | None
changes_summary: str | None
status: prisma.enums.SubmissionStatus
reviewed_at: datetime.datetime | None = None
reviewer_id: str | None = None
review_comments: str | None = None # External comments visible to creator
internal_comments: str | None = None # Private notes for admin use only
reviewed_at: datetime.datetime | None = None
changes_summary: str | None = None
# Additional fields for editing
video_url: str | None = None
agent_output_demo_url: str | None = None
categories: list[str] = []
# Aggregated from AgentGraphExecutions and StoreListingReviews:
run_count: int = 0
review_count: int = 0
review_avg_rating: float = 0.0
@classmethod
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
"""Construct from the StoreSubmission Prisma view."""
return cls(
listing_id=_sub.listing_id,
user_id=_sub.user_id,
slug=_sub.slug,
listing_version_id=_sub.listing_version_id,
listing_version=_sub.listing_version,
graph_id=_sub.graph_id,
graph_version=_sub.graph_version,
name=_sub.name,
sub_heading=_sub.sub_heading,
description=_sub.description,
instructions=_sub.instructions,
categories=_sub.categories,
image_urls=_sub.image_urls,
video_url=_sub.video_url,
agent_output_demo_url=_sub.agent_output_demo_url,
submitted_at=_sub.submitted_at,
changes_summary=_sub.changes_summary,
status=_sub.status,
reviewed_at=_sub.reviewed_at,
reviewer_id=_sub.reviewer_id,
review_comments=_sub.review_comments,
run_count=_sub.run_count,
review_count=_sub.review_count,
review_avg_rating=_sub.review_avg_rating,
)
@classmethod
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
"""
Construct from the StoreListingVersion Prisma model (with StoreListing included)
"""
if not (_l := _lv.StoreListing):
raise ValueError("StoreListingVersion must have included StoreListing")
return cls(
listing_id=_l.id,
user_id=_l.owningUserId,
slug=_l.slug,
listing_version_id=_lv.id,
listing_version=_lv.version,
graph_id=_lv.agentGraphId,
graph_version=_lv.agentGraphVersion,
name=_lv.name,
sub_heading=_lv.subHeading,
description=_lv.description,
instructions=_lv.instructions,
categories=_lv.categories,
image_urls=_lv.imageUrls,
video_url=_lv.videoUrl,
agent_output_demo_url=_lv.agentOutputDemoUrl,
submitted_at=_lv.submittedAt,
changes_summary=_lv.changesSummary,
status=_lv.submissionStatus,
reviewed_at=_lv.reviewedAt,
reviewer_id=_lv.reviewerId,
review_comments=_lv.reviewComments,
)
class StoreSubmissionsResponse(pydantic.BaseModel):
@@ -144,33 +273,12 @@ class StoreSubmissionsResponse(pydantic.BaseModel):
pagination: Pagination
class StoreListingWithVersions(pydantic.BaseModel):
"""A store listing with its version history"""
listing_id: str
slug: str
agent_id: str
agent_version: int
active_version_id: str | None = None
has_approved_version: bool = False
creator_email: str | None = None
latest_version: StoreSubmission | None = None
versions: list[StoreSubmission] = []
class StoreListingsWithVersionsResponse(pydantic.BaseModel):
"""Response model for listings with version history"""
listings: list[StoreListingWithVersions]
pagination: Pagination
class StoreSubmissionRequest(pydantic.BaseModel):
agent_id: str = pydantic.Field(
..., min_length=1, description="Agent ID cannot be empty"
graph_id: str = pydantic.Field(
..., min_length=1, description="Graph ID cannot be empty"
)
agent_version: int = pydantic.Field(
..., gt=0, description="Agent version must be greater than 0"
graph_version: int = pydantic.Field(
..., gt=0, description="Graph version must be greater than 0"
)
slug: str
name: str
@@ -198,12 +306,42 @@ class StoreSubmissionEditRequest(pydantic.BaseModel):
recommended_schedule_cron: str | None = None
class ProfileDetails(pydantic.BaseModel):
name: str
username: str
description: str
links: list[str]
avatar_url: str | None = None
class StoreSubmissionAdminView(StoreSubmission):
internal_comments: str | None # Private admin notes
@classmethod
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
return cls(
**StoreSubmission.from_db(_sub).model_dump(),
internal_comments=_sub.internal_comments,
)
@classmethod
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
return cls(
**StoreSubmission.from_listing_version(_lv).model_dump(),
internal_comments=_lv.internalComments,
)
class StoreListingWithVersionsAdminView(pydantic.BaseModel):
"""A store listing with its version history"""
listing_id: str
graph_id: str
slug: str
active_listing_version_id: str | None = None
has_approved_version: bool = False
creator_email: str | None = None
latest_version: StoreSubmissionAdminView | None = None
versions: list[StoreSubmissionAdminView] = []
class StoreListingsWithVersionsAdminViewResponse(pydantic.BaseModel):
"""Response model for listings with version history"""
listings: list[StoreListingWithVersionsAdminView]
pagination: Pagination
class StoreReview(pydantic.BaseModel):

View File

@@ -1,203 +0,0 @@
import datetime
import prisma.enums
from . import model as store_model
def test_pagination():
pagination = store_model.Pagination(
total_items=100, total_pages=5, current_page=2, page_size=20
)
assert pagination.total_items == 100
assert pagination.total_pages == 5
assert pagination.current_page == 2
assert pagination.page_size == 20
def test_store_agent():
agent = store_model.StoreAgent(
slug="test-agent",
agent_name="Test Agent",
agent_image="test.jpg",
creator="creator1",
creator_avatar="avatar.jpg",
sub_heading="Test subheading",
description="Test description",
runs=50,
rating=4.5,
agent_graph_id="test-graph-id",
)
assert agent.slug == "test-agent"
assert agent.agent_name == "Test Agent"
assert agent.runs == 50
assert agent.rating == 4.5
assert agent.agent_graph_id == "test-graph-id"
def test_store_agents_response():
response = store_model.StoreAgentsResponse(
agents=[
store_model.StoreAgent(
slug="test-agent",
agent_name="Test Agent",
agent_image="test.jpg",
creator="creator1",
creator_avatar="avatar.jpg",
sub_heading="Test subheading",
description="Test description",
runs=50,
rating=4.5,
agent_graph_id="test-graph-id",
)
],
pagination=store_model.Pagination(
total_items=1, total_pages=1, current_page=1, page_size=20
),
)
assert len(response.agents) == 1
assert response.pagination.total_items == 1
def test_store_agent_details():
details = store_model.StoreAgentDetails(
store_listing_version_id="version123",
slug="test-agent",
agent_name="Test Agent",
agent_video="video.mp4",
agent_output_demo="demo.mp4",
agent_image=["image1.jpg", "image2.jpg"],
creator="creator1",
creator_avatar="avatar.jpg",
sub_heading="Test subheading",
description="Test description",
categories=["cat1", "cat2"],
runs=50,
rating=4.5,
versions=["1.0", "2.0"],
agentGraphVersions=["1", "2"],
agentGraphId="test-graph-id",
last_updated=datetime.datetime.now(),
)
assert details.slug == "test-agent"
assert len(details.agent_image) == 2
assert len(details.categories) == 2
assert len(details.versions) == 2
def test_creator():
creator = store_model.Creator(
agent_rating=4.8,
agent_runs=1000,
name="Test Creator",
username="creator1",
description="Test description",
avatar_url="avatar.jpg",
num_agents=5,
is_featured=False,
)
assert creator.name == "Test Creator"
assert creator.num_agents == 5
def test_creators_response():
response = store_model.CreatorsResponse(
creators=[
store_model.Creator(
agent_rating=4.8,
agent_runs=1000,
name="Test Creator",
username="creator1",
description="Test description",
avatar_url="avatar.jpg",
num_agents=5,
is_featured=False,
)
],
pagination=store_model.Pagination(
total_items=1, total_pages=1, current_page=1, page_size=20
),
)
assert len(response.creators) == 1
assert response.pagination.total_items == 1
def test_creator_details():
details = store_model.CreatorDetails(
name="Test Creator",
username="creator1",
description="Test description",
links=["link1.com", "link2.com"],
avatar_url="avatar.jpg",
agent_rating=4.8,
agent_runs=1000,
top_categories=["cat1", "cat2"],
)
assert details.name == "Test Creator"
assert len(details.links) == 2
assert details.agent_rating == 4.8
assert len(details.top_categories) == 2
def test_store_submission():
submission = store_model.StoreSubmission(
listing_id="listing123",
agent_id="agent123",
agent_version=1,
sub_heading="Test subheading",
name="Test Agent",
slug="test-agent",
description="Test description",
image_urls=["image1.jpg", "image2.jpg"],
date_submitted=datetime.datetime(2023, 1, 1),
status=prisma.enums.SubmissionStatus.PENDING,
runs=50,
rating=4.5,
)
assert submission.name == "Test Agent"
assert len(submission.image_urls) == 2
assert submission.status == prisma.enums.SubmissionStatus.PENDING
def test_store_submissions_response():
response = store_model.StoreSubmissionsResponse(
submissions=[
store_model.StoreSubmission(
listing_id="listing123",
agent_id="agent123",
agent_version=1,
sub_heading="Test subheading",
name="Test Agent",
slug="test-agent",
description="Test description",
image_urls=["image1.jpg"],
date_submitted=datetime.datetime(2023, 1, 1),
status=prisma.enums.SubmissionStatus.PENDING,
runs=50,
rating=4.5,
)
],
pagination=store_model.Pagination(
total_items=1, total_pages=1, current_page=1, page_size=20
),
)
assert len(response.submissions) == 1
assert response.pagination.total_items == 1
def test_store_submission_request():
request = store_model.StoreSubmissionRequest(
agent_id="agent123",
agent_version=1,
slug="test-agent",
name="Test Agent",
sub_heading="Test subheading",
video_url="video.mp4",
image_urls=["image1.jpg", "image2.jpg"],
description="Test description",
categories=["cat1", "cat2"],
)
assert request.agent_id == "agent123"
assert request.agent_version == 1
assert len(request.image_urls) == 2
assert len(request.categories) == 2

View File

@@ -1,16 +1,17 @@
import logging
import tempfile
import typing
import urllib.parse
from typing import Literal
import autogpt_libs.auth
import fastapi
import fastapi.responses
import prisma.enums
from fastapi import Query, Security
from pydantic import BaseModel
import backend.data.graph
import backend.util.json
from backend.util.exceptions import NotFoundError
from backend.util.models import Pagination
from . import cache as store_cache
@@ -34,22 +35,15 @@ router = fastapi.APIRouter()
"/profile",
summary="Get user profile",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.ProfileDetails,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def get_profile(
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Get the profile details for the authenticated user.
Cached for 1 hour per user.
"""
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.ProfileDetails:
"""Get the profile details for the authenticated user."""
profile = await store_db.get_user_profile(user_id)
if profile is None:
return fastapi.responses.JSONResponse(
status_code=404,
content={"detail": "Profile not found"},
)
raise NotFoundError("User does not have a profile yet")
return profile
@@ -57,98 +51,17 @@ async def get_profile(
"/profile",
summary="Update user profile",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.CreatorDetails,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def update_or_create_profile(
profile: store_model.Profile,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Update the store profile for the authenticated user.
Args:
profile (Profile): The updated profile details
user_id (str): ID of the authenticated user
Returns:
CreatorDetails: The updated profile
Raises:
HTTPException: If there is an error updating the profile
"""
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.ProfileDetails:
"""Update the store profile for the authenticated user."""
updated_profile = await store_db.update_profile(user_id=user_id, profile=profile)
return updated_profile
##############################################
############### Agent Endpoints ##############
##############################################
@router.get(
"/agents",
summary="List store agents",
tags=["store", "public"],
response_model=store_model.StoreAgentsResponse,
)
async def get_agents(
featured: bool = False,
creator: str | None = None,
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
search_query: str | None = None,
category: str | None = None,
page: int = 1,
page_size: int = 20,
):
"""
Get a paginated list of agents from the store with optional filtering and sorting.
Args:
featured (bool, optional): Filter to only show featured agents. Defaults to False.
creator (str | None, optional): Filter agents by creator username. Defaults to None.
sorted_by (str | None, optional): Sort agents by "runs" or "rating". Defaults to None.
search_query (str | None, optional): Search agents by name, subheading and description. Defaults to None.
category (str | None, optional): Filter agents by category. Defaults to None.
page (int, optional): Page number for pagination. Defaults to 1.
page_size (int, optional): Number of agents per page. Defaults to 20.
Returns:
StoreAgentsResponse: Paginated list of agents matching the filters
Raises:
HTTPException: If page or page_size are less than 1
Used for:
- Home Page Featured Agents
- Home Page Top Agents
- Search Results
- Agent Details - Other Agents By Creator
- Agent Details - Similar Agents
- Creator Details - Agents By Creator
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
agents = await store_cache._get_cached_store_agents(
featured=featured,
creator=creator,
sorted_by=sorted_by,
search_query=search_query,
category=category,
page=page,
page_size=page_size,
)
return agents
##############################################
############### Search Endpoints #############
##############################################
@@ -158,60 +71,30 @@ async def get_agents(
"/search",
summary="Unified search across all content types",
tags=["store", "public"],
response_model=store_model.UnifiedSearchResponse,
)
async def unified_search(
query: str,
content_types: list[str] | None = fastapi.Query(
content_types: list[prisma.enums.ContentType] | None = Query(
default=None,
description="Content types to search: STORE_AGENT, BLOCK, DOCUMENTATION. If not specified, searches all.",
description="Content types to search. If not specified, searches all.",
),
page: int = 1,
page_size: int = 20,
user_id: str | None = fastapi.Security(
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
user_id: str | None = Security(
autogpt_libs.auth.get_optional_user_id, use_cache=False
),
):
) -> store_model.UnifiedSearchResponse:
"""
Search across all content types (store agents, blocks, documentation) using hybrid search.
Search across all content types (marketplace agents, blocks, documentation)
using hybrid search.
Combines semantic (embedding-based) and lexical (text-based) search for best results.
Args:
query: The search query string
content_types: Optional list of content types to filter by (STORE_AGENT, BLOCK, DOCUMENTATION)
page: Page number for pagination (default 1)
page_size: Number of results per page (default 20)
user_id: Optional authenticated user ID (for user-scoped content in future)
Returns:
UnifiedSearchResponse: Paginated list of search results with relevance scores
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
# Convert string content types to enum
content_type_enums: list[prisma.enums.ContentType] | None = None
if content_types:
try:
content_type_enums = [prisma.enums.ContentType(ct) for ct in content_types]
except ValueError as e:
raise fastapi.HTTPException(
status_code=422,
detail=f"Invalid content type. Valid values: STORE_AGENT, BLOCK, DOCUMENTATION. Error: {e}",
)
# Perform unified hybrid search
results, total = await store_hybrid_search.unified_hybrid_search(
query=query,
content_types=content_type_enums,
content_types=content_types,
user_id=user_id,
page=page,
page_size=page_size,
@@ -245,22 +128,69 @@ async def unified_search(
)
##############################################
############### Agent Endpoints ##############
##############################################
@router.get(
"/agents",
summary="List store agents",
tags=["store", "public"],
)
async def get_agents(
featured: bool = Query(
default=False, description="Filter to only show featured agents"
),
creator: str | None = Query(
default=None, description="Filter agents by creator username"
),
category: str | None = Query(default=None, description="Filter agents by category"),
search_query: str | None = Query(
default=None, description="Literal + semantic search on names and descriptions"
),
sorted_by: store_db.StoreAgentsSortOptions | None = Query(
default=None,
description="Property to sort results by. Ignored if search_query is provided.",
),
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
) -> store_model.StoreAgentsResponse:
"""
Get a paginated list of agents from the marketplace,
with optional filtering and sorting.
Used for:
- Home Page Featured Agents
- Home Page Top Agents
- Search Results
- Agent Details - Other Agents By Creator
- Agent Details - Similar Agents
- Creator Details - Agents By Creator
"""
agents = await store_cache._get_cached_store_agents(
featured=featured,
creator=creator,
sorted_by=sorted_by,
search_query=search_query,
category=category,
page=page,
page_size=page_size,
)
return agents
@router.get(
"/agents/{username}/{agent_name}",
summary="Get specific agent",
tags=["store", "public"],
response_model=store_model.StoreAgentDetails,
)
async def get_agent(
async def get_agent_by_name(
username: str,
agent_name: str,
include_changelog: bool = fastapi.Query(default=False),
):
"""
This is only used on the AgentDetails Page.
It returns the store listing agents details.
"""
include_changelog: bool = Query(default=False),
) -> store_model.StoreAgentDetails:
"""Get details of a marketplace agent"""
username = urllib.parse.unquote(username).lower()
# URL decode the agent name since it comes from the URL path
agent_name = urllib.parse.unquote(agent_name).lower()
@@ -270,76 +200,82 @@ async def get_agent(
return agent
@router.get(
"/graph/{store_listing_version_id}",
summary="Get agent graph",
tags=["store"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
)
async def get_graph_meta_by_store_listing_version_id(
store_listing_version_id: str,
) -> backend.data.graph.GraphModelWithoutNodes:
"""
Get Agent Graph from Store Listing Version ID.
"""
graph = await store_db.get_available_graph(store_listing_version_id)
return graph
@router.get(
"/agents/{store_listing_version_id}",
summary="Get agent by version",
tags=["store"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreAgentDetails,
)
async def get_store_agent(store_listing_version_id: str):
"""
Get Store Agent Details from Store Listing Version ID.
"""
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
return agent
@router.post(
"/agents/{username}/{agent_name}/review",
summary="Create agent review",
tags=["store"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreReview,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def create_review(
async def post_user_review_for_agent(
username: str,
agent_name: str,
review: store_model.StoreReviewCreate,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Create a review for a store agent.
Args:
username: Creator's username
agent_name: Name/slug of the agent
review: Review details including score and optional comments
user_id: ID of authenticated user creating the review
Returns:
The created review
"""
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.StoreReview:
"""Post a user review on a marketplace agent listing"""
username = urllib.parse.unquote(username).lower()
agent_name = urllib.parse.unquote(agent_name).lower()
# Create the review
created_review = await store_db.create_store_review(
user_id=user_id,
store_listing_version_id=review.store_listing_version_id,
score=review.score,
comments=review.comments,
)
return created_review
@router.get(
"/listings/versions/{store_listing_version_id}",
summary="Get agent by version",
tags=["store"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def get_agent_by_listing_version(
store_listing_version_id: str,
) -> store_model.StoreAgentDetails:
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
return agent
@router.get(
"/listings/versions/{store_listing_version_id}/graph",
summary="Get agent graph",
tags=["store"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def get_graph_meta_by_store_listing_version_id(
store_listing_version_id: str,
) -> backend.data.graph.GraphModelWithoutNodes:
"""Get outline of graph belonging to a specific marketplace listing version"""
graph = await store_db.get_available_graph(store_listing_version_id)
return graph
@router.get(
"/listings/versions/{store_listing_version_id}/graph/download",
summary="Download agent file",
tags=["store", "public"],
)
async def download_agent_file(
store_listing_version_id: str,
) -> fastapi.responses.FileResponse:
"""Download agent graph file for a specific marketplace listing version"""
graph_data = await store_db.get_agent(store_listing_version_id)
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
# Sending graph as a stream (similar to marketplace v1)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as tmp_file:
tmp_file.write(backend.util.json.dumps(graph_data))
tmp_file.flush()
return fastapi.responses.FileResponse(
tmp_file.name, filename=file_name, media_type="application/json"
)
##############################################
############# Creator Endpoints #############
##############################################
@@ -349,37 +285,19 @@ async def create_review(
"/creators",
summary="List store creators",
tags=["store", "public"],
response_model=store_model.CreatorsResponse,
)
async def get_creators(
featured: bool = False,
search_query: str | None = None,
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
page: int = 1,
page_size: int = 20,
):
"""
This is needed for:
- Home Page Featured Creators
- Search Results Page
---
To support this functionality we need:
- featured: bool - to limit the list to just featured agents
- search_query: str - vector search based on the creators profile description.
- sorted_by: [agent_rating, agent_runs] -
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
featured: bool = Query(
default=False, description="Filter to only show featured creators"
),
search_query: str | None = Query(
default=None, description="Literal + semantic search on names and descriptions"
),
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
) -> store_model.CreatorsResponse:
"""List or search marketplace creators"""
creators = await store_cache._get_cached_store_creators(
featured=featured,
search_query=search_query,
@@ -391,18 +309,12 @@ async def get_creators(
@router.get(
"/creator/{username}",
"/creators/{username}",
summary="Get creator details",
tags=["store", "public"],
response_model=store_model.CreatorDetails,
)
async def get_creator(
username: str,
):
"""
Get the details of a creator.
- Creator Details Page
"""
async def get_creator(username: str) -> store_model.CreatorDetails:
"""Get details on a marketplace creator"""
username = urllib.parse.unquote(username).lower()
creator = await store_cache._get_cached_creator_details(username=username)
return creator
@@ -414,20 +326,17 @@ async def get_creator(
@router.get(
"/myagents",
"/my-unpublished-agents",
summary="Get my agents",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.MyAgentsResponse,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def get_my_agents(
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
page: typing.Annotated[int, fastapi.Query(ge=1)] = 1,
page_size: typing.Annotated[int, fastapi.Query(ge=1)] = 20,
):
"""
Get user's own agents.
"""
async def get_my_unpublished_agents(
user_id: str = Security(autogpt_libs.auth.get_user_id),
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
) -> store_model.MyUnpublishedAgentsResponse:
"""List the authenticated user's unpublished agents"""
agents = await store_db.get_my_agents(user_id, page=page, page_size=page_size)
return agents
@@ -436,28 +345,17 @@ async def get_my_agents(
"/submissions/{submission_id}",
summary="Delete store submission",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=bool,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def delete_submission(
submission_id: str,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Delete a store listing submission.
Args:
user_id (str): ID of the authenticated user
submission_id (str): ID of the submission to be deleted
Returns:
bool: True if the submission was successfully deleted, False otherwise
"""
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> bool:
"""Delete a marketplace listing submission"""
result = await store_db.delete_store_submission(
user_id=user_id,
submission_id=submission_id,
)
return result
@@ -465,37 +363,14 @@ async def delete_submission(
"/submissions",
summary="List my submissions",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreSubmissionsResponse,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def get_submissions(
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
page: int = 1,
page_size: int = 20,
):
"""
Get a paginated list of store submissions for the authenticated user.
Args:
user_id (str): ID of the authenticated user
page (int, optional): Page number for pagination. Defaults to 1.
page_size (int, optional): Number of submissions per page. Defaults to 20.
Returns:
StoreListingsResponse: Paginated list of store submissions
Raises:
HTTPException: If page or page_size are less than 1
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
user_id: str = Security(autogpt_libs.auth.get_user_id),
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
) -> store_model.StoreSubmissionsResponse:
"""List the authenticated user's marketplace listing submissions"""
listings = await store_db.get_store_submissions(
user_id=user_id,
page=page,
@@ -508,30 +383,17 @@ async def get_submissions(
"/submissions",
summary="Create store submission",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreSubmission,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def create_submission(
submission_request: store_model.StoreSubmissionRequest,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Create a new store listing submission.
Args:
submission_request (StoreSubmissionRequest): The submission details
user_id (str): ID of the authenticated user submitting the listing
Returns:
StoreSubmission: The created store submission
Raises:
HTTPException: If there is an error creating the submission
"""
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.StoreSubmission:
"""Submit a new marketplace listing for review"""
result = await store_db.create_store_submission(
user_id=user_id,
agent_id=submission_request.agent_id,
agent_version=submission_request.agent_version,
graph_id=submission_request.graph_id,
graph_version=submission_request.graph_version,
slug=submission_request.slug,
name=submission_request.name,
video_url=submission_request.video_url,
@@ -544,7 +406,6 @@ async def create_submission(
changes_summary=submission_request.changes_summary or "Initial Submission",
recommended_schedule_cron=submission_request.recommended_schedule_cron,
)
return result
@@ -552,28 +413,14 @@ async def create_submission(
"/submissions/{store_listing_version_id}",
summary="Edit store submission",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreSubmission,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def edit_submission(
store_listing_version_id: str,
submission_request: store_model.StoreSubmissionEditRequest,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Edit an existing store listing submission.
Args:
store_listing_version_id (str): ID of the store listing version to edit
submission_request (StoreSubmissionRequest): The updated submission details
user_id (str): ID of the authenticated user editing the listing
Returns:
StoreSubmission: The updated store submission
Raises:
HTTPException: If there is an error editing the submission
"""
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.StoreSubmission:
"""Update a pending marketplace listing submission"""
result = await store_db.edit_store_submission(
user_id=user_id,
store_listing_version_id=store_listing_version_id,
@@ -588,7 +435,6 @@ async def edit_submission(
changes_summary=submission_request.changes_summary,
recommended_schedule_cron=submission_request.recommended_schedule_cron,
)
return result
@@ -596,115 +442,61 @@ async def edit_submission(
"/submissions/media",
summary="Upload submission media",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def upload_submission_media(
file: fastapi.UploadFile,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Upload media (images/videos) for a store listing submission.
Args:
file (UploadFile): The media file to upload
user_id (str): ID of the authenticated user uploading the media
Returns:
str: URL of the uploaded media file
Raises:
HTTPException: If there is an error uploading the media
"""
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> str:
"""Upload media for a marketplace listing submission"""
media_url = await store_media.upload_media(user_id=user_id, file=file)
return media_url
class ImageURLResponse(BaseModel):
image_url: str
@router.post(
"/submissions/generate_image",
summary="Generate submission image",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def generate_image(
agent_id: str,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
) -> fastapi.responses.Response:
graph_id: str,
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> ImageURLResponse:
"""
Generate an image for a store listing submission.
Args:
agent_id (str): ID of the agent to generate an image for
user_id (str): ID of the authenticated user
Returns:
JSONResponse: JSON containing the URL of the generated image
Generate an image for a marketplace listing submission based on the properties
of a given graph.
"""
agent = await backend.data.graph.get_graph(
graph_id=agent_id, version=None, user_id=user_id
graph = await backend.data.graph.get_graph(
graph_id=graph_id, version=None, user_id=user_id
)
if not agent:
raise fastapi.HTTPException(
status_code=404, detail=f"Agent with ID {agent_id} not found"
)
if not graph:
raise NotFoundError(f"Agent graph #{graph_id} not found")
# Use .jpeg here since we are generating JPEG images
filename = f"agent_{agent_id}.jpeg"
filename = f"agent_{graph_id}.jpeg"
existing_url = await store_media.check_media_exists(user_id, filename)
if existing_url:
logger.info(f"Using existing image for agent {agent_id}")
return fastapi.responses.JSONResponse(content={"image_url": existing_url})
logger.info(f"Using existing image for agent graph {graph_id}")
return ImageURLResponse(image_url=existing_url)
# Generate agent image as JPEG
image = await store_image_gen.generate_agent_image(agent=agent)
image = await store_image_gen.generate_agent_image(agent=graph)
# Create UploadFile with the correct filename and content_type
image_file = fastapi.UploadFile(
file=image,
filename=filename,
)
image_url = await store_media.upload_media(
user_id=user_id, file=image_file, use_file_name=True
)
return fastapi.responses.JSONResponse(content={"image_url": image_url})
@router.get(
"/download/agents/{store_listing_version_id}",
summary="Download agent file",
tags=["store", "public"],
)
async def download_agent_file(
store_listing_version_id: str = fastapi.Path(
..., description="The ID of the agent to download"
),
) -> fastapi.responses.FileResponse:
"""
Download the agent file by streaming its content.
Args:
store_listing_version_id (str): The ID of the agent to download
Returns:
StreamingResponse: A streaming response containing the agent's graph data.
Raises:
HTTPException: If the agent is not found or an unexpected error occurs.
"""
graph_data = await store_db.get_agent(store_listing_version_id)
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
# Sending graph as a stream (similar to marketplace v1)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as tmp_file:
tmp_file.write(backend.util.json.dumps(graph_data))
tmp_file.flush()
return fastapi.responses.FileResponse(
tmp_file.name, filename=file_name, media_type="application/json"
)
return ImageURLResponse(image_url=image_url)
##############################################

View File

@@ -8,6 +8,8 @@ import pytest
import pytest_mock
from pytest_snapshot.plugin import Snapshot
from backend.api.features.store.db import StoreAgentsSortOptions
from . import model as store_model
from . import routes as store_routes
@@ -196,7 +198,7 @@ def test_get_agents_sorted(
mock_db_call.assert_called_once_with(
featured=False,
creators=None,
sorted_by="runs",
sorted_by=StoreAgentsSortOptions.RUNS,
search_query=None,
category=None,
page=1,
@@ -380,9 +382,11 @@ def test_get_agent_details(
runs=100,
rating=4.5,
versions=["1.0.0", "1.1.0"],
agentGraphVersions=["1", "2"],
agentGraphId="test-graph-id",
graph_versions=["1", "2"],
graph_id="test-graph-id",
last_updated=FIXED_NOW,
active_version_id="test-version-id",
has_approved_version=True,
)
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agent_details")
mock_db_call.return_value = mocked_value
@@ -435,15 +439,17 @@ def test_get_creators_pagination(
) -> None:
mocked_value = store_model.CreatorsResponse(
creators=[
store_model.Creator(
store_model.CreatorDetails(
name=f"Creator {i}",
username=f"creator{i}",
description=f"Creator {i} description",
avatar_url=f"avatar{i}.jpg",
num_agents=1,
agent_rating=4.5,
agent_runs=100,
description=f"Creator {i} description",
links=[f"user{i}.link.com"],
is_featured=False,
num_agents=1,
agent_runs=100,
agent_rating=4.5,
top_categories=["cat1", "cat2", "cat3"],
)
for i in range(5)
],
@@ -496,19 +502,19 @@ def test_get_creator_details(
mocked_value = store_model.CreatorDetails(
name="Test User",
username="creator1",
avatar_url="avatar.jpg",
description="Test creator description",
links=["link1.com", "link2.com"],
avatar_url="avatar.jpg",
agent_rating=4.8,
is_featured=True,
num_agents=5,
agent_runs=1000,
agent_rating=4.8,
top_categories=["category1", "category2"],
)
mock_db_call = mocker.patch(
"backend.api.features.store.db.get_store_creator_details"
)
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_creator")
mock_db_call.return_value = mocked_value
response = client.get("/creator/creator1")
response = client.get("/creators/creator1")
assert response.status_code == 200
data = store_model.CreatorDetails.model_validate(response.json())
@@ -528,19 +534,26 @@ def test_get_submissions_success(
submissions=[
store_model.StoreSubmission(
listing_id="test-listing-id",
name="Test Agent",
description="Test agent description",
image_urls=["test.jpg"],
date_submitted=FIXED_NOW,
status=prisma.enums.SubmissionStatus.APPROVED,
runs=50,
rating=4.2,
agent_id="test-agent-id",
agent_version=1,
sub_heading="Test agent subheading",
user_id="test-user-id",
slug="test-agent",
video_url="test.mp4",
listing_version_id="test-version-id",
listing_version=1,
graph_id="test-agent-id",
graph_version=1,
name="Test Agent",
sub_heading="Test agent subheading",
description="Test agent description",
instructions="Click the button!",
categories=["test-category"],
image_urls=["test.jpg"],
video_url="test.mp4",
agent_output_demo_url="demo_video.mp4",
submitted_at=FIXED_NOW,
changes_summary="Initial Submission",
status=prisma.enums.SubmissionStatus.APPROVED,
run_count=50,
review_count=5,
review_avg_rating=4.2,
)
],
pagination=store_model.Pagination(

View File

@@ -11,6 +11,7 @@ import pytest
from backend.util.models import Pagination
from . import cache as store_cache
from .db import StoreAgentsSortOptions
from .model import StoreAgent, StoreAgentsResponse
@@ -215,7 +216,7 @@ class TestCacheDeletion:
await store_cache._get_cached_store_agents(
featured=True,
creator="testuser",
sorted_by="rating",
sorted_by=StoreAgentsSortOptions.RATING,
search_query="AI assistant",
category="productivity",
page=2,
@@ -227,7 +228,7 @@ class TestCacheDeletion:
deleted = store_cache._get_cached_store_agents.cache_delete(
featured=True,
creator="testuser",
sorted_by="rating",
sorted_by=StoreAgentsSortOptions.RATING,
search_query="AI assistant",
category="productivity",
page=2,
@@ -239,7 +240,7 @@ class TestCacheDeletion:
deleted = store_cache._get_cached_store_agents.cache_delete(
featured=True,
creator="testuser",
sorted_by="rating",
sorted_by=StoreAgentsSortOptions.RATING,
search_query="AI assistant",
category="productivity",
page=2,

View File

@@ -55,6 +55,7 @@ from backend.data.credit import (
set_auto_top_up,
)
from backend.data.graph import GraphSettings
from backend.data.invited_user import get_or_activate_user
from backend.data.model import CredentialsMetaInput, UserOnboarding
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.data.onboarding import (
@@ -70,7 +71,6 @@ from backend.data.onboarding import (
update_user_onboarding,
)
from backend.data.user import (
get_or_create_user,
get_user_by_id,
get_user_notification_preference,
update_user_email,
@@ -136,12 +136,10 @@ _tally_background_tasks: set[asyncio.Task] = set()
dependencies=[Security(requires_user)],
)
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
user = await get_or_create_user(user_data)
user = await get_or_activate_user(user_data)
# Fire-and-forget: populate business understanding from Tally form.
# We use created_at proximity instead of an is_new flag because
# get_or_create_user is cached — a separate is_new return value would be
# unreliable on repeated calls within the cache TTL.
# Fire-and-forget: backfill Tally understanding when invite pre-seeding did
# not produce a stored result before first activation.
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
if age_seconds < 30:
try:
@@ -165,8 +163,11 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
dependencies=[Security(requires_user)],
)
async def update_user_email_route(
user_id: Annotated[str, Security(get_user_id)], email: str = Body(...)
user_id: Annotated[str, Security(get_user_id)],
email: str = Body(...),
user_data: dict = Security(get_jwt_payload),
) -> dict[str, str]:
await get_or_activate_user(user_data)
await update_user_email(user_id, email)
return {"email": email}
@@ -182,7 +183,7 @@ async def get_user_timezone_route(
user_data: dict = Security(get_jwt_payload),
) -> TimezoneResponse:
"""Get user timezone setting."""
user = await get_or_create_user(user_data)
user = await get_or_activate_user(user_data)
return TimezoneResponse(timezone=user.timezone)
@@ -193,9 +194,12 @@ async def get_user_timezone_route(
dependencies=[Security(requires_user)],
)
async def update_user_timezone_route(
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
user_id: Annotated[str, Security(get_user_id)],
request: UpdateTimezoneRequest,
user_data: dict = Security(get_jwt_payload),
) -> TimezoneResponse:
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
await get_or_activate_user(user_data)
user = await update_user_timezone(user_id, str(request.timezone))
return TimezoneResponse(timezone=user.timezone)
@@ -208,7 +212,9 @@ async def update_user_timezone_route(
)
async def get_preferences(
user_id: Annotated[str, Security(get_user_id)],
user_data: dict = Security(get_jwt_payload),
) -> NotificationPreference:
await get_or_activate_user(user_data)
preferences = await get_user_notification_preference(user_id)
return preferences
@@ -222,7 +228,9 @@ async def get_preferences(
async def update_preferences(
user_id: Annotated[str, Security(get_user_id)],
preferences: NotificationPreferenceDTO = Body(...),
user_data: dict = Security(get_jwt_payload),
) -> NotificationPreference:
await get_or_activate_user(user_data)
output = await update_user_notification_preference(user_id, preferences)
return output
@@ -449,7 +457,6 @@ async def execute_graph_block(
async def upload_file(
user_id: Annotated[str, Security(get_user_id)],
file: UploadFile = File(...),
provider: str = "gcs",
expiration_hours: int = 24,
) -> UploadFileResponse:
"""
@@ -512,7 +519,6 @@ async def upload_file(
storage_path = await cloud_storage.store_file(
content=content,
filename=file_name,
provider=provider,
expiration_hours=expiration_hours,
user_id=user_id,
)

View File

@@ -51,7 +51,7 @@ def test_get_or_create_user_route(
}
mocker.patch(
"backend.api.features.v1.get_or_create_user",
"backend.api.features.v1.get_or_activate_user",
return_value=mock_user,
)
@@ -515,7 +515,6 @@ async def test_upload_file_success(test_user_id: str):
result = await upload_file(
file=upload_file_mock,
user_id=test_user_id,
provider="gcs",
expiration_hours=24,
)
@@ -533,7 +532,6 @@ async def test_upload_file_success(test_user_id: str):
mock_handler.store_file.assert_called_once_with(
content=file_content,
filename="test.txt",
provider="gcs",
expiration_hours=24,
user_id=test_user_id,
)

View File

@@ -19,6 +19,7 @@ from prisma.errors import PrismaError
import backend.api.features.admin.credit_admin_routes
import backend.api.features.admin.execution_analytics_routes
import backend.api.features.admin.store_admin_routes
import backend.api.features.admin.user_admin_routes
import backend.api.features.builder
import backend.api.features.builder.routes
import backend.api.features.chat.routes as chat_routes
@@ -55,6 +56,7 @@ from backend.util.exceptions import (
MissingConfigError,
NotAuthorizedError,
NotFoundError,
PreconditionFailed,
)
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
from backend.util.service import UnhealthyServiceError
@@ -275,6 +277,7 @@ app.add_exception_handler(RequestValidationError, validation_error_handler)
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
app.add_exception_handler(ValueError, handle_internal_http_error(400))
app.add_exception_handler(PreconditionFailed, handle_internal_http_error(428))
app.add_exception_handler(Exception, handle_internal_http_error(500))
app.include_router(backend.api.features.v1.v1_router, tags=["v1"], prefix="/api")
@@ -309,6 +312,11 @@ app.include_router(
tags=["v2", "admin"],
prefix="/api/executions",
)
app.include_router(
backend.api.features.admin.user_admin_routes.router,
tags=["v2", "admin"],
prefix="/api/users",
)
app.include_router(
backend.api.features.executions.review.routes.router,
tags=["v2", "executions", "review"],

View File

@@ -1,8 +1,8 @@
import logging
from typing import Literal
from pydantic import BaseModel
from backend.api.features.store.db import StoreAgentsSortOptions
from backend.blocks._base import (
Block,
BlockCategory,
@@ -176,8 +176,8 @@ class SearchStoreAgentsBlock(Block):
category: str | None = SchemaField(
description="Filter by category", default=None
)
sort_by: Literal["rating", "runs", "name", "updated_at"] = SchemaField(
description="How to sort the results", default="rating"
sort_by: StoreAgentsSortOptions = SchemaField(
description="How to sort the results", default=StoreAgentsSortOptions.RATING
)
limit: int = SchemaField(
description="Maximum number of results to return", default=10, ge=1, le=100
@@ -278,7 +278,7 @@ class SearchStoreAgentsBlock(Block):
self,
query: str | None = None,
category: str | None = None,
sort_by: Literal["rating", "runs", "name", "updated_at"] = "rating",
sort_by: StoreAgentsSortOptions = StoreAgentsSortOptions.RATING,
limit: int = 10,
) -> SearchAgentsResponse:
"""

View File

@@ -2,6 +2,7 @@ from unittest.mock import MagicMock
import pytest
from backend.api.features.store.db import StoreAgentsSortOptions
from backend.blocks.system.library_operations import (
AddToLibraryFromStoreBlock,
LibraryAgent,
@@ -121,7 +122,10 @@ async def test_search_store_agents_block(mocker):
)
input_data = block.Input(
query="test", category="productivity", sort_by="rating", limit=10
query="test",
category="productivity",
sort_by=StoreAgentsSortOptions.RATING, # type: ignore[reportArgumentType]
limit=10,
)
outputs = {}

View File

@@ -22,6 +22,7 @@ from backend.copilot.model import (
update_session_title,
upsert_chat_session,
)
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.response_model import (
StreamBaseResponse,
StreamError,
@@ -62,8 +63,8 @@ async def _update_title_async(
"""Generate and persist a session title in the background."""
try:
title = await _generate_session_title(message, user_id, session_id)
if title:
await update_session_title(session_id, title)
if title and user_id:
await update_session_title(session_id, user_id, title, only_if_empty=True)
except Exception as e:
logger.warning("[Baseline] Failed to update session title: %s", e)
@@ -176,14 +177,17 @@ async def stream_chat_completion_baseline(
# changes from concurrent chats updating business understanding.
is_first_turn = len(session.messages) <= 1
if is_first_turn:
system_prompt, _ = await _build_system_prompt(
base_system_prompt, _ = await _build_system_prompt(
user_id, has_conversation_history=False
)
else:
system_prompt, _ = await _build_system_prompt(
base_system_prompt, _ = await _build_system_prompt(
user_id=None, has_conversation_history=True
)
# Append tool documentation and technical notes
system_prompt = base_system_prompt + get_baseline_supplement()
# Compress context if approaching the model's token limit
messages_for_context = await _compress_session_messages(session.messages)

View File

@@ -81,6 +81,35 @@ async def update_chat_session(
return ChatSession.from_db(session) if session else None
async def update_chat_session_title(
session_id: str,
user_id: str,
title: str,
*,
only_if_empty: bool = False,
) -> bool:
"""Update the title of a chat session, scoped to the owning user.
Always filters by (session_id, user_id) so callers cannot mutate another
user's session even when they know the session_id.
Args:
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
guard so auto-generated titles never overwrite a user-set title.
Returns True if a row was updated, False otherwise (session not found,
wrong user, or — when only_if_empty — title was already set).
"""
where: ChatSessionWhereInput = {"id": session_id, "userId": user_id}
if only_if_empty:
where["title"] = None
result = await PrismaChatSession.prisma().update_many(
where=where,
data={"title": title, "updatedAt": datetime.now(UTC)},
)
return result > 0
async def add_chat_message(
session_id: str,
role: str,

View File

@@ -469,8 +469,16 @@ async def upsert_chat_session(
)
db_error = e
# Save to cache (best-effort, even if DB failed)
# Save to cache (best-effort, even if DB failed).
# Title updates (update_session_title) run *outside* this lock because
# they only touch the title field, not messages. So a concurrent rename
# or auto-title may have written a newer title to Redis while this
# upsert was in progress. Always prefer the cached title to avoid
# overwriting it with the stale in-memory copy.
try:
existing_cached = await _get_session_from_cache(session.session_id)
if existing_cached and existing_cached.title:
session = session.model_copy(update={"title": existing_cached.title})
await cache_chat_session(session)
except Exception as e:
# If DB succeeded but cache failed, raise cache error
@@ -685,30 +693,48 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
return True
async def update_session_title(session_id: str, title: str) -> bool:
"""Update only the title of a chat session.
async def update_session_title(
session_id: str,
user_id: str,
title: str,
*,
only_if_empty: bool = False,
) -> bool:
"""Update the title of a chat session, scoped to the owning user.
This is a lightweight operation that doesn't touch messages, avoiding
race conditions with concurrent message updates. Use this for background
title generation instead of upsert_chat_session.
Lightweight operation that doesn't touch messages, avoiding race conditions
with concurrent message updates.
Args:
session_id: The session ID to update.
user_id: Owning user — the DB query filters on this.
title: The new title to set.
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
so auto-generated titles never overwrite a user-set title.
Returns:
True if updated successfully, False otherwise.
True if updated successfully, False otherwise (not found, wrong user,
or — when only_if_empty — title was already set).
"""
try:
result = await chat_db().update_chat_session(session_id=session_id, title=title)
if result is None:
logger.warning(f"Session {session_id} not found for title update")
updated = await chat_db().update_chat_session_title(
session_id, user_id, title, only_if_empty=only_if_empty
)
if not updated:
return False
# Invalidate the cache so the next access reloads from DB with the
# updated title. This avoids a read-modify-write on the full session
# blob, which could overwrite concurrent message updates.
await invalidate_session_cache(session_id)
# Update title in cache if it exists (instead of invalidating).
# This prevents race conditions where cache invalidation causes
# the frontend to see stale DB data while streaming is still in progress.
try:
cached = await _get_session_from_cache(session_id)
if cached:
cached.title = title
await cache_chat_session(cached)
except Exception as e:
logger.warning(
f"Cache title update failed for session {session_id} (non-critical): {e}"
)
return True
except Exception as e:

View File

@@ -0,0 +1,191 @@
"""Centralized prompt building logic for CoPilot.
This module contains all prompt construction functions and constants,
handling the distinction between:
- SDK mode vs Baseline mode (tool documentation needs)
- Local mode vs E2B mode (storage/filesystem differences)
"""
from backend.copilot.tools import TOOL_REGISTRY
# Shared technical notes that apply to both SDK and baseline modes
_SHARED_TOOL_NOTES = """\
### Sharing files with the user
After saving a file to the persistent workspace with `write_workspace_file`,
share it with the user by embedding the `download_url` from the response in
your message as a Markdown link or image:
- **Any file** — shows as a clickable download link:
`[report.csv](workspace://file_id#text/csv)`
- **Image** — renders inline in chat:
`![chart](workspace://file_id#image/png)`
- **Video** — renders inline in chat with player controls:
`![recording](workspace://file_id#video/mp4)`
The `download_url` field in the `write_workspace_file` response is already
in the correct format — paste it directly after the `(` in the Markdown.
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.
All tasks must run in the foreground.
"""
# Environment-specific supplement templates
def _build_storage_supplement(
working_dir: str,
sandbox_type: str,
storage_system_1_name: str,
storage_system_1_characteristics: list[str],
storage_system_1_persistence: list[str],
file_move_name_1_to_2: str,
file_move_name_2_to_1: str,
) -> str:
"""Build storage/filesystem supplement for a specific environment.
Template function handles all formatting (bullets, indentation, markdown).
Callers provide clean data as lists of strings.
Args:
working_dir: Working directory path
sandbox_type: Description of bash_exec sandbox
storage_system_1_name: Name of primary storage (ephemeral or cloud)
storage_system_1_characteristics: List of characteristic descriptions
storage_system_1_persistence: List of persistence behavior descriptions
file_move_name_1_to_2: Direction label for primary→persistent
file_move_name_2_to_1: Direction label for persistent→primary
"""
# Format lists as bullet points with proper indentation
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
persistence = "\n".join(f" - {p}" for p in storage_system_1_persistence)
return f"""
## Tool notes
### Shell commands
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
for shell commands — it runs {sandbox_type}.
### Working directory
- Your working directory is: `{working_dir}`
- All SDK file tools AND `bash_exec` operate on the same filesystem
- Use relative paths or absolute paths under `{working_dir}` for all file operations
### Two storage systems — CRITICAL to understand
1. **{storage_system_1_name}** (`{working_dir}`):
{characteristics}
{persistence}
2. **Persistent workspace** (cloud storage):
- Files here **survive across sessions indefinitely**
### Moving files between storages
- **{file_move_name_1_to_2}**: Copy to persistent workspace
- **{file_move_name_2_to_1}**: Download for processing
### File persistence
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
{_SHARED_TOOL_NOTES}"""
# Pre-built supplements for common environments
def _get_local_storage_supplement(cwd: str) -> str:
"""Local ephemeral storage (files lost between turns)."""
return _build_storage_supplement(
working_dir=cwd,
sandbox_type="in a network-isolated sandbox",
storage_system_1_name="Ephemeral working directory",
storage_system_1_characteristics=[
"Shared by SDK Read/Write/Edit/Glob/Grep tools AND `bash_exec`",
],
storage_system_1_persistence=[
"Files here are **lost between turns** — do NOT rely on them persisting",
"Use for temporary work: running scripts, processing data, etc.",
],
file_move_name_1_to_2="Ephemeral → Persistent",
file_move_name_2_to_1="Persistent → Ephemeral",
)
def _get_cloud_sandbox_supplement() -> str:
"""Cloud persistent sandbox (files survive across turns in session)."""
return _build_storage_supplement(
working_dir="/home/user",
sandbox_type="in a cloud sandbox with full internet access",
storage_system_1_name="Cloud sandbox",
storage_system_1_characteristics=[
"Shared by all file tools AND `bash_exec` — same filesystem",
"Full Linux environment with internet access",
],
storage_system_1_persistence=[
"Files **persist across turns** within the current session",
"Lost when the session expires (12 h inactivity)",
],
file_move_name_1_to_2="Sandbox → Persistent",
file_move_name_2_to_1="Persistent → Sandbox",
)
def _generate_tool_documentation() -> str:
"""Auto-generate tool documentation from TOOL_REGISTRY.
NOTE: This is ONLY used in baseline mode (direct OpenAI API).
SDK mode doesn't need it since Claude gets tool schemas automatically.
This generates a complete list of available tools with their descriptions,
ensuring the documentation stays in sync with the actual tool implementations.
All workflow guidance is now embedded in individual tool descriptions.
Only documents tools that are available in the current environment
(checked via tool.is_available property).
"""
docs = "\n## AVAILABLE TOOLS\n\n"
# Sort tools alphabetically for consistent output
# Filter by is_available to match get_available_tools() behavior
for name in sorted(TOOL_REGISTRY.keys()):
tool = TOOL_REGISTRY[name]
if not tool.is_available:
continue
schema = tool.as_openai_tool()
desc = schema["function"].get("description", "No description available")
# Format as bullet list with tool name in code style
docs += f"- **`{name}`**: {desc}\n"
return docs
def get_sdk_supplement(use_e2b: bool, cwd: str = "") -> str:
"""Get the supplement for SDK mode (Claude Agent SDK).
SDK mode does NOT include tool documentation because Claude automatically
receives tool schemas from the SDK. Only includes technical notes about
storage systems and execution environment.
Args:
use_e2b: Whether E2B cloud sandbox is being used
cwd: Current working directory (only used in local_storage mode)
Returns:
The supplement string to append to the system prompt
"""
if use_e2b:
return _get_cloud_sandbox_supplement()
return _get_local_storage_supplement(cwd)
def get_baseline_supplement() -> str:
"""Get the supplement for baseline mode (direct OpenAI API).
Baseline mode INCLUDES auto-generated tool documentation because the
direct API doesn't automatically provide tool schemas to Claude.
Also includes shared technical notes (but NOT SDK-specific environment details).
Returns:
The supplement string to append to the system prompt
"""
tool_docs = _generate_tool_documentation()
return tool_docs + _SHARED_TOOL_NOTES

View File

@@ -127,7 +127,6 @@ def create_security_hooks(
sdk_cwd: str | None = None,
max_subtasks: int = 3,
on_compact: Callable[[], None] | None = None,
on_stop: Callable[[str, str], None] | None = None,
) -> dict[str, Any]:
"""Create the security hooks configuration for Claude Agent SDK.
@@ -136,15 +135,12 @@ def create_security_hooks(
- PostToolUse: Log successful tool executions
- PostToolUseFailure: Log and handle failed tool executions
- PreCompact: Log context compaction events (SDK handles compaction automatically)
- Stop: Capture transcript path for stateless resume (when *on_stop* is provided)
Args:
user_id: Current user ID for isolation validation
sdk_cwd: SDK working directory for workspace-scoped tool validation
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when
the SDK finishes processing — used to read the JSONL transcript
before the CLI process exits.
on_compact: Callback invoked when SDK starts compacting context.
Returns:
Hooks configuration dict for ClaudeAgentOptions
@@ -311,30 +307,6 @@ def create_security_hooks(
on_compact()
return cast(SyncHookJSONOutput, {})
# --- Stop hook: capture transcript path for stateless resume ---
async def stop_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Capture transcript path when SDK finishes processing.
The Stop hook fires while the CLI process is still alive, giving us
a reliable window to read the JSONL transcript before SIGTERM.
"""
_ = context, tool_use_id
transcript_path = cast(str, input_data.get("transcript_path", ""))
sdk_session_id = cast(str, input_data.get("session_id", ""))
if transcript_path and on_stop:
logger.info(
f"[SDK] Stop hook: transcript_path={transcript_path}, "
f"sdk_session_id={sdk_session_id[:12]}..."
)
on_stop(transcript_path, sdk_session_id)
return cast(SyncHookJSONOutput, {})
hooks: dict[str, Any] = {
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
@@ -344,9 +316,6 @@ def create_security_hooks(
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
}
if on_stop is not None:
hooks["Stop"] = [HookMatcher(matcher=None, hooks=[stop_hook])]
return hooks
except ImportError:
# Fallback for when SDK isn't available - return empty hooks

View File

@@ -12,7 +12,6 @@ import subprocess
import sys
import uuid
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any, cast
import openai
@@ -21,6 +20,9 @@ from claude_agent_sdk import (
ClaudeAgentOptions,
ClaudeSDKClient,
ResultMessage,
TextBlock,
ThinkingBlock,
ToolResultBlock,
ToolUseBlock,
)
from langfuse import propagate_attributes
@@ -42,6 +44,7 @@ from ..model import (
update_session_title,
upsert_chat_session,
)
from ..prompting import get_sdk_supplement
from ..response_model import (
StreamBaseResponse,
StreamError,
@@ -74,11 +77,11 @@ from .tool_adapter import (
from .transcript import (
cleanup_cli_project_dir,
download_transcript,
read_transcript_file,
upload_transcript,
validate_transcript,
write_transcript_to_tempfile,
)
from .transcript_builder import TranscriptBuilder
logger = logging.getLogger(__name__)
config = ChatConfig()
@@ -137,19 +140,6 @@ _setup_langfuse_otel()
_background_tasks: set[asyncio.Task[Any]] = set()
@dataclass
class CapturedTranscript:
"""Info captured by the SDK Stop hook for stateless --resume."""
path: str = ""
sdk_session_id: str = ""
raw_content: str = ""
@property
def available(self) -> bool:
return bool(self.path)
_SDK_CWD_PREFIX = WORKSPACE_PREFIX
# Heartbeat interval — keep SSE alive through proxies/LBs during tool execution.
@@ -157,140 +147,6 @@ _SDK_CWD_PREFIX = WORKSPACE_PREFIX
_HEARTBEAT_INTERVAL = 10.0 # seconds
# Appended to the system prompt to inform the agent about available tools.
# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead,
# which has kernel-level network isolation (unshare --net).
_SHARED_TOOL_NOTES = """\
### Sharing files with the user
After saving a file to the persistent workspace with `write_workspace_file`,
share it with the user by embedding the `download_url` from the response in
your message as a Markdown link or image:
- **Any file** — shows as a clickable download link:
`[report.csv](workspace://file_id#text/csv)`
- **Image** — renders inline in chat:
`![chart](workspace://file_id#image/png)`
- **Video** — renders inline in chat with player controls:
`![recording](workspace://file_id#video/mp4)`
The `download_url` field in the `write_workspace_file` response is already
in the correct format — paste it directly after the `(` in the Markdown.
### Long-running tools
Long-running tools (create_agent, edit_agent, etc.) are handled
asynchronously. You will receive an immediate response; the actual result
is delivered to the user via a background stream.
### Large tool outputs
When a tool output exceeds the display limit, it is automatically saved to
the persistent workspace. The truncated output includes a
`<tool-output-truncated>` tag with the workspace path. Use
`read_workspace_file(path="...", offset=N, length=50000)` to retrieve
additional sections.
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.
All tasks must run in the foreground.
"""
_LOCAL_TOOL_SUPPLEMENT = (
"""
## Tool notes
### Shell commands
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
for shell commands — it runs in a network-isolated sandbox.
### Working directory
- Your working directory is: `{cwd}`
- All SDK Read/Write/Edit/Glob/Grep tools AND `bash_exec` operate inside this
directory. This is the ONLY writable path — do not attempt to read or write
anywhere else on the filesystem.
- Use relative paths or absolute paths under `{cwd}` for all file operations.
### Two storage systems — CRITICAL to understand
1. **Ephemeral working directory** (`{cwd}`):
- Shared by SDK Read/Write/Edit/Glob/Grep tools AND `bash_exec`
- Files here are **lost between turns** — do NOT rely on them persisting
- Use for temporary work: running scripts, processing data, etc.
2. **Persistent workspace** (cloud storage):
- Files here **survive across turns and sessions**
- Use `write_workspace_file` to save important files (code, outputs, configs)
- Use `read_workspace_file` to retrieve previously saved files
- Use `list_workspace_files` to see what files you've saved before
- Call `list_workspace_files(include_all_sessions=True)` to see files from
all sessions
### Moving files between ephemeral and persistent storage
- **Ephemeral → Persistent**: Use `write_workspace_file` with either:
- `content` param (plain text) — for text files
- `source_path` param — to copy any file directly from the ephemeral dir
- **Persistent → Ephemeral**: Use `read_workspace_file` with `save_to_path`
param to download a workspace file to the ephemeral dir for processing
### File persistence workflow
When you create or modify important files (code, configs, outputs), you MUST:
1. Save them using `write_workspace_file` so they persist
2. At the start of a new turn, call `list_workspace_files` to see what files
are available from previous turns
"""
+ _SHARED_TOOL_NOTES
)
_E2B_TOOL_SUPPLEMENT = (
"""
## Tool notes
### Shell commands
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
for shell commands — it runs in a cloud sandbox with full internet access.
### Working directory
- Your working directory is: `/home/user` (cloud sandbox)
- All file tools (`read_file`, `write_file`, `edit_file`, `glob`, `grep`)
AND `bash_exec` operate on the **same cloud sandbox filesystem**.
- Files created by `bash_exec` are immediately visible to `read_file` and
vice-versa — they share one filesystem.
- Use relative paths (resolved from `/home/user`) or absolute paths.
### Two storage systems — CRITICAL to understand
1. **Cloud sandbox** (`/home/user`):
- Shared by all file tools AND `bash_exec` — same filesystem
- Files **persist across turns** within the current session
- Full Linux environment with internet access
- Lost when the session expires (12 h inactivity)
2. **Persistent workspace** (cloud storage):
- Files here **survive across sessions indefinitely**
- Use `write_workspace_file` to save important files permanently
- Use `read_workspace_file` to retrieve previously saved files
- Use `list_workspace_files` to see what files you've saved before
- Call `list_workspace_files(include_all_sessions=True)` to see files from
all sessions
### Moving files between sandbox and persistent storage
- **Sandbox → Persistent**: Use `write_workspace_file` with `source_path`
to copy from the sandbox to permanent storage
- **Persistent → Sandbox**: Use `read_workspace_file` with `save_to_path`
to download into the sandbox for processing
### File persistence workflow
Important files that must survive beyond this session should be saved with
`write_workspace_file`. Sandbox files persist across turns but are lost
when the session expires.
"""
+ _SHARED_TOOL_NOTES
)
STREAM_LOCK_PREFIX = "copilot:stream:lock:"
@@ -451,6 +307,50 @@ def _cleanup_sdk_tool_results(cwd: str) -> None:
pass
def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
"""Convert SDK content blocks to transcript format.
Handles TextBlock, ToolUseBlock, ToolResultBlock, and ThinkingBlock.
Unknown block types are logged and skipped.
"""
result: list[dict[str, Any]] = []
for block in blocks or []:
if isinstance(block, TextBlock):
result.append({"type": "text", "text": block.text})
elif isinstance(block, ToolUseBlock):
result.append(
{
"type": "tool_use",
"id": block.id,
"name": block.name,
"input": block.input,
}
)
elif isinstance(block, ToolResultBlock):
tool_result_entry: dict[str, Any] = {
"type": "tool_result",
"tool_use_id": block.tool_use_id,
"content": block.content,
}
if block.is_error:
tool_result_entry["is_error"] = True
result.append(tool_result_entry)
elif isinstance(block, ThinkingBlock):
result.append(
{
"type": "thinking",
"thinking": block.thinking,
"signature": block.signature,
}
)
else:
logger.warning(
f"[SDK] Unknown content block type: {type(block).__name__}. "
f"This may indicate a new SDK version with additional block types."
)
return result
async def _compress_messages(
messages: list[ChatMessage],
) -> tuple[list[ChatMessage], bool]:
@@ -806,6 +706,11 @@ async def stream_chat_completion_sdk(
user_id=user_id, session_id=session_id, message_length=len(message)
)
# Structured log prefix: [SDK][<session>][T<turn>]
# Turn = number of user messages (1-based), computed AFTER appending the new message.
turn = sum(1 for m in session.messages if m.role == "user")
log_prefix = f"[SDK][{session_id[:12]}][T{turn}]"
session = await upsert_chat_session(session)
# Generate title for new sessions (first user message)
@@ -823,10 +728,11 @@ async def stream_chat_completion_sdk(
message_id = str(uuid.uuid4())
stream_id = str(uuid.uuid4())
stream_completed = False
ended_with_stream_error = False
e2b_sandbox = None
use_resume = False
resume_file: str | None = None
captured_transcript = CapturedTranscript()
transcript_builder = TranscriptBuilder()
sdk_cwd = ""
# Acquire stream lock to prevent concurrent streams to the same session
@@ -841,7 +747,7 @@ async def stream_chat_completion_sdk(
if lock_owner != stream_id:
# Another stream is active
logger.warning(
f"[SDK] Session {session_id} already has an active stream: {lock_owner}"
f"{log_prefix} Session already has an active stream: {lock_owner}"
)
yield StreamError(
errorText="Another stream is already active for this session. "
@@ -865,7 +771,7 @@ async def stream_chat_completion_sdk(
sdk_cwd = _make_sdk_cwd(session_id)
os.makedirs(sdk_cwd, exist_ok=True)
except (ValueError, OSError) as e:
logger.error("[SDK] [%s] Invalid SDK cwd: %s", session_id[:12], e)
logger.error("%s Invalid SDK cwd: %s", log_prefix, e)
yield StreamError(
errorText="Unable to initialize working directory.",
code="sdk_cwd_error",
@@ -909,12 +815,13 @@ async def stream_chat_completion_sdk(
):
return None
try:
return await download_transcript(user_id, session_id)
return await download_transcript(
user_id, session_id, log_prefix=log_prefix
)
except Exception as transcript_err:
logger.warning(
"[SDK] [%s] Transcript download failed, continuing without "
"--resume: %s",
session_id[:12],
"%s Transcript download failed, continuing without " "--resume: %s",
log_prefix,
transcript_err,
)
return None
@@ -926,21 +833,27 @@ async def stream_chat_completion_sdk(
)
use_e2b = e2b_sandbox is not None
system_prompt = base_system_prompt + (
_E2B_TOOL_SUPPLEMENT
if use_e2b
else _LOCAL_TOOL_SUPPLEMENT.format(cwd=sdk_cwd)
# Append appropriate supplement (Claude gets tool schemas automatically)
system_prompt = base_system_prompt + get_sdk_supplement(
use_e2b=use_e2b, cwd=sdk_cwd
)
# Process transcript download result
transcript_msg_count = 0
if dl:
is_valid = validate_transcript(dl.content)
dl_lines = dl.content.strip().split("\n") if dl.content else []
logger.info(
"%s Downloaded transcript: %dB, %d lines, " "msg_count=%d, valid=%s",
log_prefix,
len(dl.content),
len(dl_lines),
dl.message_count,
is_valid,
)
if is_valid:
logger.info(
f"[SDK] Transcript available for session {session_id}: "
f"{len(dl.content)}B, msg_count={dl.message_count}"
)
# Load previous FULL context into builder
transcript_builder.load_previous(dl.content, log_prefix=log_prefix)
resume_file = write_transcript_to_tempfile(
dl.content, session_id, sdk_cwd
)
@@ -948,16 +861,14 @@ async def stream_chat_completion_sdk(
use_resume = True
transcript_msg_count = dl.message_count
logger.debug(
f"[SDK] Using --resume ({len(dl.content)}B, "
f"{log_prefix} Using --resume ({len(dl.content)}B, "
f"msg_count={transcript_msg_count})"
)
else:
logger.warning(
f"[SDK] Transcript downloaded but invalid for {session_id}"
)
logger.warning(f"{log_prefix} Transcript downloaded but invalid")
elif config.claude_agent_use_resume and user_id and len(session.messages) > 1:
logger.warning(
f"[SDK] No transcript available for {session_id} "
f"{log_prefix} No transcript available "
f"({len(session.messages)} messages in session)"
)
@@ -979,25 +890,6 @@ async def stream_chat_completion_sdk(
sdk_model = _resolve_sdk_model()
# --- Transcript capture via Stop hook ---
# Read the file content immediately — the SDK may clean up
# the file before our finally block runs.
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
captured_transcript.path = transcript_path
captured_transcript.sdk_session_id = sdk_session_id
content = read_transcript_file(transcript_path)
if content:
captured_transcript.raw_content = content
logger.info(
f"[SDK] Stop hook: captured {len(content)}B from "
f"{transcript_path}"
)
else:
logger.warning(
f"[SDK] Stop hook: transcript file empty/missing at "
f"{transcript_path}"
)
# Track SDK-internal compaction (PreCompact hook → start, next msg → end)
compaction = CompactionTracker()
@@ -1005,7 +897,6 @@ async def stream_chat_completion_sdk(
user_id,
sdk_cwd=sdk_cwd,
max_subtasks=config.claude_agent_max_subtasks,
on_stop=_on_stop if config.claude_agent_use_resume else None,
on_compact=compaction.on_compact,
)
@@ -1040,7 +931,10 @@ async def stream_chat_completion_sdk(
session_id=session_id,
trace_name="copilot-sdk",
tags=["sdk"],
metadata={"resume": str(use_resume)},
metadata={
"resume": str(use_resume),
"conversation_turn": str(turn),
},
)
_otel_ctx.__enter__()
@@ -1074,9 +968,9 @@ async def stream_chat_completion_sdk(
query_message = f"{query_message}\n\n{attachments.hint}"
logger.info(
"[SDK] [%s] Sending query — resume=%s, total_msgs=%d, "
"%s Sending query — resume=%s, total_msgs=%d, "
"query_len=%d, attached_files=%d, image_blocks=%d",
session_id[:12],
log_prefix,
use_resume,
len(session.messages),
len(query_message),
@@ -1105,15 +999,19 @@ async def stream_chat_completion_sdk(
await client._transport.write( # noqa: SLF001
json.dumps(user_msg) + "\n"
)
# Capture user message in transcript (multimodal)
transcript_builder.append_user(content=content_blocks)
else:
await client.query(query_message, session_id=session_id)
# Capture actual user message in transcript (not the engineered query)
# query_message may include context wrappers, but transcript needs raw input
transcript_builder.append_user(content=current_message)
assistant_response = ChatMessage(role="assistant", content="")
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False
ended_with_stream_error = False
# Use an explicit async iterator with non-cancelling heartbeats.
# CRITICAL: we must NOT cancel __anext__() mid-flight — doing so
# (via asyncio.timeout or wait_for) corrupts the SDK's internal
@@ -1150,8 +1048,8 @@ async def stream_chat_completion_sdk(
sdk_msg = done.pop().result()
except StopAsyncIteration:
logger.info(
"[SDK] [%s] Stream ended normally (StopAsyncIteration)",
session_id[:12],
"%s Stream ended normally (StopAsyncIteration)",
log_prefix,
)
break
except Exception as stream_err:
@@ -1160,8 +1058,8 @@ async def stream_chat_completion_sdk(
# so the session can still be saved and the
# frontend gets a clean finish.
logger.error(
"[SDK] [%s] Stream error from SDK: %s",
session_id[:12],
"%s Stream error from SDK: %s",
log_prefix,
stream_err,
exc_info=True,
)
@@ -1173,9 +1071,9 @@ async def stream_chat_completion_sdk(
break
logger.info(
"[SDK] [%s] Received: %s %s "
"%s Received: %s %s "
"(unresolved=%d, current=%d, resolved=%d)",
session_id[:12],
log_prefix,
type(sdk_msg).__name__,
getattr(sdk_msg, "subtype", ""),
len(adapter.current_tool_calls)
@@ -1210,10 +1108,10 @@ async def stream_chat_completion_sdk(
await asyncio.sleep(0)
else:
logger.warning(
"[SDK] [%s] Timed out waiting for "
"%s Timed out waiting for "
"PostToolUse hook stash "
"(%d unresolved tool calls)",
session_id[:12],
log_prefix,
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
)
@@ -1221,9 +1119,9 @@ async def stream_chat_completion_sdk(
# Log ResultMessage details for debugging
if isinstance(sdk_msg, ResultMessage):
logger.info(
"[SDK] [%s] Received: ResultMessage %s "
"%s Received: ResultMessage %s "
"(unresolved=%d, current=%d, resolved=%d)",
session_id[:12],
log_prefix,
sdk_msg.subtype,
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
@@ -1232,8 +1130,8 @@ async def stream_chat_completion_sdk(
)
if sdk_msg.subtype in ("error", "error_during_execution"):
logger.error(
"[SDK] [%s] SDK execution failed with error: %s",
session_id[:12],
"%s SDK execution failed with error: %s",
log_prefix,
sdk_msg.result or "(no error message provided)",
)
@@ -1258,8 +1156,8 @@ async def stream_chat_completion_sdk(
out_len = len(str(response.output))
extra = f", output_len={out_len}"
logger.info(
"[SDK] [%s] Tool event: %s, tool=%s%s",
session_id[:12],
"%s Tool event: %s, tool=%s%s",
log_prefix,
type(response).__name__,
getattr(response, "toolName", "N/A"),
extra,
@@ -1268,8 +1166,8 @@ async def stream_chat_completion_sdk(
# Log errors being sent to frontend
if isinstance(response, StreamError):
logger.error(
"[SDK] [%s] Sending error to frontend: %s (code=%s)",
session_id[:12],
"%s Sending error to frontend: %s (code=%s)",
log_prefix,
response.errorText,
response.code,
)
@@ -1314,29 +1212,44 @@ async def stream_chat_completion_sdk(
has_appended_assistant = True
elif isinstance(response, StreamToolOutputAvailable):
content = (
response.output
if isinstance(response.output, str)
else json.dumps(response.output, ensure_ascii=False)
)
session.messages.append(
ChatMessage(
role="tool",
content=(
response.output
if isinstance(response.output, str)
else str(response.output)
),
content=content,
tool_call_id=response.toolCallId,
)
)
transcript_builder.append_tool_result(
tool_use_id=response.toolCallId,
content=content,
)
has_tool_results = True
elif isinstance(response, StreamFinish):
stream_completed = True
# Append assistant entry AFTER convert_message so that
# any stashed tool results from the previous turn are
# recorded first, preserving the required API order:
# assistant(tool_use) → tool_result → assistant(text).
if isinstance(sdk_msg, AssistantMessage):
transcript_builder.append_assistant(
content_blocks=_format_sdk_content_blocks(sdk_msg.content),
model=sdk_msg.model,
)
except asyncio.CancelledError:
# Task/generator was cancelled (e.g. client disconnect,
# server shutdown). Log and let the safety-net / finally
# blocks handle cleanup.
logger.warning(
"[SDK] [%s] Streaming loop cancelled (asyncio.CancelledError)",
session_id[:12],
"%s Streaming loop cancelled (asyncio.CancelledError)",
log_prefix,
)
raise
finally:
@@ -1350,7 +1263,8 @@ async def stream_chat_completion_sdk(
except (asyncio.CancelledError, StopAsyncIteration):
# Expected: task was cancelled or exhausted during cleanup
logger.info(
"[SDK] Pending __anext__ task completed during cleanup"
"%s Pending __anext__ task completed during cleanup",
log_prefix,
)
# Safety net: if tools are still unresolved after the
@@ -1359,9 +1273,9 @@ async def stream_chat_completion_sdk(
# them now so the frontend stops showing spinners.
if adapter.has_unresolved_tool_calls:
logger.warning(
"[SDK] [%s] %d unresolved tool(s) after stream loop — "
"%s %d unresolved tool(s) after stream loop — "
"flushing as safety net",
session_id[:12],
log_prefix,
len(adapter.current_tool_calls) - len(adapter.resolved_tool_calls),
)
safety_responses: list[StreamBaseResponse] = []
@@ -1372,11 +1286,20 @@ async def stream_chat_completion_sdk(
(StreamToolInputAvailable, StreamToolOutputAvailable),
):
logger.info(
"[SDK] [%s] Safety flush: %s, tool=%s",
session_id[:12],
"%s Safety flush: %s, tool=%s",
log_prefix,
type(response).__name__,
getattr(response, "toolName", "N/A"),
)
if isinstance(response, StreamToolOutputAvailable):
transcript_builder.append_tool_result(
tool_use_id=response.toolCallId,
content=(
response.output
if isinstance(response.output, str)
else json.dumps(response.output, ensure_ascii=False)
),
)
yield response
# If the stream ended without a ResultMessage, the SDK
@@ -1386,8 +1309,8 @@ async def stream_chat_completion_sdk(
# StreamFinish is published by mark_session_completed in the processor.
if not stream_completed and not ended_with_stream_error:
logger.info(
"[SDK] [%s] Stream ended without ResultMessage (stopped by user)",
session_id[:12],
"%s Stream ended without ResultMessage (stopped by user)",
log_prefix,
)
closing_responses: list[StreamBaseResponse] = []
adapter._end_text_if_open(closing_responses)
@@ -1408,69 +1331,36 @@ async def stream_chat_completion_sdk(
) and not has_appended_assistant:
session.messages.append(assistant_response)
# --- Upload transcript for next-turn --resume ---
# After async with the SDK task group has exited, so the Stop
# hook has already fired and the CLI has been SIGTERMed. The
# CLI uses appendFileSync, so all writes are safely on disk.
if config.claude_agent_use_resume and user_id:
# With --resume the CLI appends to the resume file (most
# complete). Otherwise use the Stop hook path.
if use_resume and resume_file:
raw_transcript = read_transcript_file(resume_file)
logger.debug("[SDK] Transcript source: resume file")
elif captured_transcript.path:
raw_transcript = read_transcript_file(captured_transcript.path)
logger.debug(
"[SDK] Transcript source: stop hook (%s), read result: %s",
captured_transcript.path,
f"{len(raw_transcript)}B" if raw_transcript else "None",
)
else:
raw_transcript = None
# Transcript upload is handled exclusively in the finally block
# to avoid double-uploads (the success path used to upload the
# old resume file, then the finally block overwrote it with the
# stop hook content — which could be smaller after compaction).
if not raw_transcript:
logger.debug(
"[SDK] No usable transcript — CLI file had no "
"conversation entries (expected for first turn "
"without --resume)"
)
if raw_transcript:
# Shield the upload from generator cancellation so a
# client disconnect / page refresh doesn't lose the
# transcript. The upload must finish even if the SSE
# connection is torn down.
await asyncio.shield(
_try_upload_transcript(
user_id,
session_id,
raw_transcript,
message_count=len(session.messages),
)
)
logger.info(
"[SDK] [%s] Stream completed successfully with %d messages",
session_id[:12],
len(session.messages),
)
if ended_with_stream_error:
logger.warning(
"%s Stream ended with SDK error after %d messages",
log_prefix,
len(session.messages),
)
else:
logger.info(
"%s Stream completed successfully with %d messages",
log_prefix,
len(session.messages),
)
except BaseException as e:
# Catch BaseException to handle both Exception and CancelledError
# (CancelledError inherits from BaseException in Python 3.8+)
if isinstance(e, asyncio.CancelledError):
logger.warning("[SDK] [%s] Session cancelled", session_id[:12])
logger.warning("%s Session cancelled", log_prefix)
error_msg = "Operation cancelled"
else:
error_msg = str(e) or type(e).__name__
# SDK cleanup RuntimeError is expected during cancellation, log as warning
if isinstance(e, RuntimeError) and "cancel scope" in str(e):
logger.warning(
"[SDK] [%s] SDK cleanup error: %s", session_id[:12], error_msg
)
logger.warning("%s SDK cleanup error: %s", log_prefix, error_msg)
else:
logger.error(
f"[SDK] [%s] Error: {error_msg}", session_id[:12], exc_info=True
)
logger.error("%s Error: %s", log_prefix, error_msg, exc_info=True)
# Append error marker to session (non-invasive text parsing approach)
# The finally block will persist the session with this error marker
@@ -1481,8 +1371,8 @@ async def stream_chat_completion_sdk(
)
)
logger.debug(
"[SDK] [%s] Appended error marker, will be persisted in finally",
session_id[:12],
"%s Appended error marker, will be persisted in finally",
log_prefix,
)
# Yield StreamError for immediate feedback (only for non-cancellation errors)
@@ -1514,47 +1404,61 @@ async def stream_chat_completion_sdk(
try:
await asyncio.shield(upsert_chat_session(session))
logger.info(
"[SDK] [%s] Session persisted in finally with %d messages",
session_id[:12],
"%s Session persisted in finally with %d messages",
log_prefix,
len(session.messages),
)
except Exception as persist_err:
logger.error(
"[SDK] [%s] Failed to persist session in finally: %s",
session_id[:12],
"%s Failed to persist session in finally: %s",
log_prefix,
persist_err,
exc_info=True,
)
# --- Upload transcript for next-turn --resume ---
# This MUST run in finally so the transcript is uploaded even when
# the streaming loop raises an exception. The CLI uses
# appendFileSync, so whatever was written before the error/SIGTERM
# is safely on disk and still useful for the next turn.
if config.claude_agent_use_resume and user_id:
# the streaming loop raises an exception.
# The transcript represents the COMPLETE active context (atomic).
if config.claude_agent_use_resume and user_id and session is not None:
try:
# Prefer content captured in the Stop hook (read before
# cleanup removes the file). Fall back to the resume
# file when the stop hook didn't fire (e.g. error before
# completion) so we don't lose the prior transcript.
raw_transcript = captured_transcript.raw_content or None
if not raw_transcript and use_resume and resume_file:
raw_transcript = read_transcript_file(resume_file)
# Build complete transcript from captured SDK messages
transcript_content = transcript_builder.to_jsonl()
if raw_transcript and session is not None:
await asyncio.shield(
_try_upload_transcript(
user_id,
session_id,
raw_transcript,
message_count=len(session.messages),
)
if not transcript_content:
logger.warning(
"%s No transcript to upload (builder empty)", log_prefix
)
elif not validate_transcript(transcript_content):
logger.warning(
"%s Transcript invalid, skipping upload (entries=%d)",
log_prefix,
transcript_builder.entry_count,
)
else:
logger.warning(f"[SDK] No transcript to upload for {session_id}")
logger.info(
"%s Uploading complete transcript (entries=%d, bytes=%d)",
log_prefix,
transcript_builder.entry_count,
len(transcript_content),
)
# Shield upload from cancellation - let it complete even if
# the finally block is interrupted. No timeout to avoid race
# conditions where backgrounded uploads overwrite newer transcripts.
await asyncio.shield(
upload_transcript(
user_id=user_id,
session_id=session_id,
content=transcript_content,
message_count=len(session.messages),
log_prefix=log_prefix,
)
)
except Exception as upload_err:
logger.error(
f"[SDK] Transcript upload failed in finally: {upload_err}",
"%s Transcript upload failed in finally: %s",
log_prefix,
upload_err,
exc_info=True,
)
@@ -1565,33 +1469,6 @@ async def stream_chat_completion_sdk(
await lock.release()
async def _try_upload_transcript(
user_id: str,
session_id: str,
raw_content: str,
message_count: int = 0,
) -> bool:
"""Strip progress entries and upload transcript (with timeout).
Returns True if the upload completed without error.
"""
try:
async with asyncio.timeout(30):
await upload_transcript(
user_id, session_id, raw_content, message_count=message_count
)
return True
except asyncio.TimeoutError:
logger.warning(f"[SDK] Transcript upload timed out for {session_id}")
return False
except Exception as e:
logger.error(
f"[SDK] Failed to upload transcript for {session_id}: {e}",
exc_info=True,
)
return False
async def _update_title_async(
session_id: str, message: str, user_id: str | None = None
) -> None:
@@ -1600,8 +1477,8 @@ async def _update_title_async(
title = await _generate_session_title(
message, user_id=user_id, session_id=session_id
)
if title:
await update_session_title(session_id, title)
if title and user_id:
await update_session_title(session_id, user_id, title, only_if_empty=True)
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
except Exception as e:
logger.warning(f"[SDK] Failed to update session title: {e}")

View File

@@ -145,3 +145,103 @@ class TestPrepareFileAttachments:
assert "Read tool" not in result.hint
assert len(result.image_blocks) == 1
class TestPromptSupplement:
"""Tests for centralized prompt supplement generation."""
def test_sdk_supplement_excludes_tool_docs(self):
"""SDK mode should NOT include tool documentation (Claude gets schemas automatically)."""
from backend.copilot.prompting import get_sdk_supplement
# Test both local and E2B modes
local_supplement = get_sdk_supplement(use_e2b=False, cwd="/tmp/test")
e2b_supplement = get_sdk_supplement(use_e2b=True, cwd="")
# Should NOT have tool list section
assert "## AVAILABLE TOOLS" not in local_supplement
assert "## AVAILABLE TOOLS" not in e2b_supplement
# Should still have technical notes
assert "## Tool notes" in local_supplement
assert "## Tool notes" in e2b_supplement
def test_baseline_supplement_includes_tool_docs(self):
"""Baseline mode MUST include tool documentation (direct API needs it)."""
from backend.copilot.prompting import get_baseline_supplement
supplement = get_baseline_supplement()
# MUST have tool list section
assert "## AVAILABLE TOOLS" in supplement
# Should NOT have environment-specific notes (SDK-only)
assert "## Tool notes" not in supplement
def test_baseline_supplement_includes_key_tools(self):
"""Baseline supplement should document all essential tools."""
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.tools import TOOL_REGISTRY
docs = get_baseline_supplement()
# Core agent workflow tools (always available)
assert "`create_agent`" in docs
assert "`run_agent`" in docs
assert "`find_library_agent`" in docs
assert "`edit_agent`" in docs
# MCP integration (always available)
assert "`run_mcp_tool`" in docs
# Folder management (always available)
assert "`create_folder`" in docs
# Browser tools only if available (Playwright may not be installed in CI)
if (
TOOL_REGISTRY.get("browser_navigate")
and TOOL_REGISTRY["browser_navigate"].is_available
):
assert "`browser_navigate`" in docs
def test_baseline_supplement_includes_workflows(self):
"""Baseline supplement should include workflow guidance in tool descriptions."""
from backend.copilot.prompting import get_baseline_supplement
docs = get_baseline_supplement()
# Workflows are now in individual tool descriptions (not separate sections)
# Check that key workflow concepts appear in tool descriptions
assert "suggested_goal" in docs or "clarifying_questions" in docs
assert "run_mcp_tool" in docs
def test_baseline_supplement_completeness(self):
"""All available tools from TOOL_REGISTRY should appear in baseline supplement."""
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.tools import TOOL_REGISTRY
docs = get_baseline_supplement()
# Verify each available registered tool is documented
# (matches _generate_tool_documentation which filters by is_available)
for tool_name, tool in TOOL_REGISTRY.items():
if not tool.is_available:
continue
assert (
f"`{tool_name}`" in docs
), f"Tool '{tool_name}' missing from baseline supplement"
def test_baseline_supplement_no_duplicate_tools(self):
"""No tool should appear multiple times in baseline supplement."""
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.tools import TOOL_REGISTRY
docs = get_baseline_supplement()
# Count occurrences of each available tool in the entire supplement
for tool_name, tool in TOOL_REGISTRY.items():
if not tool.is_available:
continue
# Count how many times this tool appears as a bullet point
count = docs.count(f"- **`{tool_name}`**")
assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)"

View File

@@ -10,13 +10,14 @@ Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
filesystem for self-hosted) — no DB column needed.
"""
import json
import logging
import os
import re
import time
from dataclasses import dataclass
from backend.util import json
logger = logging.getLogger(__name__)
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
@@ -58,41 +59,37 @@ def strip_progress_entries(content: str) -> str:
Removes entries whose ``type`` is in ``STRIPPABLE_TYPES`` and reparents
any remaining child entries so the ``parentUuid`` chain stays intact.
Typically reduces transcript size by ~30%.
Entries that are not stripped or reparented are kept as their original
raw JSON line to avoid unnecessary re-serialization that changes
whitespace or key ordering.
"""
lines = content.strip().split("\n")
entries: list[dict] = []
# Parse entries, keeping the original line alongside the parsed dict.
parsed: list[tuple[str, dict | None]] = []
for line in lines:
try:
entries.append(json.loads(line))
except json.JSONDecodeError:
# Keep unparseable lines as-is (safety)
entries.append({"_raw": line})
parsed.append((line, json.loads(line, fallback=None)))
# First pass: identify stripped UUIDs and build parent map.
stripped_uuids: set[str] = set()
uuid_to_parent: dict[str, str] = {}
kept: list[dict] = []
for entry in entries:
if "_raw" in entry:
kept.append(entry)
for _line, entry in parsed:
if not isinstance(entry, dict):
continue
uid = entry.get("uuid", "")
parent = entry.get("parentUuid", "")
entry_type = entry.get("type", "")
if uid:
uuid_to_parent[uid] = parent
if entry.get("type", "") in STRIPPABLE_TYPES and uid:
stripped_uuids.add(uid)
if entry_type in STRIPPABLE_TYPES:
if uid:
stripped_uuids.add(uid)
else:
kept.append(entry)
# Reparent: walk up chain through stripped entries to find surviving ancestor
for entry in kept:
if "_raw" in entry:
# Second pass: keep non-stripped entries, reparenting where needed.
# Preserve original line when no reparenting is required.
reparented: set[str] = set()
for _line, entry in parsed:
if not isinstance(entry, dict):
continue
parent = entry.get("parentUuid", "")
original_parent = parent
@@ -100,63 +97,32 @@ def strip_progress_entries(content: str) -> str:
parent = uuid_to_parent.get(parent, "")
if parent != original_parent:
entry["parentUuid"] = parent
uid = entry.get("uuid", "")
if uid:
reparented.add(uid)
result_lines: list[str] = []
for entry in kept:
if "_raw" in entry:
result_lines.append(entry["_raw"])
else:
for line, entry in parsed:
if not isinstance(entry, dict):
result_lines.append(line)
continue
if entry.get("type", "") in STRIPPABLE_TYPES:
continue
uid = entry.get("uuid", "")
if uid in reparented:
# Re-serialize only entries whose parentUuid was changed.
result_lines.append(json.dumps(entry, separators=(",", ":")))
else:
result_lines.append(line)
return "\n".join(result_lines) + "\n"
# ---------------------------------------------------------------------------
# Local file I/O (read from CLI's JSONL, write temp file for --resume)
# Local file I/O (write temp file for --resume)
# ---------------------------------------------------------------------------
def read_transcript_file(transcript_path: str) -> str | None:
"""Read a JSONL transcript file from disk.
Returns the raw JSONL content, or ``None`` if the file is missing, empty,
or only contains metadata (≤2 lines with no conversation messages).
"""
if not transcript_path or not os.path.isfile(transcript_path):
logger.debug(f"[Transcript] File not found: {transcript_path}")
return None
try:
with open(transcript_path) as f:
content = f.read()
if not content.strip():
logger.debug("[Transcript] File is empty: %s", transcript_path)
return None
lines = content.strip().split("\n")
# Validate that the transcript has real conversation content
# (not just metadata like queue-operation entries).
if not validate_transcript(content):
logger.debug(
"[Transcript] No conversation content (%d lines) in %s",
len(lines),
transcript_path,
)
return None
logger.info(
f"[Transcript] Read {len(lines)} lines, "
f"{len(content)} bytes from {transcript_path}"
)
return content
except (json.JSONDecodeError, OSError) as e:
logger.warning(f"[Transcript] Failed to read {transcript_path}: {e}")
return None
def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
"""Sanitize an ID for safe use in file paths.
@@ -171,14 +137,6 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
def _encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does.
The CLI replaces all non-alphanumeric characters with ``-``.
"""
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
"""Remove the CLI's project directory for a specific working directory.
@@ -188,7 +146,8 @@ def cleanup_cli_project_dir(sdk_cwd: str) -> None:
"""
import shutil
cwd_encoded = _encode_cwd_for_cli(sdk_cwd)
# Encode cwd the same way CLI does (replaces non-alphanumeric with -)
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
@@ -248,32 +207,29 @@ def write_transcript_to_tempfile(
def validate_transcript(content: str | None) -> bool:
"""Check that a transcript has actual conversation messages.
A valid transcript for resume needs at least one user message and one
assistant message (not just queue-operation / file-history-snapshot
metadata).
A valid transcript needs at least one assistant message (not just
queue-operation / file-history-snapshot metadata). We do NOT require
a ``type: "user"`` entry because with ``--resume`` the user's message
is passed as a CLI query parameter and does not appear in the
transcript file.
"""
if not content or not content.strip():
return False
lines = content.strip().split("\n")
if len(lines) < 2:
return False
has_user = False
has_assistant = False
for line in lines:
try:
entry = json.loads(line)
msg_type = entry.get("type")
if msg_type == "user":
has_user = True
elif msg_type == "assistant":
has_assistant = True
except json.JSONDecodeError:
if not line.strip():
continue
entry = json.loads(line, fallback=None)
if not isinstance(entry, dict):
return False
if entry.get("type") == "assistant":
has_assistant = True
return has_user and has_assistant
return has_assistant
# ---------------------------------------------------------------------------
@@ -328,26 +284,41 @@ async def upload_transcript(
session_id: str,
content: str,
message_count: int = 0,
log_prefix: str = "[Transcript]",
) -> None:
"""Strip progress entries and upload transcript to bucket storage.
"""Strip progress entries and upload complete transcript.
The transcript represents the FULL active context (atomic).
Each upload REPLACES the previous transcript entirely.
The executor holds a cluster lock per session, so concurrent uploads for
the same session cannot happen. We always overwrite — with ``--resume``
the CLI may compact old tool results, so neither byte size nor line count
is a reliable proxy for "newer".
the same session cannot happen.
Args:
message_count: ``len(session.messages)`` at upload time — used by
the next turn to detect staleness and compress only the gap.
content: Complete JSONL transcript (from TranscriptBuilder).
message_count: ``len(session.messages)`` at upload time.
"""
from backend.util.workspace_storage import get_workspace_storage
# Strip metadata entries (progress, file-history-snapshot, etc.)
# Note: SDK-built transcripts shouldn't have these, but strip for safety
stripped = strip_progress_entries(content)
if not validate_transcript(stripped):
# Log entry types for debugging — helps identify why validation failed
entry_types: list[str] = []
for line in stripped.strip().split("\n"):
entry = json.loads(line, fallback={"type": "INVALID_JSON"})
entry_types.append(entry.get("type", "?"))
logger.warning(
f"[Transcript] Skipping upload — stripped content not valid "
f"for session {session_id}"
"%s Skipping upload — stripped content not valid "
"(types=%s, stripped_len=%d, raw_len=%d)",
log_prefix,
entry_types,
len(stripped),
len(content),
)
logger.debug("%s Raw content preview: %s", log_prefix, content[:500])
logger.debug("%s Stripped content: %s", log_prefix, stripped[:500])
return
storage = await get_workspace_storage()
@@ -373,17 +344,18 @@ async def upload_transcript(
content=json.dumps(meta).encode("utf-8"),
)
except Exception as e:
logger.warning(f"[Transcript] Failed to write metadata for {session_id}: {e}")
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
logger.info(
f"[Transcript] Uploaded {len(encoded)}B "
f"(stripped from {len(content)}B, msg_count={message_count}) "
f"for session {session_id}"
f"{log_prefix} Uploaded {len(encoded)}B "
f"(stripped from {len(content)}B, msg_count={message_count})"
)
async def download_transcript(
user_id: str, session_id: str
user_id: str,
session_id: str,
log_prefix: str = "[Transcript]",
) -> TranscriptDownload | None:
"""Download transcript and metadata from bucket storage.
@@ -399,10 +371,10 @@ async def download_transcript(
data = await storage.retrieve(path)
content = data.decode("utf-8")
except FileNotFoundError:
logger.debug(f"[Transcript] No transcript in storage for {session_id}")
logger.debug(f"{log_prefix} No transcript in storage")
return None
except Exception as e:
logger.warning(f"[Transcript] Failed to download transcript: {e}")
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
return None
# Try to load metadata (best-effort — old transcripts won't have it)
@@ -419,16 +391,13 @@ async def download_transcript(
meta_path = f"local://{mwid}/{mfid}/{mfname}"
meta_data = await storage.retrieve(meta_path)
meta = json.loads(meta_data.decode("utf-8"))
meta = json.loads(meta_data.decode("utf-8"), fallback={})
message_count = meta.get("message_count", 0)
uploaded_at = meta.get("uploaded_at", 0.0)
except (FileNotFoundError, json.JSONDecodeError, Exception):
except (FileNotFoundError, Exception):
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
logger.info(
f"[Transcript] Downloaded {len(content)}B "
f"(msg_count={message_count}) for session {session_id}"
)
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
return TranscriptDownload(
content=content,
message_count=message_count,

View File

@@ -0,0 +1,188 @@
"""Build complete JSONL transcript from SDK messages.
The transcript represents the FULL active context at any point in time.
Each upload REPLACES the previous transcript atomically.
Flow:
Turn 1: Upload [msg1, msg2]
Turn 2: Download [msg1, msg2] → Upload [msg1, msg2, msg3, msg4] (REPLACE)
Turn 3: Download [msg1, msg2, msg3, msg4] → Upload [all messages] (REPLACE)
The transcript is never incremental - always the complete atomic state.
"""
import logging
from typing import Any
from uuid import uuid4
from pydantic import BaseModel
from backend.util import json
from .transcript import STRIPPABLE_TYPES
logger = logging.getLogger(__name__)
class TranscriptEntry(BaseModel):
"""Single transcript entry (user or assistant turn)."""
type: str
uuid: str
parentUuid: str | None
message: dict[str, Any]
class TranscriptBuilder:
"""Build complete JSONL transcript from SDK messages.
This builder maintains the FULL conversation state, not incremental changes.
The output is always the complete active context.
"""
def __init__(self) -> None:
self._entries: list[TranscriptEntry] = []
self._last_uuid: str | None = None
def _last_is_assistant(self) -> bool:
return bool(self._entries) and self._entries[-1].type == "assistant"
def _last_message_id(self) -> str:
"""Return the message.id of the last entry, or '' if none."""
if self._entries:
return self._entries[-1].message.get("id", "")
return ""
def load_previous(self, content: str, log_prefix: str = "[Transcript]") -> None:
"""Load complete previous transcript.
This loads the FULL previous context. As new messages come in,
we append to this state. The final output is the complete context
(previous + new), not just the delta.
"""
if not content or not content.strip():
return
lines = content.strip().split("\n")
for line_num, line in enumerate(lines, 1):
if not line.strip():
continue
data = json.loads(line, fallback=None)
if data is None:
logger.warning(
"%s Failed to parse transcript line %d/%d",
log_prefix,
line_num,
len(lines),
)
continue
# Load all non-strippable entries (user/assistant/system/etc.)
# Skip only STRIPPABLE_TYPES to match strip_progress_entries() behavior
entry_type = data.get("type", "")
if entry_type in STRIPPABLE_TYPES:
continue
entry = TranscriptEntry(
type=data["type"],
uuid=data.get("uuid") or str(uuid4()),
parentUuid=data.get("parentUuid"),
message=data.get("message", {}),
)
self._entries.append(entry)
self._last_uuid = entry.uuid
logger.info(
"%s Loaded %d entries from previous transcript (last_uuid=%s)",
log_prefix,
len(self._entries),
self._last_uuid[:12] if self._last_uuid else None,
)
def append_user(self, content: str | list[dict], uuid: str | None = None) -> None:
"""Append a user entry."""
msg_uuid = uuid or str(uuid4())
self._entries.append(
TranscriptEntry(
type="user",
uuid=msg_uuid,
parentUuid=self._last_uuid,
message={"role": "user", "content": content},
)
)
self._last_uuid = msg_uuid
def append_tool_result(self, tool_use_id: str, content: str) -> None:
"""Append a tool result as a user entry (one per tool call)."""
self.append_user(
content=[
{"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
]
)
def append_assistant(
self,
content_blocks: list[dict],
model: str = "",
stop_reason: str | None = None,
) -> None:
"""Append an assistant entry.
Consecutive assistant entries automatically share the same message ID
so the CLI can merge them (thinking → text → tool_use) into a single
API message on ``--resume``. A new ID is assigned whenever an
assistant entry follows a non-assistant entry (user message or tool
result), because that marks the start of a new API response.
"""
message_id = (
self._last_message_id()
if self._last_is_assistant()
else f"msg_sdk_{uuid4().hex[:24]}"
)
msg_uuid = str(uuid4())
self._entries.append(
TranscriptEntry(
type="assistant",
uuid=msg_uuid,
parentUuid=self._last_uuid,
message={
"role": "assistant",
"model": model,
"id": message_id,
"type": "message",
"content": content_blocks,
"stop_reason": stop_reason,
"stop_sequence": None,
},
)
)
self._last_uuid = msg_uuid
def to_jsonl(self) -> str:
"""Export complete context as JSONL.
Consecutive assistant entries are kept separate to match the
native CLI format — the SDK merges them internally on resume.
Returns the FULL conversation state (all entries), not incremental.
This output REPLACES any previous transcript.
"""
if not self._entries:
return ""
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
return "\n".join(lines) + "\n"
@property
def entry_count(self) -> int:
"""Total number of entries in the complete context."""
return len(self._entries)
@property
def is_empty(self) -> bool:
"""Whether this builder has any entries."""
return len(self._entries) == 0

View File

@@ -1,11 +1,11 @@
"""Unit tests for JSONL transcript management utilities."""
import json
import os
from backend.util import json
from .transcript import (
STRIPPABLE_TYPES,
read_transcript_file,
strip_progress_entries,
validate_transcript,
write_transcript_to_tempfile,
@@ -38,49 +38,6 @@ PROGRESS_ENTRY = {
VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG)
# --- read_transcript_file ---
class TestReadTranscriptFile:
def test_returns_content_for_valid_file(self, tmp_path):
path = tmp_path / "session.jsonl"
path.write_text(VALID_TRANSCRIPT)
result = read_transcript_file(str(path))
assert result is not None
assert "user" in result
def test_returns_none_for_missing_file(self):
assert read_transcript_file("/nonexistent/path.jsonl") is None
def test_returns_none_for_empty_path(self):
assert read_transcript_file("") is None
def test_returns_none_for_empty_file(self, tmp_path):
path = tmp_path / "empty.jsonl"
path.write_text("")
assert read_transcript_file(str(path)) is None
def test_returns_none_for_metadata_only(self, tmp_path):
content = _make_jsonl(METADATA_LINE, FILE_HISTORY)
path = tmp_path / "meta.jsonl"
path.write_text(content)
assert read_transcript_file(str(path)) is None
def test_returns_none_for_invalid_json(self, tmp_path):
path = tmp_path / "bad.jsonl"
path.write_text("not json\n{}\n{}\n")
assert read_transcript_file(str(path)) is None
def test_no_size_limit(self, tmp_path):
"""Large files are accepted — bucket storage has no size limit."""
big_content = {"type": "user", "uuid": "u9", "data": "x" * 1_000_000}
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, big_content, ASST_MSG)
path = tmp_path / "big.jsonl"
path.write_text(content)
result = read_transcript_file(str(path))
assert result is not None
# --- write_transcript_to_tempfile ---
@@ -155,12 +112,56 @@ class TestValidateTranscript:
assert validate_transcript(content) is False
def test_assistant_only_no_user(self):
"""With --resume the user message is a CLI query param, not a transcript entry.
A transcript with only assistant entries is valid."""
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, ASST_MSG)
assert validate_transcript(content) is False
assert validate_transcript(content) is True
def test_resume_transcript_without_user_entry(self):
"""Simulates a real --resume stop hook transcript: the CLI session file
has summary + assistant entries but no user entry."""
summary = {"type": "summary", "uuid": "s1", "text": "context..."}
asst1 = {
"type": "assistant",
"uuid": "a1",
"message": {"role": "assistant", "content": "Hello!"},
}
asst2 = {
"type": "assistant",
"uuid": "a2",
"parentUuid": "a1",
"message": {"role": "assistant", "content": "Sure, let me help."},
}
content = _make_jsonl(summary, asst1, asst2)
assert validate_transcript(content) is True
def test_single_assistant_entry(self):
"""A transcript with just one assistant line is valid — the CLI may
produce short transcripts for simple responses with no tool use."""
content = json.dumps(ASST_MSG) + "\n"
assert validate_transcript(content) is True
def test_invalid_json_returns_false(self):
assert validate_transcript("not json\n{}\n{}\n") is False
def test_malformed_json_after_valid_assistant_returns_false(self):
"""Validation must scan all lines - malformed JSON anywhere should fail."""
valid_asst = json.dumps(ASST_MSG)
malformed = "not valid json"
content = valid_asst + "\n" + malformed + "\n"
assert validate_transcript(content) is False
def test_blank_lines_are_skipped(self):
"""Transcripts with blank lines should be valid if they contain assistant entries."""
content = (
json.dumps(USER_MSG)
+ "\n\n" # blank line
+ json.dumps(ASST_MSG)
+ "\n"
+ "\n" # another blank line
)
assert validate_transcript(content) is True
# --- strip_progress_entries ---
@@ -253,3 +254,31 @@ class TestStripProgressEntries:
assert "queue-operation" not in result_types
assert "user" in result_types
assert "assistant" in result_types
def test_preserves_original_line_formatting(self):
"""Non-reparented entries keep their original JSON formatting."""
# orjson produces compact JSON - test that we preserve the exact input
# when no reparenting is needed (no re-serialization)
original_line = json.dumps(USER_MSG)
content = original_line + "\n" + json.dumps(ASST_MSG) + "\n"
result = strip_progress_entries(content)
result_lines = result.strip().split("\n")
# Original line should be byte-identical (not re-serialized)
assert result_lines[0] == original_line
def test_reparented_entries_are_reserialized(self):
"""Entries whose parentUuid changes must be re-serialized."""
progress = {"type": "progress", "uuid": "p1", "parentUuid": "u1"}
asst = {
"type": "assistant",
"uuid": "a1",
"parentUuid": "p1",
"message": {"role": "assistant", "content": "done"},
}
content = _make_jsonl(USER_MSG, progress, asst)
result = strip_progress_entries(content)
lines = result.strip().split("\n")
asst_entry = json.loads(lines[-1])
assert asst_entry["parentUuid"] == "u1" # reparented

View File

@@ -18,7 +18,7 @@ from langfuse.openai import (
from backend.data.db_accessors import understanding_db
from backend.data.understanding import format_understanding_for_prompt
from backend.util.exceptions import NotFoundError
from backend.util.exceptions import NotAuthorizedError, NotFoundError
from backend.util.settings import AppEnvironment, Settings
from .config import ChatConfig
@@ -34,8 +34,9 @@ client = LangfuseAsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
langfuse = get_client()
# Default system prompt used when Langfuse is not configured
# This is a snapshot of the "CoPilot Prompt" from Langfuse (version 11)
DEFAULT_SYSTEM_PROMPT = """You are **Otto**, an AI Co-Pilot for AutoGPT and a Forward-Deployed Automation Engineer serving small business owners. Your mission is to help users automate business tasks with AI by delivering tangible value through working automations—not through documentation or lengthy explanations.
# Provides minimal baseline tone and personality - all workflow, tools, and
# technical details are provided via the supplement.
DEFAULT_SYSTEM_PROMPT = """You are an AI automation assistant helping users build and run automations.
Here is everything you know about the current user from previous interactions:
@@ -43,113 +44,12 @@ Here is everything you know about the current user from previous interactions:
{users_information}
</users_information>
## YOUR CORE MANDATE
Your goal is to help users automate tasks by:
- Understanding their needs and business context
- Building and running working automations
- Delivering tangible value through action, not just explanation
You are action-oriented. Your success is measured by:
- **Value Delivery**: Does the user think "wow, that was amazing" or "what was the point"?
- **Demonstrable Proof**: Show working automations, not descriptions of what's possible
- **Time Saved**: Focus on tangible efficiency gains
- **Quality Output**: Deliver results that meet or exceed expectations
## YOUR WORKFLOW
Adapt flexibly to the conversation context. Not every interaction requires all stages:
1. **Explore & Understand**: Learn about the user's business, tasks, and goals. Use `add_understanding` to capture important context that will improve future conversations.
2. **Assess Automation Potential**: Help the user understand whether and how AI can automate their task.
3. **Prepare for AI**: Provide brief, actionable guidance on prerequisites (data, access, etc.).
4. **Discover or Create Agents**:
- **Always check the user's library first** with `find_library_agent` (these may be customized to their needs)
- Search the marketplace with `find_agent` for pre-built automations
- Find reusable components with `find_block`
- **For live integrations** (read a GitHub repo, query a database, post to Slack, etc.) consider `run_mcp_tool` — it connects directly to external services without building a full agent
- Create custom solutions with `create_agent` if nothing suitable exists
- Modify existing library agents with `edit_agent`
- **When `create_agent` returns `suggested_goal`**: Present the suggestion to the user and ask "Would you like me to proceed with this refined goal?" If they accept, call `create_agent` again with the suggested goal.
- **When `create_agent` returns `clarifying_questions`**: After the user answers, call `create_agent` again with the original description AND the answers in the `context` parameter.
5. **Execute**: Run automations immediately, schedule them, or set up webhooks using `run_agent`. Test specific components with `run_block`.
6. **Show Results**: Display outputs using `agent_output`.
## AVAILABLE TOOLS
**Understanding & Discovery:**
- `add_understanding`: Create a memory about the user's business or use cases for future sessions
- `search_docs`: Search platform documentation for specific technical information
- `get_doc_page`: Retrieve full text of a specific documentation page
**Agent Discovery:**
- `find_library_agent`: Search the user's existing agents (CHECK HERE FIRST—these may be customized)
- `find_agent`: Search the marketplace for pre-built automations
- `find_block`: Find pre-written code units that perform specific tasks (agents are built from blocks)
**Agent Creation & Editing:**
- `create_agent`: Create a new automation agent
- `edit_agent`: Modify an agent in the user's library
**Execution & Output:**
- `run_agent`: Run an agent now, schedule it, or set up a webhook trigger
- `run_block`: Test or run a specific block independently
- `agent_output`: View results from previous agent runs
**MCP (Model Context Protocol) Servers:**
- `run_mcp_tool`: Connect to any MCP server to discover and run its tools
**Two-step flow:**
1. `run_mcp_tool(server_url)` → returns a list of available tools. Each tool has `name`, `description`, and `input_schema` (JSON Schema). Read `input_schema.properties` to understand what arguments are needed.
2. `run_mcp_tool(server_url, tool_name, tool_arguments)` → executes the tool. Build `tool_arguments` as a flat `{{key: value}}` object matching the tool's `input_schema.properties`.
**Authentication:** If the MCP server requires credentials, the UI will show an OAuth connect button. Once the user connects and clicks Proceed, they will automatically send you a message confirming credentials are ready (e.g. "I've connected the MCP server credentials. Please retry run_mcp_tool..."). When you receive that confirmation, **immediately** call `run_mcp_tool` again with the exact same `server_url` — and the same `tool_name`/`tool_arguments` if you were already mid-execution. Do not ask the user what to do next; just retry.
**Finding server URLs (fastest → slowest):**
1. **Known hosted servers** — use directly, no lookup:
- Notion: `https://mcp.notion.com/mcp`
- Linear: `https://mcp.linear.app/mcp`
- Stripe: `https://mcp.stripe.com`
- Intercom: `https://mcp.intercom.com/mcp`
- Cloudflare: `https://mcp.cloudflare.com/mcp`
- Atlassian (Jira/Confluence): `https://mcp.atlassian.com/mcp`
2. **`web_search`** — use `web_search("{{service}} MCP server URL")` for any service not in the list above. This is the fastest way to find unlisted servers.
3. **Registry API** — `web_fetch("https://registry.modelcontextprotocol.io/v0.1/servers?search={{query}}&limit=10")` to browse what's available. Returns names + GitHub repo URLs but NOT the endpoint URL; follow up with `web_search` to find the actual endpoint.
- **Never** `web_fetch` the registry homepage — it is JavaScript-rendered and returns a blank page.
**When to use:** Use `run_mcp_tool` when the user wants to interact with an external service (GitHub, Slack, a database, a SaaS tool, etc.) via its MCP integration. Unlike `web_fetch` (which just retrieves a raw URL), MCP servers expose structured typed tools — prefer `run_mcp_tool` for any service with an MCP server, and `web_fetch` only for plain URL retrieval with no MCP server involved.
**CRITICAL**: `run_mcp_tool` is **always available** in your tool list. If the user explicitly provides an MCP server URL or asks you to call `run_mcp_tool`, you MUST use it — never claim it is unavailable, and never substitute `web_fetch` for an explicit MCP request.
## BEHAVIORAL GUIDELINES
**Be Concise:**
- Target 2-5 short lines maximum
- Make every word count—no repetition or filler
- Use lightweight structure for scannability (bullets, numbered lists, short prompts)
- Avoid jargon (blocks, slugs, cron) unless the user asks
**Be Proactive:**
- Suggest next steps before being asked
- Anticipate needs based on conversation context and user information
- Look for opportunities to expand scope when relevant
- Reveal capabilities through action, not explanation
**Use Tools Effectively:**
- Select the right tool for each task
- **Always check `find_library_agent` before searching the marketplace**
- Use `add_understanding` to capture valuable business context
- When tool calls fail, try alternative approaches
- **For MCP integrations**: Known URL (see list) or `web_search("{{service}} MCP server URL")` → `run_mcp_tool(server_url)` → `run_mcp_tool(server_url, tool_name, tool_arguments)`. If credentials needed, UI prompts automatically; when user confirms, retry immediately with same arguments.
**Handle Feedback Loops:**
- When a tool returns a suggested alternative (like a refined goal), present it clearly and ask the user for confirmation before proceeding
- When clarifying questions are answered, immediately re-call the tool with the accumulated context
- Don't ask redundant questions if the user has already provided context in the conversation
## CRITICAL REMINDER
You are NOT a chatbot. You are NOT documentation. You are a partner who helps busy business owners get value quickly by showing proof through working automations. Bias toward action over explanation."""
Be concise, proactive, and action-oriented. Bias toward showing working solutions over lengthy explanations."""
# ---------------------------------------------------------------------------
@@ -298,6 +198,12 @@ async def assign_user_to_session(
session = await get_chat_session(session_id, None)
if not session:
raise NotFoundError(f"Session {session_id} not found")
if session.user_id is not None and session.user_id != user_id:
logger.warning(
f"[SECURITY] Attempt to claim session {session_id} by user {user_id}, "
f"but it already belongs to user {session.user_id}"
)
raise NotAuthorizedError(f"Not authorized to claim session {session_id}")
session.user_id = user_id
session = await upsert_chat_session(session)
return session

View File

@@ -20,6 +20,14 @@ from .find_agent import FindAgentTool
from .find_block import FindBlockTool
from .find_library_agent import FindLibraryAgentTool
from .get_doc_page import GetDocPageTool
from .manage_folders import (
CreateFolderTool,
DeleteFolderTool,
ListFoldersTool,
MoveAgentsToFolderTool,
MoveFolderTool,
UpdateFolderTool,
)
from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .run_mcp_tool import RunMCPToolTool
@@ -47,6 +55,13 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"find_agent": FindAgentTool(),
"find_block": FindBlockTool(),
"find_library_agent": FindLibraryAgentTool(),
# Folder management tools
"create_folder": CreateFolderTool(),
"list_folders": ListFoldersTool(),
"update_folder": UpdateFolderTool(),
"move_folder": MoveFolderTool(),
"delete_folder": DeleteFolderTool(),
"move_agents_to_folder": MoveAgentsToFolderTool(),
"run_agent": RunAgentTool(),
"run_block": RunBlockTool(),
"run_mcp_tool": RunMCPToolTool(),

View File

@@ -151,8 +151,8 @@ async def setup_test_data(server):
unique_slug = f"test-agent-{str(uuid.uuid4())[:8]}"
store_submission = await store_db.create_store_submission(
user_id=user.id,
agent_id=created_graph.id,
agent_version=created_graph.version,
graph_id=created_graph.id,
graph_version=created_graph.version,
slug=unique_slug,
name="Test Agent",
description="A simple test agent",
@@ -161,10 +161,10 @@ async def setup_test_data(server):
image_urls=["https://example.com/image.jpg"],
)
assert store_submission.store_listing_version_id is not None
assert store_submission.listing_version_id is not None
# 4. Approve the store listing version
await store_db.review_store_submission(
store_listing_version_id=store_submission.store_listing_version_id,
store_listing_version_id=store_submission.listing_version_id,
is_approved=True,
external_comments="Approved for testing",
internal_comments="Test approval",
@@ -321,8 +321,8 @@ async def setup_llm_test_data(server):
unique_slug = f"llm-test-agent-{str(uuid.uuid4())[:8]}"
store_submission = await store_db.create_store_submission(
user_id=user.id,
agent_id=created_graph.id,
agent_version=created_graph.version,
graph_id=created_graph.id,
graph_version=created_graph.version,
slug=unique_slug,
name="LLM Test Agent",
description="An agent with LLM capabilities",
@@ -330,9 +330,9 @@ async def setup_llm_test_data(server):
categories=["testing", "ai"],
image_urls=["https://example.com/image.jpg"],
)
assert store_submission.store_listing_version_id is not None
assert store_submission.listing_version_id is not None
await store_db.review_store_submission(
store_listing_version_id=store_submission.store_listing_version_id,
store_listing_version_id=store_submission.listing_version_id,
is_approved=True,
external_comments="Approved for testing",
internal_comments="Test approval for LLM agent",
@@ -476,8 +476,8 @@ async def setup_firecrawl_test_data(server):
unique_slug = f"firecrawl-test-agent-{str(uuid.uuid4())[:8]}"
store_submission = await store_db.create_store_submission(
user_id=user.id,
agent_id=created_graph.id,
agent_version=created_graph.version,
graph_id=created_graph.id,
graph_version=created_graph.version,
slug=unique_slug,
name="Firecrawl Test Agent",
description="An agent with Firecrawl integration (no credentials)",
@@ -485,9 +485,9 @@ async def setup_firecrawl_test_data(server):
categories=["testing", "scraping"],
image_urls=["https://example.com/image.jpg"],
)
assert store_submission.store_listing_version_id is not None
assert store_submission.listing_version_id is not None
await store_db.review_store_submission(
store_listing_version_id=store_submission.store_listing_version_id,
store_listing_version_id=store_submission.listing_version_id,
is_approved=True,
external_comments="Approved for testing",
internal_comments="Test approval for Firecrawl agent",

View File

@@ -695,7 +695,10 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
async def save_agent_to_library(
agent_json: dict[str, Any], user_id: str, is_update: bool = False
agent_json: dict[str, Any],
user_id: str,
is_update: bool = False,
folder_id: str | None = None,
) -> tuple[Graph, Any]:
"""Save agent to database and user's library.
@@ -703,6 +706,7 @@ async def save_agent_to_library(
agent_json: Agent JSON dict
user_id: User ID
is_update: Whether this is an update to an existing agent
folder_id: Optional folder ID to place the agent in
Returns:
Tuple of (created Graph, LibraryAgent)
@@ -711,7 +715,7 @@ async def save_agent_to_library(
db = library_db()
if is_update:
return await db.update_graph_in_library(graph, user_id)
return await db.create_graph_in_library(graph, user_id)
return await db.create_graph_in_library(graph, user_id, folder_id=folder_id)
def graph_to_json(graph: Graph) -> dict[str, Any]:

View File

@@ -39,9 +39,13 @@ class CreateAgentTool(BaseTool):
return (
"Create a new agent workflow from a natural language description. "
"First generates a preview, then saves to library if save=true. "
"\n\nIMPORTANT: Before calling this tool, search for relevant existing agents "
"using find_library_agent that could be used as building blocks. "
"Pass their IDs in the library_agent_ids parameter so the generator can compose them."
"\n\nWorkflow: (1) Always check find_library_agent first for existing building blocks. "
"(2) Call create_agent with description and library_agent_ids. "
"(3) If response contains suggested_goal: Present to user, ask for confirmation, "
"then call again with the suggested goal if accepted. "
"(4) If response contains clarifying_questions: Present to user, collect answers, "
"then call again with original description AND answers in the context parameter. "
"\n\nThis feedback loop ensures the generated agent matches user intent."
)
@property
@@ -84,6 +88,14 @@ class CreateAgentTool(BaseTool):
),
"default": True,
},
"folder_id": {
"type": "string",
"description": (
"Optional folder ID to save the agent into. "
"If not provided, the agent is saved at root level. "
"Use list_folders to find available folders."
),
},
},
"required": ["description"],
}
@@ -105,6 +117,7 @@ class CreateAgentTool(BaseTool):
context = kwargs.get("context", "")
library_agent_ids = kwargs.get("library_agent_ids", [])
save = kwargs.get("save", True)
folder_id = kwargs.get("folder_id")
session_id = session.session_id if session else None
logger.info(
@@ -336,7 +349,7 @@ class CreateAgentTool(BaseTool):
try:
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id
agent_json, user_id, folder_id=folder_id
)
logger.info(

View File

@@ -3,9 +3,9 @@
import logging
from typing import Any
from backend.api.features.store.exceptions import AgentNotFoundError
from backend.copilot.model import ChatSession
from backend.data.db_accessors import store_db as get_store_db
from backend.util.exceptions import NotFoundError
from .agent_generator import (
AgentGeneratorNotConfiguredError,
@@ -80,6 +80,14 @@ class CustomizeAgentTool(BaseTool):
),
"default": True,
},
"folder_id": {
"type": "string",
"description": (
"Optional folder ID to save the agent into. "
"If not provided, the agent is saved at root level. "
"Use list_folders to find available folders."
),
},
},
"required": ["agent_id", "modifications"],
}
@@ -102,6 +110,7 @@ class CustomizeAgentTool(BaseTool):
modifications = kwargs.get("modifications", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
folder_id = kwargs.get("folder_id")
session_id = session.session_id if session else None
if not agent_id:
@@ -140,7 +149,7 @@ class CustomizeAgentTool(BaseTool):
agent_details = await store_db.get_store_agent_details(
username=creator_username, agent_name=agent_slug
)
except AgentNotFoundError:
except NotFoundError:
return ErrorResponse(
message=(
f"Could not find marketplace agent '{agent_id}'. "
@@ -310,7 +319,7 @@ class CustomizeAgentTool(BaseTool):
# Save to user's library
try:
created_graph, library_agent = await save_agent_to_library(
customized_agent, user_id, is_update=False
customized_agent, user_id, is_update=False, folder_id=folder_id
)
return AgentSavedResponse(

View File

@@ -0,0 +1,573 @@
"""Folder management tools for the copilot."""
from typing import Any
from backend.api.features.library import model as library_model
from backend.api.features.library.db import collect_tree_ids
from backend.copilot.model import ChatSession
from backend.data.db_accessors import library_db
from .base import BaseTool
from .models import (
AgentsMovedToFolderResponse,
ErrorResponse,
FolderAgentSummary,
FolderCreatedResponse,
FolderDeletedResponse,
FolderInfo,
FolderListResponse,
FolderMovedResponse,
FolderTreeInfo,
FolderUpdatedResponse,
ToolResponseBase,
)
def _folder_to_info(
folder: library_model.LibraryFolder,
agents: list[FolderAgentSummary] | None = None,
) -> FolderInfo:
"""Convert a LibraryFolder DB model to a FolderInfo response model."""
return FolderInfo(
id=folder.id,
name=folder.name,
parent_id=folder.parent_id,
icon=folder.icon,
color=folder.color,
agent_count=folder.agent_count,
subfolder_count=folder.subfolder_count,
agents=agents,
)
def _tree_to_info(
tree: library_model.LibraryFolderTree,
agents_map: dict[str, list[FolderAgentSummary]] | None = None,
) -> FolderTreeInfo:
"""Recursively convert a LibraryFolderTree to a FolderTreeInfo response."""
return FolderTreeInfo(
id=tree.id,
name=tree.name,
parent_id=tree.parent_id,
icon=tree.icon,
color=tree.color,
agent_count=tree.agent_count,
subfolder_count=tree.subfolder_count,
children=[_tree_to_info(child, agents_map) for child in tree.children],
agents=agents_map.get(tree.id) if agents_map else None,
)
def _to_agent_summaries(
raw: list[dict[str, str | None]],
) -> list[FolderAgentSummary]:
"""Convert raw agent dicts to typed FolderAgentSummary models."""
return [
FolderAgentSummary(
id=a["id"] or "",
name=a["name"] or "",
description=a["description"] or "",
)
for a in raw
]
def _to_agent_summaries_map(
raw: dict[str, list[dict[str, str | None]]],
) -> dict[str, list[FolderAgentSummary]]:
"""Convert a folder-id-keyed dict of raw agents to typed summaries."""
return {fid: _to_agent_summaries(agents) for fid, agents in raw.items()}
class CreateFolderTool(BaseTool):
"""Tool for creating a library folder."""
@property
def name(self) -> str:
return "create_folder"
@property
def description(self) -> str:
return (
"Create a new folder in the user's library to organize agents. "
"Optionally nest it inside an existing folder using parent_id."
)
@property
def requires_auth(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Name for the new folder (max 100 chars).",
},
"parent_id": {
"type": "string",
"description": (
"ID of the parent folder to nest inside. "
"Omit to create at root level."
),
},
"icon": {
"type": "string",
"description": "Optional icon identifier for the folder.",
},
"color": {
"type": "string",
"description": "Optional hex color code (#RRGGBB).",
},
},
"required": ["name"],
}
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Create a folder with the given name and optional parent/icon/color."""
assert user_id is not None # guaranteed by requires_auth
name = (kwargs.get("name") or "").strip()
parent_id = kwargs.get("parent_id")
icon = kwargs.get("icon")
color = kwargs.get("color")
session_id = session.session_id if session else None
if not name:
return ErrorResponse(
message="Please provide a folder name.",
error="missing_name",
session_id=session_id,
)
try:
folder = await library_db().create_folder(
user_id=user_id,
name=name,
parent_id=parent_id,
icon=icon,
color=color,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to create folder: {e}",
error="create_folder_failed",
session_id=session_id,
)
return FolderCreatedResponse(
message=f"Folder '{folder.name}' created successfully!",
folder=_folder_to_info(folder),
session_id=session_id,
)
class ListFoldersTool(BaseTool):
"""Tool for listing library folders."""
@property
def name(self) -> str:
return "list_folders"
@property
def description(self) -> str:
return (
"List the user's library folders. "
"Omit parent_id to get the full folder tree. "
"Provide parent_id to list only direct children of that folder. "
"Set include_agents=true to also return the agents inside each folder "
"and root-level agents not in any folder. Always set include_agents=true "
"when the user asks about agents, wants to see what's in their folders, "
"or mentions agents alongside folders."
)
@property
def requires_auth(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"parent_id": {
"type": "string",
"description": (
"List children of this folder. "
"Omit to get the full folder tree."
),
},
"include_agents": {
"type": "boolean",
"description": (
"Whether to include the list of agents inside each folder. "
"Defaults to false."
),
},
},
"required": [],
}
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""List folders as a flat list (by parent) or full tree."""
assert user_id is not None # guaranteed by requires_auth
parent_id = kwargs.get("parent_id")
include_agents = kwargs.get("include_agents", False)
session_id = session.session_id if session else None
try:
if parent_id:
folders = await library_db().list_folders(
user_id=user_id, parent_id=parent_id
)
raw_map = (
await library_db().get_folder_agents_map(
user_id, [f.id for f in folders]
)
if include_agents
else None
)
agents_map = _to_agent_summaries_map(raw_map) if raw_map else None
return FolderListResponse(
message=f"Found {len(folders)} folder(s).",
folders=[
_folder_to_info(f, agents_map.get(f.id) if agents_map else None)
for f in folders
],
count=len(folders),
session_id=session_id,
)
else:
tree = await library_db().get_folder_tree(user_id=user_id)
all_ids = collect_tree_ids(tree)
agents_map = None
root_agents = None
if include_agents:
raw_map = await library_db().get_folder_agents_map(user_id, all_ids)
agents_map = _to_agent_summaries_map(raw_map)
root_agents = _to_agent_summaries(
await library_db().get_root_agent_summaries(user_id)
)
return FolderListResponse(
message=f"Found {len(all_ids)} folder(s) in your library.",
tree=[_tree_to_info(t, agents_map) for t in tree],
root_agents=root_agents,
count=len(all_ids),
session_id=session_id,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to list folders: {e}",
error="list_folders_failed",
session_id=session_id,
)
class UpdateFolderTool(BaseTool):
"""Tool for updating a folder's properties."""
@property
def name(self) -> str:
return "update_folder"
@property
def description(self) -> str:
return "Update a folder's name, icon, or color."
@property
def requires_auth(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"folder_id": {
"type": "string",
"description": "ID of the folder to update.",
},
"name": {
"type": "string",
"description": "New name for the folder.",
},
"icon": {
"type": "string",
"description": "New icon identifier.",
},
"color": {
"type": "string",
"description": "New hex color code (#RRGGBB).",
},
},
"required": ["folder_id"],
}
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Update a folder's name, icon, or color."""
assert user_id is not None # guaranteed by requires_auth
folder_id = (kwargs.get("folder_id") or "").strip()
name = kwargs.get("name")
icon = kwargs.get("icon")
color = kwargs.get("color")
session_id = session.session_id if session else None
if not folder_id:
return ErrorResponse(
message="Please provide a folder_id.",
error="missing_folder_id",
session_id=session_id,
)
try:
folder = await library_db().update_folder(
folder_id=folder_id,
user_id=user_id,
name=name,
icon=icon,
color=color,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to update folder: {e}",
error="update_folder_failed",
session_id=session_id,
)
return FolderUpdatedResponse(
message=f"Folder updated to '{folder.name}'.",
folder=_folder_to_info(folder),
session_id=session_id,
)
class MoveFolderTool(BaseTool):
"""Tool for moving a folder to a new parent."""
@property
def name(self) -> str:
return "move_folder"
@property
def description(self) -> str:
return (
"Move a folder to a different parent folder. "
"Set target_parent_id to null to move to root level."
)
@property
def requires_auth(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"folder_id": {
"type": "string",
"description": "ID of the folder to move.",
},
"target_parent_id": {
"type": ["string", "null"],
"description": (
"ID of the new parent folder. "
"Use null to move to root level."
),
},
},
"required": ["folder_id"],
}
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Move a folder to a new parent or to root level."""
assert user_id is not None # guaranteed by requires_auth
folder_id = (kwargs.get("folder_id") or "").strip()
target_parent_id = kwargs.get("target_parent_id")
session_id = session.session_id if session else None
if not folder_id:
return ErrorResponse(
message="Please provide a folder_id.",
error="missing_folder_id",
session_id=session_id,
)
try:
folder = await library_db().move_folder(
folder_id=folder_id,
user_id=user_id,
target_parent_id=target_parent_id,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to move folder: {e}",
error="move_folder_failed",
session_id=session_id,
)
dest = "a subfolder" if target_parent_id else "root level"
return FolderMovedResponse(
message=f"Folder '{folder.name}' moved to {dest}.",
folder=_folder_to_info(folder),
target_parent_id=target_parent_id,
session_id=session_id,
)
class DeleteFolderTool(BaseTool):
"""Tool for deleting a folder."""
@property
def name(self) -> str:
return "delete_folder"
@property
def description(self) -> str:
return (
"Delete a folder from the user's library. "
"Agents inside the folder are moved to root level (not deleted)."
)
@property
def requires_auth(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"folder_id": {
"type": "string",
"description": "ID of the folder to delete.",
},
},
"required": ["folder_id"],
}
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Soft-delete a folder; agents inside are moved to root level."""
assert user_id is not None # guaranteed by requires_auth
folder_id = (kwargs.get("folder_id") or "").strip()
session_id = session.session_id if session else None
if not folder_id:
return ErrorResponse(
message="Please provide a folder_id.",
error="missing_folder_id",
session_id=session_id,
)
try:
await library_db().delete_folder(
folder_id=folder_id,
user_id=user_id,
soft_delete=True,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to delete folder: {e}",
error="delete_folder_failed",
session_id=session_id,
)
return FolderDeletedResponse(
message="Folder deleted. Any agents inside were moved to root level.",
folder_id=folder_id,
session_id=session_id,
)
class MoveAgentsToFolderTool(BaseTool):
"""Tool for moving agents into a folder."""
@property
def name(self) -> str:
return "move_agents_to_folder"
@property
def description(self) -> str:
return (
"Move one or more agents to a folder. "
"Set folder_id to null to move agents to root level."
)
@property
def requires_auth(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": "List of library agent IDs to move.",
},
"folder_id": {
"type": ["string", "null"],
"description": (
"Target folder ID. Use null to move to root level."
),
},
},
"required": ["agent_ids"],
}
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Move one or more agents to a folder or to root level."""
assert user_id is not None # guaranteed by requires_auth
agent_ids = kwargs.get("agent_ids", [])
folder_id = kwargs.get("folder_id")
session_id = session.session_id if session else None
if not agent_ids:
return ErrorResponse(
message="Please provide at least one agent ID.",
error="missing_agent_ids",
session_id=session_id,
)
try:
moved = await library_db().bulk_move_agents_to_folder(
agent_ids=agent_ids,
folder_id=folder_id,
user_id=user_id,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to move agents: {e}",
error="move_agents_failed",
session_id=session_id,
)
moved_ids = [a.id for a in moved]
agent_names = [a.name for a in moved]
dest = "the folder" if folder_id else "root level"
names_str = (
", ".join(agent_names) if agent_names else f"{len(agent_ids)} agent(s)"
)
return AgentsMovedToFolderResponse(
message=f"Moved {names_str} to {dest}.",
agent_ids=moved_ids,
agent_names=agent_names,
folder_id=folder_id,
count=len(moved),
session_id=session_id,
)

View File

@@ -0,0 +1,455 @@
"""Tests for folder management copilot tools."""
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.api.features.library import model as library_model
from backend.copilot.tools.manage_folders import (
CreateFolderTool,
DeleteFolderTool,
ListFoldersTool,
MoveAgentsToFolderTool,
MoveFolderTool,
UpdateFolderTool,
)
from backend.copilot.tools.models import (
AgentsMovedToFolderResponse,
ErrorResponse,
FolderCreatedResponse,
FolderDeletedResponse,
FolderListResponse,
FolderMovedResponse,
FolderUpdatedResponse,
)
from ._test_data import make_session
_TEST_USER_ID = "test-user-folders"
_NOW = datetime.now(UTC)
def _make_folder(
id: str = "folder-1",
name: str = "My Folder",
parent_id: str | None = None,
icon: str | None = None,
color: str | None = None,
agent_count: int = 0,
subfolder_count: int = 0,
) -> library_model.LibraryFolder:
return library_model.LibraryFolder(
id=id,
user_id=_TEST_USER_ID,
name=name,
icon=icon,
color=color,
parent_id=parent_id,
created_at=_NOW,
updated_at=_NOW,
agent_count=agent_count,
subfolder_count=subfolder_count,
)
def _make_tree(
id: str = "folder-1",
name: str = "Root",
children: list[library_model.LibraryFolderTree] | None = None,
) -> library_model.LibraryFolderTree:
return library_model.LibraryFolderTree(
id=id,
user_id=_TEST_USER_ID,
name=name,
created_at=_NOW,
updated_at=_NOW,
children=children or [],
)
def _make_library_agent(id: str = "agent-1", name: str = "Test Agent"):
agent = MagicMock()
agent.id = id
agent.name = name
return agent
@pytest.fixture
def session():
return make_session(_TEST_USER_ID)
# ── CreateFolderTool ──
@pytest.fixture
def create_tool():
return CreateFolderTool()
@pytest.mark.asyncio
async def test_create_folder_missing_name(create_tool, session):
result = await create_tool._execute(user_id=_TEST_USER_ID, session=session, name="")
assert isinstance(result, ErrorResponse)
assert result.error == "missing_name"
@pytest.mark.asyncio
async def test_create_folder_none_name(create_tool, session):
result = await create_tool._execute(
user_id=_TEST_USER_ID, session=session, name=None
)
assert isinstance(result, ErrorResponse)
assert result.error == "missing_name"
@pytest.mark.asyncio
async def test_create_folder_success(create_tool, session):
folder = _make_folder(name="New Folder")
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.create_folder = AsyncMock(return_value=folder)
result = await create_tool._execute(
user_id=_TEST_USER_ID, session=session, name="New Folder"
)
assert isinstance(result, FolderCreatedResponse)
assert result.folder.name == "New Folder"
assert "New Folder" in result.message
@pytest.mark.asyncio
async def test_create_folder_db_error(create_tool, session):
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.create_folder = AsyncMock(
side_effect=Exception("db down")
)
result = await create_tool._execute(
user_id=_TEST_USER_ID, session=session, name="Folder"
)
assert isinstance(result, ErrorResponse)
assert result.error == "create_folder_failed"
# ── ListFoldersTool ──
@pytest.fixture
def list_tool():
return ListFoldersTool()
@pytest.mark.asyncio
async def test_list_folders_by_parent(list_tool, session):
folders = [_make_folder(id="f1", name="A"), _make_folder(id="f2", name="B")]
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.list_folders = AsyncMock(return_value=folders)
result = await list_tool._execute(
user_id=_TEST_USER_ID, session=session, parent_id="parent-1"
)
assert isinstance(result, FolderListResponse)
assert result.count == 2
assert len(result.folders) == 2
@pytest.mark.asyncio
async def test_list_folders_tree(list_tool, session):
tree = [
_make_tree(id="r1", name="Root", children=[_make_tree(id="c1", name="Child")])
]
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.get_folder_tree = AsyncMock(return_value=tree)
result = await list_tool._execute(user_id=_TEST_USER_ID, session=session)
assert isinstance(result, FolderListResponse)
assert result.count == 2 # root + child
assert result.tree is not None
assert len(result.tree) == 1
@pytest.mark.asyncio
async def test_list_folders_tree_with_agents_includes_root(list_tool, session):
tree = [_make_tree(id="r1", name="Root")]
raw_map = {"r1": [{"id": "a1", "name": "Foldered", "description": "In folder"}]}
root_raw = [{"id": "a2", "name": "Loose Agent", "description": "At root"}]
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.get_folder_tree = AsyncMock(return_value=tree)
mock_lib.return_value.get_folder_agents_map = AsyncMock(return_value=raw_map)
mock_lib.return_value.get_root_agent_summaries = AsyncMock(
return_value=root_raw
)
result = await list_tool._execute(
user_id=_TEST_USER_ID, session=session, include_agents=True
)
assert isinstance(result, FolderListResponse)
assert result.root_agents is not None
assert len(result.root_agents) == 1
assert result.root_agents[0].name == "Loose Agent"
assert result.tree is not None
assert result.tree[0].agents is not None
assert result.tree[0].agents[0].name == "Foldered"
mock_lib.return_value.get_root_agent_summaries.assert_awaited_once_with(
_TEST_USER_ID
)
@pytest.mark.asyncio
async def test_list_folders_tree_without_agents_no_root(list_tool, session):
tree = [_make_tree(id="r1", name="Root")]
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.get_folder_tree = AsyncMock(return_value=tree)
result = await list_tool._execute(
user_id=_TEST_USER_ID, session=session, include_agents=False
)
assert isinstance(result, FolderListResponse)
assert result.root_agents is None
@pytest.mark.asyncio
async def test_list_folders_db_error(list_tool, session):
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.get_folder_tree = AsyncMock(
side_effect=Exception("timeout")
)
result = await list_tool._execute(user_id=_TEST_USER_ID, session=session)
assert isinstance(result, ErrorResponse)
assert result.error == "list_folders_failed"
# ── UpdateFolderTool ──
@pytest.fixture
def update_tool():
return UpdateFolderTool()
@pytest.mark.asyncio
async def test_update_folder_missing_id(update_tool, session):
result = await update_tool._execute(
user_id=_TEST_USER_ID, session=session, folder_id=""
)
assert isinstance(result, ErrorResponse)
assert result.error == "missing_folder_id"
@pytest.mark.asyncio
async def test_update_folder_none_id(update_tool, session):
result = await update_tool._execute(
user_id=_TEST_USER_ID, session=session, folder_id=None
)
assert isinstance(result, ErrorResponse)
assert result.error == "missing_folder_id"
@pytest.mark.asyncio
async def test_update_folder_success(update_tool, session):
folder = _make_folder(name="Renamed")
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.update_folder = AsyncMock(return_value=folder)
result = await update_tool._execute(
user_id=_TEST_USER_ID, session=session, folder_id="folder-1", name="Renamed"
)
assert isinstance(result, FolderUpdatedResponse)
assert result.folder.name == "Renamed"
@pytest.mark.asyncio
async def test_update_folder_db_error(update_tool, session):
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.update_folder = AsyncMock(
side_effect=Exception("not found")
)
result = await update_tool._execute(
user_id=_TEST_USER_ID, session=session, folder_id="folder-1", name="X"
)
assert isinstance(result, ErrorResponse)
assert result.error == "update_folder_failed"
# ── MoveFolderTool ──
@pytest.fixture
def move_tool():
return MoveFolderTool()
@pytest.mark.asyncio
async def test_move_folder_missing_id(move_tool, session):
result = await move_tool._execute(
user_id=_TEST_USER_ID, session=session, folder_id=""
)
assert isinstance(result, ErrorResponse)
assert result.error == "missing_folder_id"
@pytest.mark.asyncio
async def test_move_folder_to_parent(move_tool, session):
folder = _make_folder(name="Moved")
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.move_folder = AsyncMock(return_value=folder)
result = await move_tool._execute(
user_id=_TEST_USER_ID,
session=session,
folder_id="folder-1",
target_parent_id="parent-1",
)
assert isinstance(result, FolderMovedResponse)
assert "subfolder" in result.message
@pytest.mark.asyncio
async def test_move_folder_to_root(move_tool, session):
folder = _make_folder(name="Moved")
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.move_folder = AsyncMock(return_value=folder)
result = await move_tool._execute(
user_id=_TEST_USER_ID,
session=session,
folder_id="folder-1",
target_parent_id=None,
)
assert isinstance(result, FolderMovedResponse)
assert "root level" in result.message
@pytest.mark.asyncio
async def test_move_folder_db_error(move_tool, session):
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.move_folder = AsyncMock(side_effect=Exception("circular"))
result = await move_tool._execute(
user_id=_TEST_USER_ID, session=session, folder_id="folder-1"
)
assert isinstance(result, ErrorResponse)
assert result.error == "move_folder_failed"
# ── DeleteFolderTool ──
@pytest.fixture
def delete_tool():
return DeleteFolderTool()
@pytest.mark.asyncio
async def test_delete_folder_missing_id(delete_tool, session):
result = await delete_tool._execute(
user_id=_TEST_USER_ID, session=session, folder_id=""
)
assert isinstance(result, ErrorResponse)
assert result.error == "missing_folder_id"
@pytest.mark.asyncio
async def test_delete_folder_success(delete_tool, session):
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.delete_folder = AsyncMock(return_value=None)
result = await delete_tool._execute(
user_id=_TEST_USER_ID, session=session, folder_id="folder-1"
)
assert isinstance(result, FolderDeletedResponse)
assert result.folder_id == "folder-1"
assert "root level" in result.message
@pytest.mark.asyncio
async def test_delete_folder_db_error(delete_tool, session):
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.delete_folder = AsyncMock(
side_effect=Exception("permission denied")
)
result = await delete_tool._execute(
user_id=_TEST_USER_ID, session=session, folder_id="folder-1"
)
assert isinstance(result, ErrorResponse)
assert result.error == "delete_folder_failed"
# ── MoveAgentsToFolderTool ──
@pytest.fixture
def move_agents_tool():
return MoveAgentsToFolderTool()
@pytest.mark.asyncio
async def test_move_agents_missing_ids(move_agents_tool, session):
result = await move_agents_tool._execute(
user_id=_TEST_USER_ID, session=session, agent_ids=[]
)
assert isinstance(result, ErrorResponse)
assert result.error == "missing_agent_ids"
@pytest.mark.asyncio
async def test_move_agents_success(move_agents_tool, session):
agents = [
_make_library_agent(id="a1", name="Agent Alpha"),
_make_library_agent(id="a2", name="Agent Beta"),
]
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.bulk_move_agents_to_folder = AsyncMock(
return_value=agents
)
result = await move_agents_tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_ids=["a1", "a2"],
folder_id="folder-1",
)
assert isinstance(result, AgentsMovedToFolderResponse)
assert result.count == 2
assert result.agent_names == ["Agent Alpha", "Agent Beta"]
assert "Agent Alpha" in result.message
assert "Agent Beta" in result.message
@pytest.mark.asyncio
async def test_move_agents_to_root(move_agents_tool, session):
agents = [_make_library_agent(id="a1", name="Agent One")]
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.bulk_move_agents_to_folder = AsyncMock(
return_value=agents
)
result = await move_agents_tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_ids=["a1"],
folder_id=None,
)
assert isinstance(result, AgentsMovedToFolderResponse)
assert "root level" in result.message
@pytest.mark.asyncio
async def test_move_agents_db_error(move_agents_tool, session):
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.bulk_move_agents_to_folder = AsyncMock(
side_effect=Exception("folder not found")
)
result = await move_agents_tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_ids=["a1"],
folder_id="bad-folder",
)
assert isinstance(result, ErrorResponse)
assert result.error == "move_agents_failed"

View File

@@ -55,6 +55,13 @@ class ResponseType(str, Enum):
# MCP tool types
MCP_TOOLS_DISCOVERED = "mcp_tools_discovered"
MCP_TOOL_OUTPUT = "mcp_tool_output"
# Folder management types
FOLDER_CREATED = "folder_created"
FOLDER_LIST = "folder_list"
FOLDER_UPDATED = "folder_updated"
FOLDER_MOVED = "folder_moved"
FOLDER_DELETED = "folder_deleted"
AGENTS_MOVED_TO_FOLDER = "agents_moved_to_folder"
# Base response model
@@ -539,3 +546,82 @@ class BrowserScreenshotResponse(ToolResponseBase):
type: ResponseType = ResponseType.BROWSER_SCREENSHOT
file_id: str # Workspace file ID — use read_workspace_file to retrieve
filename: str
# Folder management models
class FolderAgentSummary(BaseModel):
"""Lightweight agent info for folder listings."""
id: str
name: str
description: str = ""
class FolderInfo(BaseModel):
"""Information about a folder."""
id: str
name: str
parent_id: str | None = None
icon: str | None = None
color: str | None = None
agent_count: int = 0
subfolder_count: int = 0
agents: list[FolderAgentSummary] | None = None
class FolderTreeInfo(FolderInfo):
"""Folder with nested children for tree display."""
children: list["FolderTreeInfo"] = []
class FolderCreatedResponse(ToolResponseBase):
"""Response when a folder is created."""
type: ResponseType = ResponseType.FOLDER_CREATED
folder: FolderInfo
class FolderListResponse(ToolResponseBase):
"""Response for listing folders."""
type: ResponseType = ResponseType.FOLDER_LIST
folders: list[FolderInfo] = Field(default_factory=list)
tree: list[FolderTreeInfo] | None = None
root_agents: list[FolderAgentSummary] | None = None
count: int = 0
class FolderUpdatedResponse(ToolResponseBase):
"""Response when a folder is updated."""
type: ResponseType = ResponseType.FOLDER_UPDATED
folder: FolderInfo
class FolderMovedResponse(ToolResponseBase):
"""Response when a folder is moved."""
type: ResponseType = ResponseType.FOLDER_MOVED
folder: FolderInfo
target_parent_id: str | None = None
class FolderDeletedResponse(ToolResponseBase):
"""Response when a folder is deleted."""
type: ResponseType = ResponseType.FOLDER_DELETED
folder_id: str
class AgentsMovedToFolderResponse(ToolResponseBase):
"""Response when agents are moved to a folder."""
type: ResponseType = ResponseType.AGENTS_MOVED_TO_FOLDER
agent_ids: list[str]
agent_names: list[str] = []
folder_id: str | None = None
count: int = 0

View File

@@ -53,11 +53,15 @@ class RunMCPToolTool(BaseTool):
def description(self) -> str:
return (
"Connect to an MCP (Model Context Protocol) server to discover and execute its tools. "
"Call with just `server_url` to see available tools. "
"Then call again with `server_url`, `tool_name`, and `tool_arguments` to execute. "
"If the server requires authentication, the user will be prompted to connect it. "
"Find MCP servers at https://registry.modelcontextprotocol.io/ — hundreds of integrations "
"including GitHub, Postgres, Slack, filesystem, and more."
"Two-step workflow: (1) Call with just `server_url` to discover available tools. "
"(2) Call again with `server_url`, `tool_name`, and `tool_arguments` to execute. "
"Known hosted servers (use directly): Notion (https://mcp.notion.com/mcp), "
"Linear (https://mcp.linear.app/mcp), Stripe (https://mcp.stripe.com), "
"Intercom (https://mcp.intercom.com/mcp), Cloudflare (https://mcp.cloudflare.com/mcp), "
"Atlassian/Jira (https://mcp.atlassian.com/mcp). "
"For other services, search the MCP registry at https://registry.modelcontextprotocol.io/. "
"Authentication: If the server requires credentials, user will be prompted to complete the MCP credential setup flow."
"Once connected and user confirms, retry the same call immediately."
)
@property

View File

@@ -4,11 +4,20 @@ from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cas
from backend.api.features.library.db import (
add_store_agent_to_library,
bulk_move_agents_to_folder,
create_folder,
create_graph_in_library,
create_library_agent,
delete_folder,
get_folder_agents_map,
get_folder_tree,
get_library_agent,
get_library_agent_by_graph_id,
get_root_agent_summaries,
list_folders,
list_library_agents,
move_folder,
update_folder,
update_graph_in_library,
)
from backend.api.features.store.db import (
@@ -260,6 +269,16 @@ class DatabaseManager(AppService):
update_graph_in_library = _(update_graph_in_library)
validate_graph_execution_permissions = _(validate_graph_execution_permissions)
create_folder = _(create_folder)
list_folders = _(list_folders)
get_folder_tree = _(get_folder_tree)
update_folder = _(update_folder)
move_folder = _(move_folder)
delete_folder = _(delete_folder)
bulk_move_agents_to_folder = _(bulk_move_agents_to_folder)
get_folder_agents_map = _(get_folder_agents_map)
get_root_agent_summaries = _(get_root_agent_summaries)
# ============ Onboarding ============ #
increment_onboarding_runs = _(increment_onboarding_runs)
@@ -305,6 +324,7 @@ class DatabaseManager(AppService):
delete_chat_session = _(chat_db.delete_chat_session)
get_next_sequence = _(chat_db.get_next_sequence)
update_tool_message_content = _(chat_db.update_tool_message_content)
update_chat_session_title = _(chat_db.update_chat_session_title)
class DatabaseManagerClient(AppServiceClient):
@@ -433,6 +453,17 @@ class DatabaseManagerAsyncClient(AppServiceClient):
update_graph_in_library = d.update_graph_in_library
validate_graph_execution_permissions = d.validate_graph_execution_permissions
# ============ Library Folders ============ #
create_folder = d.create_folder
list_folders = d.list_folders
get_folder_tree = d.get_folder_tree
update_folder = d.update_folder
move_folder = d.move_folder
delete_folder = d.delete_folder
bulk_move_agents_to_folder = d.bulk_move_agents_to_folder
get_folder_agents_map = d.get_folder_agents_map
get_root_agent_summaries = d.get_root_agent_summaries
# ============ Onboarding ============ #
increment_onboarding_runs = d.increment_onboarding_runs
@@ -475,3 +506,4 @@ class DatabaseManagerAsyncClient(AppServiceClient):
delete_chat_session = d.delete_chat_session
get_next_sequence = d.get_next_sequence
update_tool_message_content = d.update_tool_message_content
update_chat_session_title = d.update_chat_session_title

View File

@@ -344,7 +344,7 @@ class GraphExecution(GraphExecutionMeta):
),
**{
# input from webhook-triggered block
"payload": exec.input_data["payload"]
"payload": exec.input_data.get("payload")
for exec in complete_node_executions
if (
(block := get_block(exec.block_id))

View File

@@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, patch
from uuid import UUID
import fastapi.exceptions
import prisma
import pytest
from pytest_snapshot.plugin import Snapshot
@@ -250,8 +251,8 @@ async def test_clean_graph(server: SpinTestServer):
"_test_id": "node_with_secrets",
"input": "normal_value",
"control_test_input": "should be preserved",
"api_key": "secret_api_key_123", # Should be filtered
"password": "secret_password_456", # Should be filtered
"api_key": "secret_api_key_123", # Should be filtered # pragma: allowlist secret # noqa
"password": "secret_password_456", # Should be filtered # pragma: allowlist secret # noqa
"token": "secret_token_789", # Should be filtered
"credentials": { # Should be filtered
"id": "fake-github-credentials-id",
@@ -354,9 +355,24 @@ async def test_access_store_listing_graph(server: SpinTestServer):
create_graph, DEFAULT_USER_ID
)
# Ensure the default user has a Profile (required for store submissions)
existing_profile = await prisma.models.Profile.prisma().find_first(
where={"userId": DEFAULT_USER_ID}
)
if not existing_profile:
await prisma.models.Profile.prisma().create(
data=prisma.types.ProfileCreateInput(
userId=DEFAULT_USER_ID,
name="Default User",
username=f"default-user-{DEFAULT_USER_ID[:8]}",
description="Default test user profile",
links=[],
)
)
store_submission_request = store.StoreSubmissionRequest(
agent_id=created_graph.id,
agent_version=created_graph.version,
graph_id=created_graph.id,
graph_version=created_graph.version,
slug=created_graph.id,
name="Test name",
sub_heading="Test sub heading",
@@ -385,8 +401,8 @@ async def test_access_store_listing_graph(server: SpinTestServer):
assert False, "Failed to create store listing"
slv_id = (
store_listing.store_listing_version_id
if store_listing.store_listing_version_id is not None
store_listing.listing_version_id
if store_listing.listing_version_id is not None
else None
)

View File

@@ -184,17 +184,17 @@ async def find_webhook_by_credentials_and_props(
credentials_id: str,
webhook_type: str,
resource: str,
events: Optional[list[str]],
events: list[str] | None = None,
) -> Webhook | None:
webhook = await IntegrationWebhook.prisma().find_first(
where={
"userId": user_id,
"credentialsId": credentials_id,
"webhookType": webhook_type,
"resource": resource,
**({"events": {"has_every": events}} if events else {}),
},
)
where: IntegrationWebhookWhereInput = {
"userId": user_id,
"credentialsId": credentials_id,
"webhookType": webhook_type,
"resource": resource,
}
if events is not None:
where["events"] = {"has_every": events}
webhook = await IntegrationWebhook.prisma().find_first(where=where)
return Webhook.from_db(webhook) if webhook else None

View File

@@ -0,0 +1,601 @@
import asyncio
import csv
import io
import logging
import re
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Literal, Optional
from uuid import uuid4
import prisma.enums
import prisma.models
import prisma.types
from pydantic import BaseModel, EmailStr, TypeAdapter, ValidationError
from backend.data.db import transaction
from backend.data.model import User
from backend.data.tally import get_business_understanding_input_from_tally
from backend.data.understanding import (
BusinessUnderstandingInput,
merge_business_understanding_data,
)
from backend.util.exceptions import (
NotAuthorizedError,
NotFoundError,
PreconditionFailed,
)
from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
_tally_seed_tasks: set[asyncio.Task] = set()
_email_adapter = TypeAdapter(EmailStr)
MAX_BULK_INVITE_FILE_BYTES = 1024 * 1024
MAX_BULK_INVITE_ROWS = 500
class InvitedUserRecord(BaseModel):
id: str
email: str
status: prisma.enums.InvitedUserStatus
auth_user_id: Optional[str] = None
name: Optional[str] = None
tally_understanding: Optional[dict[str, Any]] = None
tally_status: prisma.enums.TallyComputationStatus
tally_computed_at: Optional[datetime] = None
tally_error: Optional[str] = None
created_at: datetime
updated_at: datetime
@classmethod
def from_db(cls, invited_user: "prisma.models.InvitedUser") -> "InvitedUserRecord":
payload = (
invited_user.tallyUnderstanding
if isinstance(invited_user.tallyUnderstanding, dict)
else None
)
return cls(
id=invited_user.id,
email=invited_user.email,
status=invited_user.status,
auth_user_id=invited_user.authUserId,
name=invited_user.name,
tally_understanding=payload,
tally_status=invited_user.tallyStatus,
tally_computed_at=invited_user.tallyComputedAt,
tally_error=invited_user.tallyError,
created_at=invited_user.createdAt,
updated_at=invited_user.updatedAt,
)
class BulkInvitedUserRowResult(BaseModel):
row_number: int
email: Optional[str] = None
name: Optional[str] = None
status: Literal["CREATED", "SKIPPED", "ERROR"]
message: str
invited_user: Optional[InvitedUserRecord] = None
class BulkInvitedUsersResult(BaseModel):
created_count: int
skipped_count: int
error_count: int
results: list[BulkInvitedUserRowResult]
@dataclass(frozen=True)
class _ParsedInviteRow:
row_number: int
email: str
name: Optional[str]
def normalize_email(email: str) -> str:
return email.strip().lower()
def _normalize_name(name: Optional[str]) -> Optional[str]:
if name is None:
return None
normalized = name.strip()
return normalized or None
def _default_profile_name(email: str, preferred_name: Optional[str]) -> str:
if preferred_name:
return preferred_name
local_part = email.split("@", 1)[0].strip()
return local_part or "user"
def _sanitize_username_base(email: str) -> str:
local_part = email.split("@", 1)[0].lower()
sanitized = re.sub(r"[^a-z0-9-]", "", local_part)
sanitized = sanitized.strip("-")
return sanitized[:40] or "user"
async def _generate_unique_profile_username(email: str, tx) -> str:
base = _sanitize_username_base(email)
for attempt in range(10):
candidate = base if attempt == 0 else f"{base}-{uuid4().hex[:6]}"
existing = await prisma.models.Profile.prisma(tx).find_unique(
where={"username": candidate}
)
if existing is None:
return candidate
raise RuntimeError(f"Unable to generate unique username for {email}")
async def _ensure_default_profile(
user_id: str,
email: str,
preferred_name: Optional[str],
tx,
) -> None:
existing_profile = await prisma.models.Profile.prisma(tx).find_unique(
where={"userId": user_id}
)
if existing_profile is not None:
return
username = await _generate_unique_profile_username(email, tx)
await prisma.models.Profile.prisma(tx).create(
data=prisma.types.ProfileCreateInput(
userId=user_id,
name=_default_profile_name(email, preferred_name),
username=username,
description="I'm new here",
links=[],
avatarUrl="",
)
)
async def _ensure_default_onboarding(user_id: str, tx) -> None:
await prisma.models.UserOnboarding.prisma(tx).upsert(
where={"userId": user_id},
data={
"create": prisma.types.UserOnboardingCreateInput(userId=user_id),
"update": {},
},
)
async def _apply_tally_understanding(
user_id: str,
invited_user: "prisma.models.InvitedUser",
tx,
) -> None:
if not isinstance(invited_user.tallyUnderstanding, dict):
return
input_data = BusinessUnderstandingInput.model_validate(
invited_user.tallyUnderstanding
)
payload = merge_business_understanding_data({}, input_data)
await prisma.models.CoPilotUnderstanding.prisma(tx).upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "data": SafeJson(payload)},
"update": {"data": SafeJson(payload)},
},
)
async def list_invited_users() -> list[InvitedUserRecord]:
invited_users = await prisma.models.InvitedUser.prisma().find_many(
order={"createdAt": "desc"}
)
return [InvitedUserRecord.from_db(invited_user) for invited_user in invited_users]
async def create_invited_user(
email: str, name: Optional[str] = None
) -> InvitedUserRecord:
normalized_email = normalize_email(email)
normalized_name = _normalize_name(name)
existing_user = await prisma.models.User.prisma().find_unique(
where={"email": normalized_email}
)
if existing_user is not None:
raise PreconditionFailed("An active user with this email already exists")
existing_invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"email": normalized_email}
)
if existing_invited_user is not None:
raise PreconditionFailed("An invited user with this email already exists")
invited_user = await prisma.models.InvitedUser.prisma().create(
data={
"email": normalized_email,
"name": normalized_name,
"status": prisma.enums.InvitedUserStatus.INVITED,
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
}
)
schedule_invited_user_tally_precompute(invited_user.id)
return InvitedUserRecord.from_db(invited_user)
async def revoke_invited_user(invited_user_id: str) -> InvitedUserRecord:
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"id": invited_user_id}
)
if invited_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
if invited_user.status == prisma.enums.InvitedUserStatus.CLAIMED:
raise PreconditionFailed("Claimed invited users cannot be revoked")
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
return InvitedUserRecord.from_db(invited_user)
revoked_user = await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={"status": prisma.enums.InvitedUserStatus.REVOKED},
)
if revoked_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
return InvitedUserRecord.from_db(revoked_user)
async def retry_invited_user_tally(invited_user_id: str) -> InvitedUserRecord:
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"id": invited_user_id}
)
if invited_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
raise PreconditionFailed("Revoked invited users cannot retry Tally seeding")
refreshed_user = await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyUnderstanding": None,
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
"tallyComputedAt": None,
"tallyError": None,
},
)
if refreshed_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
schedule_invited_user_tally_precompute(invited_user_id)
return InvitedUserRecord.from_db(refreshed_user)
def _decode_bulk_invite_file(content: bytes) -> str:
if len(content) > MAX_BULK_INVITE_FILE_BYTES:
raise ValueError("Invite file exceeds the maximum size of 1 MB")
try:
return content.decode("utf-8-sig")
except UnicodeDecodeError as exc:
raise ValueError("Invite file must be UTF-8 encoded") from exc
def _parse_bulk_invite_csv(text: str) -> list[_ParsedInviteRow]:
indexed_rows: list[tuple[int, list[str]]] = []
for row_number, row in enumerate(csv.reader(io.StringIO(text)), start=1):
normalized_row = [cell.strip() for cell in row]
if any(normalized_row):
indexed_rows.append((row_number, normalized_row))
if not indexed_rows:
return []
header = [cell.lower() for cell in indexed_rows[0][1]]
has_header = "email" in header
email_index = header.index("email") if has_header else 0
name_index = header.index("name") if has_header and "name" in header else 1
data_rows = indexed_rows[1:] if has_header else indexed_rows
parsed_rows: list[_ParsedInviteRow] = []
for row_number, row in data_rows:
email = row[email_index].strip() if len(row) > email_index else ""
name = row[name_index].strip() if len(row) > name_index else ""
parsed_rows.append(
_ParsedInviteRow(
row_number=row_number,
email=email,
name=name or None,
)
)
return parsed_rows
def _parse_bulk_invite_text(text: str) -> list[_ParsedInviteRow]:
parsed_rows: list[_ParsedInviteRow] = []
for row_number, raw_line in enumerate(text.splitlines(), start=1):
line = raw_line.strip()
if not line or line.startswith("#"):
continue
parsed_rows.append(
_ParsedInviteRow(
row_number=row_number,
email=line,
name=None,
)
)
return parsed_rows
def _parse_bulk_invite_file(
filename: Optional[str],
content: bytes,
) -> list[_ParsedInviteRow]:
text = _decode_bulk_invite_file(content)
file_name = filename.lower() if filename else ""
parsed_rows = (
_parse_bulk_invite_csv(text)
if file_name.endswith(".csv")
else _parse_bulk_invite_text(text)
)
if not parsed_rows:
raise ValueError("Invite file did not contain any emails")
if len(parsed_rows) > MAX_BULK_INVITE_ROWS:
raise ValueError(
f"Invite file contains too many rows. Maximum supported rows: {MAX_BULK_INVITE_ROWS}"
)
return parsed_rows
async def bulk_create_invited_users_from_file(
filename: Optional[str],
content: bytes,
) -> BulkInvitedUsersResult:
parsed_rows = _parse_bulk_invite_file(filename, content)
created_count = 0
skipped_count = 0
error_count = 0
results: list[BulkInvitedUserRowResult] = []
seen_emails: set[str] = set()
for row in parsed_rows:
row_name = _normalize_name(row.name)
try:
validated_email = _email_adapter.validate_python(row.email)
except ValidationError:
error_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=row.email or None,
name=row_name,
status="ERROR",
message="Invalid email address",
)
)
continue
normalized_email = normalize_email(str(validated_email))
if normalized_email in seen_emails:
skipped_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="SKIPPED",
message="Duplicate email in upload file",
)
)
continue
seen_emails.add(normalized_email)
try:
invited_user = await create_invited_user(normalized_email, row_name)
except PreconditionFailed as exc:
skipped_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="SKIPPED",
message=str(exc),
)
)
except Exception:
logger.exception(
"Failed to create bulk invite for %s from row %s",
normalized_email,
row.row_number,
)
error_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="ERROR",
message="Unexpected error creating invite",
)
)
else:
created_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="CREATED",
message="Invite created",
invited_user=invited_user,
)
)
return BulkInvitedUsersResult(
created_count=created_count,
skipped_count=skipped_count,
error_count=error_count,
results=results,
)
async def _compute_invited_user_tally_seed(invited_user_id: str) -> None:
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"id": invited_user_id}
)
if invited_user is None:
return
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
return
await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyStatus": prisma.enums.TallyComputationStatus.RUNNING,
"tallyError": None,
},
)
try:
input_data = await get_business_understanding_input_from_tally(
invited_user.email,
require_api_key=True,
)
payload = (
SafeJson(input_data.model_dump(exclude_none=True))
if input_data is not None
else None
)
await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyUnderstanding": payload,
"tallyStatus": prisma.enums.TallyComputationStatus.READY,
"tallyComputedAt": datetime.now(timezone.utc),
"tallyError": None,
},
)
except Exception as exc:
logger.exception(
"Failed to compute Tally understanding for invited user %s",
invited_user_id,
)
await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyStatus": prisma.enums.TallyComputationStatus.FAILED,
"tallyError": str(exc),
},
)
def schedule_invited_user_tally_precompute(invited_user_id: str) -> None:
task = asyncio.create_task(_compute_invited_user_tally_seed(invited_user_id))
_tally_seed_tasks.add(task)
task.add_done_callback(_tally_seed_tasks.discard)
async def get_or_activate_user(user_data: dict) -> User:
auth_user_id = user_data.get("sub")
if not auth_user_id:
raise NotAuthorizedError("User ID not found in token")
auth_email = user_data.get("email")
if not auth_email:
raise NotAuthorizedError("Email not found in token")
normalized_email = normalize_email(auth_email)
user_metadata = user_data.get("user_metadata")
metadata_name = (
user_metadata.get("name") if isinstance(user_metadata, dict) else None
)
existing_user = await prisma.models.User.prisma().find_unique(
where={"id": auth_user_id}
)
if existing_user is not None:
return User.from_db(existing_user)
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"email": normalized_email}
)
if invited_user is None:
raise NotAuthorizedError("Your email is not allowed to access the platform")
if invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
raise NotAuthorizedError("Your invitation is no longer active")
async with transaction() as tx:
current_user = await prisma.models.User.prisma(tx).find_unique(
where={"id": auth_user_id}
)
if current_user is not None:
return User.from_db(current_user)
current_invited_user = await prisma.models.InvitedUser.prisma(tx).find_unique(
where={"email": normalized_email}
)
if current_invited_user is None:
raise NotAuthorizedError("Your email is not allowed to access the platform")
if current_invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
raise NotAuthorizedError("Your invitation is no longer active")
if current_invited_user.authUserId not in (None, auth_user_id):
raise NotAuthorizedError("Your invitation has already been claimed")
preferred_name = current_invited_user.name or _normalize_name(metadata_name)
await prisma.models.User.prisma(tx).create(
data=prisma.types.UserCreateInput(
id=auth_user_id,
email=normalized_email,
name=preferred_name,
)
)
await prisma.models.InvitedUser.prisma(tx).update(
where={"id": current_invited_user.id},
data={
"status": prisma.enums.InvitedUserStatus.CLAIMED,
"authUserId": auth_user_id,
},
)
await _ensure_default_profile(
auth_user_id,
normalized_email,
preferred_name,
tx,
)
await _ensure_default_onboarding(auth_user_id, tx)
await _apply_tally_understanding(auth_user_id, current_invited_user, tx)
from backend.data.user import get_user_by_email, get_user_by_id
get_user_by_id.cache_delete(auth_user_id)
get_user_by_email.cache_delete(normalized_email)
activated_user = await prisma.models.User.prisma().find_unique(
where={"id": auth_user_id}
)
if activated_user is None:
raise RuntimeError(
f"Activated user {auth_user_id} was not found after creation"
)
return User.from_db(activated_user)

View File

@@ -0,0 +1,297 @@
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock
import prisma.enums
import pytest
import pytest_mock
from backend.util.exceptions import NotAuthorizedError, PreconditionFailed
from .invited_user import (
bulk_create_invited_users_from_file,
create_invited_user,
get_or_activate_user,
retry_invited_user_tally,
)
def _invited_user_db_record(
*,
status: prisma.enums.InvitedUserStatus = prisma.enums.InvitedUserStatus.INVITED,
tally_understanding: dict | None = None,
):
now = datetime.now(timezone.utc)
return SimpleNamespace(
id="invite-1",
email="invited@example.com",
status=status,
authUserId=None,
name="Invited User",
tallyUnderstanding=tally_understanding,
tallyStatus=prisma.enums.TallyComputationStatus.PENDING,
tallyComputedAt=None,
tallyError=None,
createdAt=now,
updatedAt=now,
)
def _user_db_record():
now = datetime.now(timezone.utc)
return SimpleNamespace(
id="auth-user-1",
email="invited@example.com",
emailVerified=True,
name="Invited User",
createdAt=now,
updatedAt=now,
metadata={},
integrations="",
stripeCustomerId=None,
topUpConfig=None,
maxEmailsPerDay=3,
notifyOnAgentRun=True,
notifyOnZeroBalance=True,
notifyOnLowBalance=True,
notifyOnBlockExecutionFailed=True,
notifyOnContinuousAgentError=True,
notifyOnDailySummary=True,
notifyOnWeeklySummary=True,
notifyOnMonthlySummary=True,
notifyOnAgentApproved=True,
notifyOnAgentRejected=True,
timezone="not-set",
)
@pytest.mark.asyncio
async def test_create_invited_user_rejects_existing_active_user(
mocker: pytest_mock.MockerFixture,
) -> None:
user_repo = Mock()
user_repo.find_unique = AsyncMock(return_value=_user_db_record())
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock()
mocker.patch(
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
with pytest.raises(PreconditionFailed):
await create_invited_user("Invited@example.com")
@pytest.mark.asyncio
async def test_create_invited_user_schedules_tally_seed(
mocker: pytest_mock.MockerFixture,
) -> None:
user_repo = Mock()
user_repo.find_unique = AsyncMock(return_value=None)
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock(return_value=None)
invited_user_repo.create = AsyncMock(return_value=_invited_user_db_record())
schedule = mocker.patch(
"backend.data.invited_user.schedule_invited_user_tally_precompute"
)
mocker.patch(
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
invited_user = await create_invited_user("Invited@example.com", "Invited User")
assert invited_user.email == "invited@example.com"
invited_user_repo.create.assert_awaited_once()
schedule.assert_called_once_with("invite-1")
@pytest.mark.asyncio
async def test_retry_invited_user_tally_resets_state_and_schedules(
mocker: pytest_mock.MockerFixture,
) -> None:
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock(return_value=_invited_user_db_record())
invited_user_repo.update = AsyncMock(return_value=_invited_user_db_record())
schedule = mocker.patch(
"backend.data.invited_user.schedule_invited_user_tally_precompute"
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
invited_user = await retry_invited_user_tally("invite-1")
assert invited_user.id == "invite-1"
invited_user_repo.update.assert_awaited_once()
schedule.assert_called_once_with("invite-1")
@pytest.mark.asyncio
async def test_get_or_activate_user_requires_invite(
mocker: pytest_mock.MockerFixture,
) -> None:
user_repo = Mock()
user_repo.find_unique = AsyncMock(return_value=None)
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock(return_value=None)
mocker.patch(
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
with pytest.raises(NotAuthorizedError):
await get_or_activate_user(
{"sub": "auth-user-1", "email": "invited@example.com"}
)
@pytest.mark.asyncio
async def test_get_or_activate_user_creates_user_from_invite(
mocker: pytest_mock.MockerFixture,
) -> None:
tx = object()
invited_user = _invited_user_db_record(
tally_understanding={"user_name": "Invited User", "industry": "Automation"}
)
created_user = _user_db_record()
outside_user_repo = Mock()
outside_user_repo.find_unique = AsyncMock(side_effect=[None, created_user])
inside_user_repo = Mock()
inside_user_repo.find_unique = AsyncMock(return_value=None)
inside_user_repo.create = AsyncMock(return_value=created_user)
outside_invited_repo = Mock()
outside_invited_repo.find_unique = AsyncMock(return_value=invited_user)
inside_invited_repo = Mock()
inside_invited_repo.find_unique = AsyncMock(return_value=invited_user)
inside_invited_repo.update = AsyncMock(return_value=invited_user)
def user_prisma(client=None):
return inside_user_repo if client is tx else outside_user_repo
def invited_user_prisma(client=None):
return inside_invited_repo if client is tx else outside_invited_repo
@asynccontextmanager
async def fake_transaction():
yield tx
ensure_profile = mocker.patch(
"backend.data.invited_user._ensure_default_profile",
AsyncMock(),
)
ensure_onboarding = mocker.patch(
"backend.data.invited_user._ensure_default_onboarding",
AsyncMock(),
)
apply_tally = mocker.patch(
"backend.data.invited_user._apply_tally_understanding",
AsyncMock(),
)
mocker.patch("backend.data.invited_user.transaction", fake_transaction)
mocker.patch(
"backend.data.invited_user.prisma.models.User.prisma", side_effect=user_prisma
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
side_effect=invited_user_prisma,
)
mocker.patch("backend.data.user.get_user_by_id.cache_delete", Mock())
mocker.patch("backend.data.user.get_user_by_email.cache_delete", Mock())
user = await get_or_activate_user(
{
"sub": "auth-user-1",
"email": "Invited@example.com",
"user_metadata": {"name": "Invited User"},
}
)
assert user.id == "auth-user-1"
inside_user_repo.create.assert_awaited_once()
inside_invited_repo.update.assert_awaited_once()
ensure_profile.assert_awaited_once()
ensure_onboarding.assert_awaited_once_with("auth-user-1", tx)
apply_tally.assert_awaited_once_with("auth-user-1", invited_user, tx)
@pytest.mark.asyncio
async def test_bulk_create_invited_users_from_text_file(
mocker: pytest_mock.MockerFixture,
) -> None:
create_invited = mocker.patch(
"backend.data.invited_user.create_invited_user",
AsyncMock(
side_effect=[
_invited_user_db_record(),
_invited_user_db_record(),
]
),
)
result = await bulk_create_invited_users_from_file(
"invites.txt",
b"Invited@example.com\nsecond@example.com\n",
)
assert result.created_count == 2
assert result.skipped_count == 0
assert result.error_count == 0
assert [row.status for row in result.results] == ["CREATED", "CREATED"]
assert create_invited.await_count == 2
@pytest.mark.asyncio
async def test_bulk_create_invited_users_handles_csv_duplicates_and_invalid_rows(
mocker: pytest_mock.MockerFixture,
) -> None:
create_invited = mocker.patch(
"backend.data.invited_user.create_invited_user",
AsyncMock(
side_effect=[
_invited_user_db_record(),
PreconditionFailed("An invited user with this email already exists"),
]
),
)
result = await bulk_create_invited_users_from_file(
"invites.csv",
(
"email,name\n"
"valid@example.com,Valid User\n"
"not-an-email,Bad Row\n"
"valid@example.com,Duplicate In File\n"
"existing@example.com,Existing User\n"
).encode("utf-8"),
)
assert result.created_count == 1
assert result.skipped_count == 2
assert result.error_count == 1
assert [row.status for row in result.results] == [
"CREATED",
"ERROR",
"SKIPPED",
"SKIPPED",
]
assert create_invited.await_count == 2

View File

@@ -13,7 +13,14 @@ from prisma.types import (
)
# from backend.notifications.models import NotificationEvent
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
from pydantic import (
BaseModel,
ConfigDict,
EmailStr,
Field,
field_validator,
model_validator,
)
from backend.util.exceptions import DatabaseError
from backend.util.json import SafeJson
@@ -175,10 +182,26 @@ class RefundRequestData(BaseNotificationData):
balance: int
class AgentApprovalData(BaseNotificationData):
class _LegacyAgentFieldsMixin:
"""Temporary patch to handle existing queued payloads"""
# FIXME: remove in next release
@model_validator(mode="before")
@classmethod
def _map_legacy_agent_fields(cls, values: Any):
if isinstance(values, dict):
if "graph_id" not in values and "agent_id" in values:
values["graph_id"] = values.pop("agent_id")
if "graph_version" not in values and "agent_version" in values:
values["graph_version"] = values.pop("agent_version")
return values
class AgentApprovalData(_LegacyAgentFieldsMixin, BaseNotificationData):
agent_name: str
agent_id: str
agent_version: int
graph_id: str
graph_version: int
reviewer_name: str
reviewer_email: str
comments: str
@@ -193,10 +216,10 @@ class AgentApprovalData(BaseNotificationData):
return value
class AgentRejectionData(BaseNotificationData):
class AgentRejectionData(_LegacyAgentFieldsMixin, BaseNotificationData):
agent_name: str
agent_id: str
agent_version: int
graph_id: str
graph_version: int
reviewer_name: str
reviewer_email: str
comments: str

View File

@@ -15,8 +15,8 @@ class TestAgentApprovalData:
"""Test creating valid AgentApprovalData."""
data = AgentApprovalData(
agent_name="Test Agent",
agent_id="test-agent-123",
agent_version=1,
graph_id="test-agent-123",
graph_version=1,
reviewer_name="John Doe",
reviewer_email="john@example.com",
comments="Great agent, approved!",
@@ -25,8 +25,8 @@ class TestAgentApprovalData:
)
assert data.agent_name == "Test Agent"
assert data.agent_id == "test-agent-123"
assert data.agent_version == 1
assert data.graph_id == "test-agent-123"
assert data.graph_version == 1
assert data.reviewer_name == "John Doe"
assert data.reviewer_email == "john@example.com"
assert data.comments == "Great agent, approved!"
@@ -40,8 +40,8 @@ class TestAgentApprovalData:
):
AgentApprovalData(
agent_name="Test Agent",
agent_id="test-agent-123",
agent_version=1,
graph_id="test-agent-123",
graph_version=1,
reviewer_name="John Doe",
reviewer_email="john@example.com",
comments="Great agent, approved!",
@@ -53,8 +53,8 @@ class TestAgentApprovalData:
"""Test AgentApprovalData with empty comments."""
data = AgentApprovalData(
agent_name="Test Agent",
agent_id="test-agent-123",
agent_version=1,
graph_id="test-agent-123",
graph_version=1,
reviewer_name="John Doe",
reviewer_email="john@example.com",
comments="", # Empty comments
@@ -72,8 +72,8 @@ class TestAgentRejectionData:
"""Test creating valid AgentRejectionData."""
data = AgentRejectionData(
agent_name="Test Agent",
agent_id="test-agent-123",
agent_version=1,
graph_id="test-agent-123",
graph_version=1,
reviewer_name="Jane Doe",
reviewer_email="jane@example.com",
comments="Please fix the security issues before resubmitting.",
@@ -82,8 +82,8 @@ class TestAgentRejectionData:
)
assert data.agent_name == "Test Agent"
assert data.agent_id == "test-agent-123"
assert data.agent_version == 1
assert data.graph_id == "test-agent-123"
assert data.graph_version == 1
assert data.reviewer_name == "Jane Doe"
assert data.reviewer_email == "jane@example.com"
assert data.comments == "Please fix the security issues before resubmitting."
@@ -97,8 +97,8 @@ class TestAgentRejectionData:
):
AgentRejectionData(
agent_name="Test Agent",
agent_id="test-agent-123",
agent_version=1,
graph_id="test-agent-123",
graph_version=1,
reviewer_name="Jane Doe",
reviewer_email="jane@example.com",
comments="Please fix the security issues.",
@@ -111,8 +111,8 @@ class TestAgentRejectionData:
long_comment = "A" * 1000 # Very long comment
data = AgentRejectionData(
agent_name="Test Agent",
agent_id="test-agent-123",
agent_version=1,
graph_id="test-agent-123",
graph_version=1,
reviewer_name="Jane Doe",
reviewer_email="jane@example.com",
comments=long_comment,
@@ -126,8 +126,8 @@ class TestAgentRejectionData:
"""Test that models can be serialized and deserialized."""
original_data = AgentRejectionData(
agent_name="Test Agent",
agent_id="test-agent-123",
agent_version=1,
graph_id="test-agent-123",
graph_version=1,
reviewer_name="Jane Doe",
reviewer_email="jane@example.com",
comments="Please fix the issues.",
@@ -142,8 +142,8 @@ class TestAgentRejectionData:
restored_data = AgentRejectionData.model_validate(data_dict)
assert restored_data.agent_name == original_data.agent_name
assert restored_data.agent_id == original_data.agent_id
assert restored_data.agent_version == original_data.agent_version
assert restored_data.graph_id == original_data.graph_id
assert restored_data.graph_version == original_data.graph_version
assert restored_data.reviewer_name == original_data.reviewer_name
assert restored_data.reviewer_email == original_data.reviewer_email
assert restored_data.comments == original_data.comments

View File

@@ -244,7 +244,10 @@ def _clean_and_split(text: str) -> list[str]:
def _calculate_points(
agent, categories: list[str], custom: list[str], integrations: list[str]
agent: prisma.models.StoreAgent,
categories: list[str],
custom: list[str],
integrations: list[str],
) -> int:
"""
Calculates the total points for an agent based on the specified criteria.
@@ -397,7 +400,7 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
storeAgents = await prisma.models.StoreAgent.prisma().find_many(
where={
"is_available": True,
"useForOnboarding": True,
"use_for_onboarding": True,
},
order=[
{"featured": "desc"},
@@ -407,7 +410,7 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
take=100,
)
# If not enough agents found, relax the useForOnboarding filter
# If not enough agents found, relax the use_for_onboarding filter
if len(storeAgents) < 2:
storeAgents = await prisma.models.StoreAgent.prisma().find_many(
where=prisma.types.StoreAgentWhereInput(**where_clause),
@@ -420,7 +423,7 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
)
# Calculate points for the first X agents and choose the top 2
agent_points = []
agent_points: list[tuple[prisma.models.StoreAgent, int]] = []
for agent in storeAgents[:POINTS_AGENT_COUNT]:
points = _calculate_points(
agent, categories, custom, user_onboarding.integrations
@@ -430,28 +433,7 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
agent_points.sort(key=lambda x: x[1], reverse=True)
recommended_agents = [agent for agent, _ in agent_points[:2]]
return [
StoreAgentDetails(
store_listing_version_id=agent.storeListingVersionId,
slug=agent.slug,
agent_name=agent.agent_name,
agent_video=agent.agent_video or "",
agent_output_demo=agent.agent_output_demo or "",
agent_image=agent.agent_image,
creator=agent.creator_username,
creator_avatar=agent.creator_avatar,
sub_heading=agent.sub_heading,
description=agent.description,
categories=agent.categories,
runs=agent.runs,
rating=agent.rating,
versions=agent.versions,
agentGraphVersions=agent.agentGraphVersions,
agentGraphId=agent.agentGraphId,
last_updated=agent.updated_at,
)
for agent in recommended_agents
]
return [StoreAgentDetails.from_db(agent) for agent in recommended_agents]
@cached(maxsize=1, ttl_seconds=300) # Cache for 5 minutes since this rarely changes

View File

@@ -380,6 +380,35 @@ async def extract_business_understanding(
return BusinessUnderstandingInput(**cleaned)
async def get_business_understanding_input_from_tally(
email: str,
*,
require_api_key: bool = False,
) -> Optional[BusinessUnderstandingInput]:
settings = Settings()
if not settings.secrets.tally_api_key:
if require_api_key:
raise RuntimeError("Tally API key is not configured")
logger.debug("Tally: no API key configured, skipping")
return None
masked = _mask_email(email)
result = await find_submission_by_email(TALLY_FORM_ID, email)
if result is None:
logger.debug(f"Tally: no submission found for {masked}")
return None
submission, questions = result
logger.info(f"Tally: found submission for {masked}, extracting understanding")
formatted = format_submission_for_llm(submission, questions)
if not formatted.strip():
logger.warning("Tally: formatted submission was empty, skipping")
return None
return await extract_business_understanding(formatted)
async def populate_understanding_from_tally(user_id: str, email: str) -> None:
"""Main orchestrator: check Tally for a matching submission and populate understanding.
@@ -394,30 +423,10 @@ async def populate_understanding_from_tally(user_id: str, email: str) -> None:
)
return
# Check API key is configured
settings = Settings()
if not settings.secrets.tally_api_key:
logger.debug("Tally: no API key configured, skipping")
understanding_input = await get_business_understanding_input_from_tally(email)
if understanding_input is None:
return
# Look up submission by email
masked = _mask_email(email)
result = await find_submission_by_email(TALLY_FORM_ID, email)
if result is None:
logger.debug(f"Tally: no submission found for {masked}")
return
submission, questions = result
logger.info(f"Tally: found submission for {masked}, extracting understanding")
# Format and extract
formatted = format_submission_for_llm(submission, questions)
if not formatted.strip():
logger.warning("Tally: formatted submission was empty, skipping")
return
understanding_input = await extract_business_understanding(formatted)
# Upsert into database
await upsert_business_understanding(user_id, understanding_input)
logger.info(f"Tally: successfully populated understanding for user {user_id}")

View File

@@ -166,6 +166,56 @@ def _merge_lists(existing: list | None, new: list | None) -> list | None:
return merged
def merge_business_understanding_data(
existing_data: dict[str, Any],
input_data: BusinessUnderstandingInput,
) -> dict[str, Any]:
merged_data = dict(existing_data)
merged_business: dict[str, Any] = {}
if isinstance(merged_data.get("business"), dict):
merged_business = dict(merged_data["business"])
business_string_fields = [
"job_title",
"business_name",
"industry",
"business_size",
"user_role",
"additional_notes",
]
business_list_fields = [
"key_workflows",
"daily_activities",
"pain_points",
"bottlenecks",
"manual_tasks",
"automation_goals",
"current_software",
"existing_automation",
]
if input_data.user_name is not None:
merged_data["name"] = input_data.user_name
for field in business_string_fields:
value = getattr(input_data, field)
if value is not None:
merged_business[field] = value
for field in business_list_fields:
value = getattr(input_data, field)
if value is not None:
existing_list = _json_to_list(merged_business.get(field))
merged_list = _merge_lists(existing_list, value)
merged_business[field] = merged_list
merged_business["version"] = 1
merged_data["business"] = merged_business
return merged_data
async def _get_from_cache(user_id: str) -> Optional[BusinessUnderstanding]:
"""Get business understanding from Redis cache."""
try:
@@ -245,63 +295,18 @@ async def upsert_business_understanding(
where={"userId": user_id}
)
# Get existing data structure or start fresh
existing_data: dict[str, Any] = {}
if existing and isinstance(existing.data, dict):
existing_data = dict(existing.data)
existing_business: dict[str, Any] = {}
if isinstance(existing_data.get("business"), dict):
existing_business = dict(existing_data["business"])
# Business fields (stored inside business object)
business_string_fields = [
"job_title",
"business_name",
"industry",
"business_size",
"user_role",
"additional_notes",
]
business_list_fields = [
"key_workflows",
"daily_activities",
"pain_points",
"bottlenecks",
"manual_tasks",
"automation_goals",
"current_software",
"existing_automation",
]
# Handle top-level name field
if input_data.user_name is not None:
existing_data["name"] = input_data.user_name
# Business string fields - overwrite if provided
for field in business_string_fields:
value = getattr(input_data, field)
if value is not None:
existing_business[field] = value
# Business list fields - merge with existing
for field in business_list_fields:
value = getattr(input_data, field)
if value is not None:
existing_list = _json_to_list(existing_business.get(field))
merged = _merge_lists(existing_list, value)
existing_business[field] = merged
# Set version and nest business data
existing_business["version"] = 1
existing_data["business"] = existing_business
merged_data = merge_business_understanding_data(existing_data, input_data)
# Upsert with the merged data
record = await CoPilotUnderstanding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "data": SafeJson(existing_data)},
"update": {"data": SafeJson(existing_data)},
"create": {"userId": user_id, "data": SafeJson(merged_data)},
"update": {"data": SafeJson(merged_data)},
},
)

View File

@@ -2,6 +2,7 @@ import logging
from unittest.mock import AsyncMock, patch
import fastapi.responses
import prisma
import pytest
import backend.api.features.library.model
@@ -497,9 +498,24 @@ async def test_store_listing_graph(server: SpinTestServer):
test_user = await create_test_user()
test_graph = await create_graph(server, create_test_graph(), test_user)
# Ensure the test user has a Profile (required for store submissions)
existing_profile = await prisma.models.Profile.prisma().find_first(
where={"userId": test_user.id}
)
if not existing_profile:
await prisma.models.Profile.prisma().create(
data=prisma.types.ProfileCreateInput(
userId=test_user.id,
name=test_user.name or "Test User",
username=f"test-user-{test_user.id[:8]}",
description="Test user profile",
links=[],
)
)
store_submission_request = backend.api.features.store.model.StoreSubmissionRequest(
agent_id=test_graph.id,
agent_version=test_graph.version,
graph_id=test_graph.id,
graph_version=test_graph.version,
slug=test_graph.id,
name="Test name",
sub_heading="Test sub heading",
@@ -517,8 +533,8 @@ async def test_store_listing_graph(server: SpinTestServer):
assert False, "Failed to create store listing"
slv_id = (
store_listing.store_listing_version_id
if store_listing.store_listing_version_id is not None
store_listing.listing_version_id
if store_listing.listing_version_id is not None
else None
)

View File

@@ -15,6 +15,7 @@ from backend.data import graph as graph_db
from backend.data import human_review as human_review_db
from backend.data import onboarding as onboarding_db
from backend.data import user as user_db
from backend.data import workspace as workspace_db
# Import dynamic field utilities from centralized location
from backend.data.block import BlockInput, BlockOutputEntry
@@ -32,7 +33,6 @@ from backend.data.execution import (
from backend.data.graph import GraphModel, Node
from backend.data.model import USER_TIMEZONE_NOT_SET, CredentialsMetaInput, GraphInput
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.data.workspace import get_or_create_workspace
from backend.util.clients import (
get_async_execution_event_bus,
get_async_execution_queue,
@@ -481,6 +481,22 @@ async def _construct_starting_node_execution_input(
if nodes_input_masks and (node_input_mask := nodes_input_masks.get(node.id)):
input_data.update(node_input_mask)
# Webhook-triggered agents cannot be executed directly without payload data.
# Legitimate webhook triggers provide payload via nodes_input_masks above.
if (
block.block_type
in (
BlockType.WEBHOOK,
BlockType.WEBHOOK_MANUAL,
)
and "payload" not in input_data
):
raise ValueError(
"This agent is triggered by an external event (webhook) "
"and cannot be executed directly. "
"Please use the appropriate trigger to run this agent."
)
input_data, error = validate_exec(node, input_data)
if input_data is None:
raise ValueError(error)
@@ -831,8 +847,9 @@ async def add_graph_execution(
udb = user_db
gdb = graph_db
odb = onboarding_db
wdb = workspace_db
else:
edb = udb = gdb = odb = get_database_manager_async_client()
edb = udb = gdb = odb = wdb = get_database_manager_async_client()
# Get or create the graph execution
if graph_exec_id:
@@ -892,7 +909,7 @@ async def add_graph_execution(
if execution_context is None:
user = await udb.get_user_by_id(user_id)
settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id)
workspace = await get_or_create_workspace(user_id)
workspace = await wdb.get_or_create_workspace(user_id)
execution_context = ExecutionContext(
# Execution identity

View File

@@ -368,12 +368,10 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
mock_get_event_bus = mocker.patch(
"backend.executor.utils.get_async_execution_event_bus"
)
mock_wdb = mocker.patch("backend.executor.utils.workspace_db")
mock_workspace = mocker.MagicMock()
mock_workspace.id = "test-workspace-id"
mocker.patch(
"backend.executor.utils.get_or_create_workspace",
new=mocker.AsyncMock(return_value=mock_workspace),
)
mock_wdb.get_or_create_workspace = mocker.AsyncMock(return_value=mock_workspace)
# Setup mock returns
# The function returns (graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip)
@@ -649,12 +647,10 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
mock_get_event_bus = mocker.patch(
"backend.executor.utils.get_async_execution_event_bus"
)
mock_wdb = mocker.patch("backend.executor.utils.workspace_db")
mock_workspace = mocker.MagicMock()
mock_workspace.id = "test-workspace-id"
mocker.patch(
"backend.executor.utils.get_or_create_workspace",
new=mocker.AsyncMock(return_value=mock_workspace),
)
mock_wdb.get_or_create_workspace = mocker.AsyncMock(return_value=mock_workspace)
# Setup returns - include nodes_to_skip in the tuple
mock_validate.return_value = (

View File

@@ -76,7 +76,6 @@ class TelegramWebhooksManager(BaseWebhooksManager):
credentials_id=credentials.id,
webhook_type=webhook_type,
resource=resource,
events=None, # Ignore events for this lookup
):
# Re-register with Telegram using the same URL but new allowed_updates
ingress_url = webhook_ingress_url(self.PROVIDER_NAME, existing.id)
@@ -143,10 +142,6 @@ class TelegramWebhooksManager(BaseWebhooksManager):
elif "video" in message:
event_type = "message.video"
else:
logger.warning(
"Unknown Telegram webhook payload type; "
f"message.keys() = {message.keys()}"
)
event_type = "message.other"
elif "edited_message" in payload:
event_type = "message.edited_message"

View File

@@ -2,8 +2,8 @@
{#
Template variables:
data.agent_name: the name of the approved agent
data.agent_id: the ID of the agent
data.agent_version: the version of the agent
data.graph_id: the ID of the agent
data.graph_version: the version of the agent
data.reviewer_name: the name of the reviewer who approved it
data.reviewer_email: the email of the reviewer
data.comments: comments from the reviewer
@@ -70,4 +70,4 @@
Thank you for contributing to the AutoGPT ecosystem! 🚀
</p>
{% endblock %}
{% endblock %}

View File

@@ -2,8 +2,8 @@
{#
Template variables:
data.agent_name: the name of the rejected agent
data.agent_id: the ID of the agent
data.agent_version: the version of the agent
data.graph_id: the ID of the agent
data.graph_version: the version of the agent
data.reviewer_name: the name of the reviewer who rejected it
data.reviewer_email: the email of the reviewer
data.comments: comments from the reviewer explaining the rejection
@@ -74,4 +74,4 @@
We're excited to see your improved agent submission! 🚀
</p>
{% endblock %}
{% endblock %}

View File

@@ -64,6 +64,10 @@ class GraphNotInLibraryError(GraphNotAccessibleError):
"""Raised when attempting to execute a graph that is not / no longer in the user's library."""
class PreconditionFailed(Exception):
"""The user must do something else first before trying the current operation"""
class InsufficientBalanceError(ValueError):
user_id: str
message: str

View File

@@ -72,19 +72,58 @@ def dumps(
T = TypeVar("T")
@overload
def loads(data: str | bytes, *args, target_type: Type[T], **kwargs) -> T: ...
# Sentinel value to detect when fallback is not provided
_NO_FALLBACK = object()
@overload
def loads(data: str | bytes, *args, **kwargs) -> Any: ...
def loads(
data: str | bytes, *args, target_type: Type[T], fallback: T | None = None, **kwargs
) -> T:
pass
@overload
def loads(data: str | bytes, *args, fallback: Any = None, **kwargs) -> Any:
pass
def loads(
data: str | bytes, *args, target_type: Type[T] | None = None, **kwargs
data: str | bytes,
*args,
target_type: Type[T] | None = None,
fallback: Any = _NO_FALLBACK,
**kwargs,
) -> Any:
parsed = orjson.loads(data)
"""Parse JSON with optional fallback on decode errors.
Args:
data: JSON string or bytes to parse
target_type: Optional type to validate/cast result to
fallback: Value to return on JSONDecodeError. If not provided, raises.
**kwargs: Additional arguments (unused, for compatibility)
Returns:
Parsed JSON data, or fallback value if parsing fails
Raises:
orjson.JSONDecodeError: Only if fallback is not provided
Examples:
>>> loads('{"valid": "json"}')
{'valid': 'json'}
>>> loads('invalid json', fallback=None)
None
>>> loads('invalid json', fallback={})
{}
>>> loads('invalid json') # raises orjson.JSONDecodeError
"""
try:
parsed = orjson.loads(data)
except orjson.JSONDecodeError:
if fallback is not _NO_FALLBACK:
return fallback
raise
if target_type:
return type_match(parsed, target_type)

View File

@@ -0,0 +1,32 @@
BEGIN;
-- Drop illogical column StoreListing.agentGraphVersion;
ALTER TABLE "StoreListing" DROP CONSTRAINT "StoreListing_agentGraphId_agentGraphVersion_fkey";
DROP INDEX "StoreListing_agentGraphId_agentGraphVersion_idx";
ALTER TABLE "StoreListing" DROP COLUMN "agentGraphVersion";
-- Add uniqueness constraint to Profile.userId and remove invalid data
--
-- Delete any profiles with null userId (which is invalid and doesn't occur in theory)
DELETE FROM "Profile" WHERE "userId" IS NULL;
--
-- Delete duplicate profiles per userId, keeping the most recently updated one
DELETE FROM "Profile"
WHERE "id" IN (
SELECT "id" FROM (
SELECT "id", ROW_NUMBER() OVER (
PARTITION BY "userId" ORDER BY "updatedAt" DESC, "id" DESC
) AS rn
FROM "Profile"
) ranked
WHERE rn > 1
);
--
-- Add userId uniqueness constraint
ALTER TABLE "Profile" ALTER COLUMN "userId" SET NOT NULL;
CREATE UNIQUE INDEX "Profile_userId_key" ON "Profile"("userId");
-- Add formal relation StoreListing.owningUserId -> Profile.userId
ALTER TABLE "StoreListing" ADD CONSTRAINT "StoreListing_owner_Profile_fkey" FOREIGN KEY ("owningUserId") REFERENCES "Profile"("userId") ON DELETE CASCADE ON UPDATE CASCADE;
COMMIT;

View File

@@ -0,0 +1,219 @@
-- Update the StoreSubmission and StoreAgent views with additional fields, clearer field names, and faster joins.
-- Steps:
-- 1. Update `mv_agent_run_counts` to exclude runs by the agent's creator
-- a. Drop dependent views `StoreAgent` and `Creator`
-- b. Update `mv_agent_run_counts` and its index
-- c. Recreate `StoreAgent` view (with updates)
-- d. Restore `Creator` view
-- 2. Update `StoreSubmission` view
-- 3. Update `StoreListingReview` indices to make `StoreSubmission` query more efficient
BEGIN;
-- Drop views that are dependent on mv_agent_run_counts
DROP VIEW IF EXISTS "StoreAgent";
DROP VIEW IF EXISTS "Creator";
-- Update materialized view for agent run counts to exclude runs by the agent's creator
DROP INDEX IF EXISTS "idx_mv_agent_run_counts";
DROP MATERIALIZED VIEW IF EXISTS "mv_agent_run_counts";
CREATE MATERIALIZED VIEW "mv_agent_run_counts" AS
SELECT
run."agentGraphId" AS graph_id,
COUNT(*) AS run_count
FROM "AgentGraphExecution" run
JOIN "AgentGraph" graph ON graph.id = run."agentGraphId"
-- Exclude runs by the agent's creator to avoid inflating run counts
WHERE graph."userId" != run."userId"
GROUP BY run."agentGraphId";
-- Recreate index
CREATE UNIQUE INDEX IF NOT EXISTS "idx_mv_agent_run_counts" ON "mv_agent_run_counts"("graph_id");
-- Re-populate the materialized view
REFRESH MATERIALIZED VIEW "mv_agent_run_counts";
-- Recreate the StoreAgent view with the following changes
-- (compared to 20260115210000_remove_storelistingversion_search):
-- - Narrow to *explicitly active* version (sl.activeVersionId) instead of MAX(version)
-- - Add `recommended_schedule_cron` column
-- - Rename `"storeListingVersionId"` -> `listing_version_id`
-- - Rename `"agentGraphVersions"` -> `graph_versions`
-- - Rename `"agentGraphId"` -> `graph_id`
-- - Rename `"useForOnboarding"` -> `use_for_onboarding`
CREATE OR REPLACE VIEW "StoreAgent" AS
WITH store_agent_versions AS (
SELECT
"storeListingId",
array_agg(DISTINCT version::text ORDER BY version::text) AS versions
FROM "StoreListingVersion"
WHERE "submissionStatus" = 'APPROVED'
GROUP BY "storeListingId"
),
agent_graph_versions AS (
SELECT
"storeListingId",
array_agg(DISTINCT "agentGraphVersion"::text ORDER BY "agentGraphVersion"::text) AS graph_versions
FROM "StoreListingVersion"
WHERE "submissionStatus" = 'APPROVED'
GROUP BY "storeListingId"
)
SELECT
sl.id AS listing_id,
slv.id AS listing_version_id,
slv."createdAt" AS updated_at,
sl.slug,
COALESCE(slv.name, '') AS agent_name,
slv."videoUrl" AS agent_video,
slv."agentOutputDemoUrl" AS agent_output_demo,
COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image,
slv."isFeatured" AS featured,
cp.username AS creator_username,
cp."avatarUrl" AS creator_avatar,
slv."subHeading" AS sub_heading,
slv.description,
slv.categories,
COALESCE(arc.run_count, 0::bigint) AS runs,
COALESCE(reviews.avg_rating, 0.0)::double precision AS rating,
COALESCE(sav.versions, ARRAY[slv.version::text]) AS versions,
slv."agentGraphId" AS graph_id,
COALESCE(
agv.graph_versions,
ARRAY[slv."agentGraphVersion"::text]
) AS graph_versions,
slv."isAvailable" AS is_available,
COALESCE(sl."useForOnboarding", false) AS use_for_onboarding,
slv."recommendedScheduleCron" AS recommended_schedule_cron
FROM "StoreListing" AS sl
JOIN "StoreListingVersion" AS slv
ON slv."storeListingId" = sl.id
AND slv.id = sl."activeVersionId"
AND slv."submissionStatus" = 'APPROVED'
JOIN "AgentGraph" AS ag
ON slv."agentGraphId" = ag.id
AND slv."agentGraphVersion" = ag.version
LEFT JOIN "Profile" AS cp
ON sl."owningUserId" = cp."userId"
LEFT JOIN "mv_review_stats" AS reviews
ON sl.id = reviews."storeListingId"
LEFT JOIN "mv_agent_run_counts" AS arc
ON ag.id = arc.graph_id
LEFT JOIN store_agent_versions AS sav
ON sl.id = sav."storeListingId"
LEFT JOIN agent_graph_versions AS agv
ON sl.id = agv."storeListingId"
WHERE sl."isDeleted" = false
AND sl."hasApprovedVersion" = true;
-- Restore Creator view as last updated in 20250604130249_optimise_store_agent_and_creator_views,
-- with minor changes:
-- - Ensure top_categories always TEXT[]
-- - Filter out empty ('') categories
CREATE OR REPLACE VIEW "Creator" AS
WITH creator_listings AS (
SELECT
sl."owningUserId",
sl.id AS listing_id,
slv."agentGraphId",
slv.categories,
sr.score,
ar.run_count
FROM "StoreListing" sl
JOIN "StoreListingVersion" slv
ON slv."storeListingId" = sl.id
AND slv."submissionStatus" = 'APPROVED'
LEFT JOIN "StoreListingReview" sr
ON sr."storeListingVersionId" = slv.id
LEFT JOIN "mv_agent_run_counts" ar
ON ar.graph_id = slv."agentGraphId"
WHERE sl."isDeleted" = false
AND sl."hasApprovedVersion" = true
),
creator_stats AS (
SELECT
cl."owningUserId",
COUNT(DISTINCT cl.listing_id) AS num_agents,
AVG(COALESCE(cl.score, 0)::numeric) AS agent_rating,
SUM(COALESCE(cl.run_count, 0)) AS agent_runs,
array_agg(DISTINCT cat ORDER BY cat)
FILTER (WHERE cat IS NOT NULL AND cat != '') AS all_categories
FROM creator_listings cl
LEFT JOIN LATERAL unnest(COALESCE(cl.categories, ARRAY[]::text[])) AS cat ON true
GROUP BY cl."owningUserId"
)
SELECT
p.username,
p.name,
p."avatarUrl" AS avatar_url,
p.description,
COALESCE(cs.all_categories, ARRAY[]::text[]) AS top_categories,
p.links,
p."isFeatured" AS is_featured,
COALESCE(cs.num_agents, 0::bigint) AS num_agents,
COALESCE(cs.agent_rating, 0.0) AS agent_rating,
COALESCE(cs.agent_runs, 0::numeric) AS agent_runs
FROM "Profile" p
LEFT JOIN creator_stats cs ON cs."owningUserId" = p."userId";
-- Recreate the StoreSubmission view with updated fields & query strategy:
-- - Uses mv_agent_run_counts instead of full AgentGraphExecution table scan + aggregation
-- - Renamed agent_id, agent_version -> graph_id, graph_version
-- - Renamed store_listing_version_id -> listing_version_id
-- - Renamed date_submitted -> submitted_at
-- - Renamed runs, rating -> run_count, review_avg_rating
-- - Added fields: instructions, agent_output_demo_url, review_count, is_deleted
DROP VIEW IF EXISTS "StoreSubmission";
CREATE OR REPLACE VIEW "StoreSubmission" AS
WITH review_stats AS (
SELECT
"storeListingVersionId" AS version_id, -- more specific than mv_review_stats
avg(score) AS avg_rating,
count(*) AS review_count
FROM "StoreListingReview"
GROUP BY "storeListingVersionId"
)
SELECT
sl.id AS listing_id,
sl."owningUserId" AS user_id,
sl.slug AS slug,
slv.id AS listing_version_id,
slv.version AS listing_version,
slv."agentGraphId" AS graph_id,
slv."agentGraphVersion" AS graph_version,
slv.name AS name,
slv."subHeading" AS sub_heading,
slv.description AS description,
slv.instructions AS instructions,
slv.categories AS categories,
slv."imageUrls" AS image_urls,
slv."videoUrl" AS video_url,
slv."agentOutputDemoUrl" AS agent_output_demo_url,
slv."submittedAt" AS submitted_at,
slv."changesSummary" AS changes_summary,
slv."submissionStatus" AS status,
slv."reviewedAt" AS reviewed_at,
slv."reviewerId" AS reviewer_id,
slv."reviewComments" AS review_comments,
slv."internalComments" AS internal_comments,
slv."isDeleted" AS is_deleted,
COALESCE(run_stats.run_count, 0::bigint) AS run_count,
COALESCE(review_stats.review_count, 0::bigint) AS review_count,
COALESCE(review_stats.avg_rating, 0.0)::double precision AS review_avg_rating
FROM "StoreListing" sl
JOIN "StoreListingVersion" slv ON slv."storeListingId" = sl.id
LEFT JOIN review_stats ON review_stats.version_id = slv.id
LEFT JOIN mv_agent_run_counts run_stats ON run_stats.graph_id = slv."agentGraphId"
WHERE sl."isDeleted" = false;
-- Drop unused index on StoreListingReview.reviewByUserId
DROP INDEX IF EXISTS "StoreListingReview_reviewByUserId_idx";
-- Add index on storeListingVersionId to make StoreSubmission query faster
CREATE INDEX "StoreListingReview_storeListingVersionId_idx" ON "StoreListingReview"("storeListingVersionId");
COMMIT;

View File

@@ -0,0 +1,114 @@
-- CreateEnum
CREATE TYPE "InvitedUserStatus" AS ENUM ('INVITED', 'CLAIMED', 'REVOKED');
-- CreateEnum
CREATE TYPE "TallyComputationStatus" AS ENUM ('PENDING', 'RUNNING', 'READY', 'FAILED');
-- CreateTable
CREATE TABLE "InvitedUser" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"email" TEXT NOT NULL,
"status" "InvitedUserStatus" NOT NULL DEFAULT 'INVITED',
"authUserId" TEXT,
"name" TEXT,
"tallyUnderstanding" JSONB,
"tallyStatus" "TallyComputationStatus" NOT NULL DEFAULT 'PENDING',
"tallyComputedAt" TIMESTAMP(3),
"tallyError" TEXT,
CONSTRAINT "InvitedUser_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "InvitedUser_email_key" ON "InvitedUser"("email");
-- CreateIndex
CREATE UNIQUE INDEX "InvitedUser_authUserId_key" ON "InvitedUser"("authUserId");
-- CreateIndex
CREATE INDEX "InvitedUser_status_idx" ON "InvitedUser"("status");
-- CreateIndex
CREATE INDEX "InvitedUser_tallyStatus_idx" ON "InvitedUser"("tallyStatus");
-- AddForeignKey
ALTER TABLE "InvitedUser" ADD CONSTRAINT "InvitedUser_authUserId_fkey"
FOREIGN KEY ("authUserId") REFERENCES "User"("id") ON DELETE SET NULL ON UPDATE CASCADE;
DO $$
DECLARE
allowed_users_schema TEXT;
BEGIN
SELECT table_schema
INTO allowed_users_schema
FROM information_schema.columns
WHERE table_name = 'allowed_users'
AND column_name = 'email'
ORDER BY CASE
WHEN table_schema = 'platform' THEN 0
WHEN table_schema = 'public' THEN 1
ELSE 2
END
LIMIT 1;
IF allowed_users_schema IS NOT NULL THEN
EXECUTE format(
'INSERT INTO platform."InvitedUser" ("id", "email", "status", "tallyStatus", "createdAt", "updatedAt")
SELECT gen_random_uuid()::text,
lower(email),
''INVITED''::platform."InvitedUserStatus",
''PENDING''::platform."TallyComputationStatus",
now(),
now()
FROM %I.allowed_users
WHERE email IS NOT NULL
ON CONFLICT ("email") DO NOTHING',
allowed_users_schema
);
END IF;
END $$;
CREATE OR REPLACE FUNCTION platform.ensure_invited_user_can_register()
RETURNS TRIGGER AS $$
BEGIN
IF NEW.email IS NULL THEN
RAISE EXCEPTION 'The email address "%" is not allowed to register. Please contact support for assistance.', NEW.email
USING ERRCODE = 'P0001';
END IF;
IF lower(split_part(NEW.email, '@', 2)) = 'agpt.co' THEN
RETURN NEW;
END IF;
IF EXISTS (
SELECT 1
FROM platform."InvitedUser" invited_user
WHERE lower(invited_user.email) = lower(NEW.email)
AND invited_user.status = 'INVITED'::platform."InvitedUserStatus"
) THEN
RETURN NEW;
END IF;
RAISE EXCEPTION 'The email address "%" is not allowed to register. Please contact support for assistance.', NEW.email
USING ERRCODE = 'P0001';
END;
$$ LANGUAGE plpgsql SECURITY DEFINER;
DO $$
BEGIN
IF EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema = 'auth'
AND table_name = 'users'
) THEN
DROP TRIGGER IF EXISTS user_added_to_platform ON auth.users;
DROP TRIGGER IF EXISTS invited_user_signup_gate ON auth.users;
CREATE TRIGGER invited_user_signup_gate
BEFORE INSERT ON auth.users
FOR EACH ROW EXECUTE FUNCTION platform.ensure_invited_user_can_register();
END IF;
END $$;

View File

@@ -65,6 +65,7 @@ model User {
NotificationBatches UserNotificationBatch[]
PendingHumanReviews PendingHumanReview[]
Workspace UserWorkspace?
ClaimedInvite InvitedUser? @relation("InvitedUserAuthUser")
// OAuth Provider relations
OAuthApplications OAuthApplication[]
@@ -73,6 +74,39 @@ model User {
OAuthRefreshTokens OAuthRefreshToken[]
}
enum InvitedUserStatus {
INVITED
CLAIMED
REVOKED
}
enum TallyComputationStatus {
PENDING
RUNNING
READY
FAILED
}
model InvitedUser {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
email String @unique
status InvitedUserStatus @default(INVITED)
authUserId String? @unique
AuthUser User? @relation("InvitedUserAuthUser", fields: [authUserId], references: [id], onDelete: SetNull)
name String?
tallyUnderstanding Json?
tallyStatus TallyComputationStatus @default(PENDING)
tallyComputedAt DateTime?
tallyError String?
@@index([status])
@@index([tallyStatus])
}
enum OnboardingStep {
// Introductory onboarding (Library)
WELCOME
@@ -281,7 +315,6 @@ model AgentGraph {
Presets AgentPreset[]
LibraryAgents LibraryAgent[]
StoreListings StoreListing[]
StoreListingVersions StoreListingVersion[]
@@id(name: "graphVersionId", [id, version])
@@ -814,10 +847,8 @@ model Profile {
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
// Only 1 of user or group can be set.
// The user this profile belongs to, if any.
userId String?
User User? @relation(fields: [userId], references: [id], onDelete: Cascade)
userId String @unique
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
name String
username String @unique
@@ -830,6 +861,7 @@ model Profile {
isFeatured Boolean @default(false)
LibraryAgents LibraryAgent[]
StoreListings StoreListing[]
@@index([userId])
}
@@ -860,9 +892,9 @@ view Creator {
}
view StoreAgent {
listing_id String @id
storeListingVersionId String
updated_at DateTime
listing_id String @id
listing_version_id String
updated_at DateTime
slug String
agent_name String
@@ -879,10 +911,12 @@ view StoreAgent {
runs Int
rating Float
versions String[]
agentGraphVersions String[]
agentGraphId String
graph_id String
graph_versions String[]
is_available Boolean @default(true)
useForOnboarding Boolean @default(false)
use_for_onboarding Boolean @default(false)
recommended_schedule_cron String?
// Materialized views used (refreshed every 15 minutes via pg_cron):
// - mv_agent_run_counts - Pre-aggregated agent execution counts by agentGraphId
@@ -896,41 +930,52 @@ view StoreAgent {
}
view StoreSubmission {
listing_id String @id
user_id String
slug String
name String
sub_heading String
description String
image_urls String[]
date_submitted DateTime
status SubmissionStatus
runs Int
rating Float
agent_id String
agent_version Int
store_listing_version_id String
reviewer_id String?
review_comments String?
internal_comments String?
reviewed_at DateTime?
changes_summary String?
video_url String?
categories String[]
// From StoreListing:
listing_id String
user_id String
slug String
// Index or unique are not applied to views
// From StoreListingVersion:
listing_version_id String @id
listing_version Int
graph_id String
graph_version Int
name String
sub_heading String
description String
instructions String?
categories String[]
image_urls String[]
video_url String?
agent_output_demo_url String?
submitted_at DateTime?
changes_summary String?
status SubmissionStatus
reviewed_at DateTime?
reviewer_id String?
review_comments String?
internal_comments String?
is_deleted Boolean
// Aggregated from AgentGraphExecutions and StoreListingReviews:
run_count Int
review_count Int
review_avg_rating Float
}
// Note: This is actually a MATERIALIZED VIEW in the database
// Refreshed automatically every 15 minutes via pg_cron (with fallback to manual refresh)
view mv_agent_run_counts {
agentGraphId String @unique
run_count Int
graph_id String @unique
run_count Int // excluding runs by the graph's creator
// Pre-aggregated count of AgentGraphExecution records by agentGraphId
// Used by StoreAgent and Creator views for performance optimization
// Unique index created automatically on agentGraphId for fast lookups
// Refresh uses CONCURRENTLY to avoid blocking reads
// Pre-aggregated count of AgentGraphExecution records by agentGraphId.
// Used by StoreAgent, Creator, and StoreSubmission views for performance optimization.
// - Should have a unique index on graph_id for fast lookups
// - Refresh should use CONCURRENTLY to avoid blocking reads
}
// Note: This is actually a MATERIALIZED VIEW in the database
@@ -979,22 +1024,18 @@ model StoreListing {
ActiveVersion StoreListingVersion? @relation("ActiveVersion", fields: [activeVersionId], references: [id])
// The agent link here is only so we can do lookup on agentId
agentGraphId String
agentGraphVersion Int
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version], onDelete: Cascade)
agentGraphId String @unique
owningUserId String
OwningUser User @relation(fields: [owningUserId], references: [id])
owningUserId String
OwningUser User @relation(fields: [owningUserId], references: [id])
CreatorProfile Profile @relation(fields: [owningUserId], references: [userId], map: "StoreListing_owner_Profile_fkey", onDelete: Cascade)
// Relations
Versions StoreListingVersion[] @relation("ListingVersions")
// Unique index on agentId to ensure only one listing per agent, regardless of number of versions the agent has.
@@unique([agentGraphId])
@@unique([owningUserId, slug])
// Used in the view query
@@index([isDeleted, hasApprovedVersion])
@@index([agentGraphId, agentGraphVersion])
}
model StoreListingVersion {
@@ -1089,16 +1130,16 @@ model UnifiedContentEmbedding {
// Search data
embedding Unsupported("vector(1536)") // pgvector embedding (extension in platform schema)
searchableText String // Combined text for search and fallback
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector")) // Full-text search (auto-populated by trigger)
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector")) // Full-text search (auto-populated by trigger)
metadata Json @default("{}") // Content-specific metadata
// NO @@index for search - GIN index "UnifiedContentEmbedding_search_idx" created via SQL migration
// Prisma may generate DROP INDEX on migrate dev - that's okay, migration recreates it
@@unique([contentType, contentId, userId], map: "UnifiedContentEmbedding_contentType_contentId_userId_key")
@@index([contentType])
@@index([userId])
@@index([contentType, userId])
@@index([embedding], map: "UnifiedContentEmbedding_embedding_idx")
// NO @@index for search - GIN index "UnifiedContentEmbedding_search_idx" created via SQL migration
// Prisma may generate DROP INDEX on migrate dev - that's okay, migration recreates it
}
model StoreListingReview {
@@ -1115,8 +1156,9 @@ model StoreListingReview {
score Int
comments String?
// Enforce one review per user per listing version
@@unique([storeListingVersionId, reviewByUserId])
@@index([reviewByUserId])
@@index([storeListingVersionId])
}
enum SubmissionStatus {

View File

@@ -23,14 +23,14 @@
"1.0.0",
"1.1.0"
],
"agentGraphVersions": [
"graph_id": "test-graph-id",
"graph_versions": [
"1",
"2"
],
"agentGraphId": "test-graph-id",
"last_updated": "2023-01-01T00:00:00",
"recommended_schedule_cron": null,
"active_version_id": null,
"has_approved_version": false,
"active_version_id": "test-version-id",
"has_approved_version": true,
"changelog": null
}

View File

@@ -1,14 +1,16 @@
{
"name": "Test User",
"username": "creator1",
"name": "Test User",
"description": "Test creator description",
"avatar_url": "avatar.jpg",
"links": [
"link1.com",
"link2.com"
],
"avatar_url": "avatar.jpg",
"agent_rating": 4.8,
"is_featured": true,
"num_agents": 5,
"agent_runs": 1000,
"agent_rating": 4.8,
"top_categories": [
"category1",
"category2"

View File

@@ -1,54 +1,94 @@
{
"creators": [
{
"name": "Creator 0",
"username": "creator0",
"name": "Creator 0",
"description": "Creator 0 description",
"avatar_url": "avatar0.jpg",
"links": [
"user0.link.com"
],
"is_featured": false,
"num_agents": 1,
"agent_rating": 4.5,
"agent_runs": 100,
"is_featured": false
"agent_rating": 4.5,
"top_categories": [
"cat1",
"cat2",
"cat3"
]
},
{
"name": "Creator 1",
"username": "creator1",
"name": "Creator 1",
"description": "Creator 1 description",
"avatar_url": "avatar1.jpg",
"links": [
"user1.link.com"
],
"is_featured": false,
"num_agents": 1,
"agent_rating": 4.5,
"agent_runs": 100,
"is_featured": false
"agent_rating": 4.5,
"top_categories": [
"cat1",
"cat2",
"cat3"
]
},
{
"name": "Creator 2",
"username": "creator2",
"name": "Creator 2",
"description": "Creator 2 description",
"avatar_url": "avatar2.jpg",
"links": [
"user2.link.com"
],
"is_featured": false,
"num_agents": 1,
"agent_rating": 4.5,
"agent_runs": 100,
"is_featured": false
"agent_rating": 4.5,
"top_categories": [
"cat1",
"cat2",
"cat3"
]
},
{
"name": "Creator 3",
"username": "creator3",
"name": "Creator 3",
"description": "Creator 3 description",
"avatar_url": "avatar3.jpg",
"links": [
"user3.link.com"
],
"is_featured": false,
"num_agents": 1,
"agent_rating": 4.5,
"agent_runs": 100,
"is_featured": false
"agent_rating": 4.5,
"top_categories": [
"cat1",
"cat2",
"cat3"
]
},
{
"name": "Creator 4",
"username": "creator4",
"name": "Creator 4",
"description": "Creator 4 description",
"avatar_url": "avatar4.jpg",
"links": [
"user4.link.com"
],
"is_featured": false,
"num_agents": 1,
"agent_rating": 4.5,
"agent_runs": 100,
"is_featured": false
"agent_rating": 4.5,
"top_categories": [
"cat1",
"cat2",
"cat3"
]
}
],
"pagination": {

View File

@@ -2,32 +2,33 @@
"submissions": [
{
"listing_id": "test-listing-id",
"agent_id": "test-agent-id",
"agent_version": 1,
"user_id": "test-user-id",
"slug": "test-agent",
"listing_version_id": "test-version-id",
"listing_version": 1,
"graph_id": "test-agent-id",
"graph_version": 1,
"name": "Test Agent",
"sub_heading": "Test agent subheading",
"slug": "test-agent",
"description": "Test agent description",
"instructions": null,
"instructions": "Click the button!",
"categories": [
"test-category"
],
"image_urls": [
"test.jpg"
],
"date_submitted": "2023-01-01T00:00:00",
"video_url": "test.mp4",
"agent_output_demo_url": "demo_video.mp4",
"submitted_at": "2023-01-01T00:00:00",
"changes_summary": "Initial Submission",
"status": "APPROVED",
"runs": 50,
"rating": 4.2,
"store_listing_version_id": null,
"version": null,
"reviewed_at": null,
"reviewer_id": null,
"review_comments": null,
"internal_comments": null,
"reviewed_at": null,
"changes_summary": null,
"video_url": "test.mp4",
"agent_output_demo_url": null,
"categories": [
"test-category"
]
"run_count": 50,
"review_count": 5,
"review_avg_rating": 4.2
}
],
"pagination": {

View File

@@ -128,7 +128,7 @@ class TestDataCreator:
email = "test123@gmail.com"
else:
email = faker.unique.email()
password = "testpassword123" # Standard test password
password = "testpassword123" # Standard test password # pragma: allowlist secret # noqa
user_id = f"test-user-{i}-{faker.uuid4()}"
# Create user in Supabase Auth (if needed)
@@ -571,8 +571,8 @@ class TestDataCreator:
if test_user and self.agent_graphs:
test_submission_data = {
"user_id": test_user["id"],
"agent_id": self.agent_graphs[0]["id"],
"agent_version": 1,
"graph_id": self.agent_graphs[0]["id"],
"graph_version": 1,
"slug": "test-agent-submission",
"name": "Test Agent Submission",
"sub_heading": "A test agent for frontend testing",
@@ -593,9 +593,9 @@ class TestDataCreator:
print("✅ Created special test store submission for test123@gmail.com")
# ALWAYS approve and feature the test submission
if test_submission.store_listing_version_id:
if test_submission.listing_version_id:
approved_submission = await review_store_submission(
store_listing_version_id=test_submission.store_listing_version_id,
store_listing_version_id=test_submission.listing_version_id,
is_approved=True,
external_comments="Test submission approved",
internal_comments="Auto-approved test submission",
@@ -605,7 +605,7 @@ class TestDataCreator:
print("✅ Approved test store submission")
await prisma.storelistingversion.update(
where={"id": test_submission.store_listing_version_id},
where={"id": test_submission.listing_version_id},
data={"isFeatured": True},
)
featured_count += 1
@@ -640,8 +640,8 @@ class TestDataCreator:
submission = await create_store_submission(
user_id=user["id"],
agent_id=graph["id"],
agent_version=graph.get("version", 1),
graph_id=graph["id"],
graph_version=graph.get("version", 1),
slug=faker.slug(),
name=graph.get("name", faker.sentence(nb_words=3)),
sub_heading=faker.sentence(),
@@ -654,7 +654,7 @@ class TestDataCreator:
submissions.append(submission.model_dump())
print(f"✅ Created store submission: {submission.name}")
if submission.store_listing_version_id:
if submission.listing_version_id:
# DETERMINISTIC: First N submissions are always approved
# First GUARANTEED_FEATURED_AGENTS of those are always featured
should_approve = (
@@ -667,7 +667,7 @@ class TestDataCreator:
try:
reviewer_id = random.choice(self.users)["id"]
approved_submission = await review_store_submission(
store_listing_version_id=submission.store_listing_version_id,
store_listing_version_id=submission.listing_version_id,
is_approved=True,
external_comments="Auto-approved for E2E testing",
internal_comments="Automatically approved by E2E test data script",
@@ -683,9 +683,7 @@ class TestDataCreator:
if should_feature:
try:
await prisma.storelistingversion.update(
where={
"id": submission.store_listing_version_id
},
where={"id": submission.listing_version_id},
data={"isFeatured": True},
)
featured_count += 1
@@ -699,9 +697,7 @@ class TestDataCreator:
elif random.random() < 0.2:
try:
await prisma.storelistingversion.update(
where={
"id": submission.store_listing_version_id
},
where={"id": submission.listing_version_id},
data={"isFeatured": True},
)
featured_count += 1
@@ -721,7 +717,7 @@ class TestDataCreator:
try:
reviewer_id = random.choice(self.users)["id"]
await review_store_submission(
store_listing_version_id=submission.store_listing_version_id,
store_listing_version_id=submission.listing_version_id,
is_approved=False,
external_comments="Submission rejected - needs improvements",
internal_comments="Automatically rejected by E2E test data script",

View File

@@ -394,7 +394,6 @@ async def main():
listing = await db.storelisting.create(
data={
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"owningUserId": user.id,
"hasApprovedVersion": random.choice([True, False]),
"slug": slug,

View File

@@ -0,0 +1,393 @@
# Beta Invite Pre-Provisioning Design
## Problem
The current signup path is split across three places:
- Supabase creates the auth user and password.
- The backend lazily creates `platform.User` on first authenticated request in [`backend/backend/data/user.py`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/backend/backend/data/user.py).
- A Postgres trigger creates `platform.Profile` after `auth.users` insert in [`backend/migrations/20250205100104_add_profile_trigger/migration.sql`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/backend/migrations/20250205100104_add_profile_trigger/migration.sql).
That works for open signup plus a DB-level allowlist, but it does not give the platform a durable object representing:
- a beta invite before the person signs up
- pre-computed onboarding data tied to an email before auth exists
- the distinction between "invited", "claimed", "active", and "revoked"
It also makes first-login setup racey because `User`, `Profile`, `UserOnboarding`, and `CoPilotUnderstanding` are not created from one source of truth.
## Goals
- Allow staff to invite a user before they have a Supabase account.
- Pre-create the platform-side user record and related data before first login.
- Keep Supabase responsible for password entry, email verification, sessions, and OAuth providers.
- Support both password signup and magic-link / OAuth signup for invited users.
- Preserve the existing ability to populate Tally-derived understanding and onboarding defaults.
- Make first login idempotent and safe if the user retries or signs in with a different method.
## Non-goals
- Replacing Supabase Auth.
- Building a full enterprise identity management system.
- Solving general-purpose team/org invites in this change.
## Proposed model
Introduce invite-backed pre-provisioning with email as the pre-auth identity key, then bind the invite to the Supabase `auth.users.id` when the user claims it.
### New tables
#### `BetaInvite`
Represents an invitation sent to a person before they create credentials.
Suggested fields:
- `id`
- `email` unique
- `status` enum: `PENDING`, `CLAIMED`, `EXPIRED`, `REVOKED`
- `inviteTokenHash` nullable
- `invitedByUserId` nullable
- `expiresAt` nullable
- `claimedAt` nullable
- `claimedAuthUserId` nullable unique
- `metadata` jsonb
- `createdAt`
- `updatedAt`
`metadata` should hold operational fields only, for example:
- source: manual import, admin UI, CSV
- cohort: beta wave name
- notes
- original tally email if normalization differs
#### `PreProvisionedUser`
Represents the platform-side user state that exists before Supabase auth is bound.
Suggested fields:
- `id`
- `inviteId` unique
- `email` unique
- `authUserId` nullable unique
- `status` enum: `PENDING_CLAIM`, `ACTIVE`, `MERGE_REQUIRED`, `DISABLED`
- `name` nullable
- `timezone` default `not-set`
- `emailVerified` default `false`
- `metadata` jsonb
- `createdAt`
- `updatedAt`
This is the durable record staff can enrich before login.
#### `PreProvisionedUserSeed`
Stores structured seed data that should become first-class user records on claim.
Suggested fields:
- `id`
- `preProvisionedUserId` unique
- `profile` jsonb nullable
- `onboarding` jsonb nullable
- `businessUnderstanding` jsonb nullable
- `promptContext` jsonb nullable
- `createdAt`
- `updatedAt`
This avoids polluting the main `User.metadata` blob and gives explicit ownership over what is seed data versus ongoing user-authored data.
## Why not pre-create `platform.User` directly?
`platform.User.id` is currently designed to equal the Supabase user id in [`backend/schema.prisma`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/backend/schema.prisma). Pre-creating `User` with a fake UUID would ripple through every FK and complicate the eventual bind. Reusing email as the join key inside `User` is also not safe enough because the rest of the system assumes `User.id` is the canonical identity.
A separate pre-provisioning layer is lower risk:
- it avoids breaking every relation off `User`
- it keeps the existing auth token model intact
- it gives a clean migration path to merge into `User` once a Supabase identity exists
## Claim flow
### 1. Staff creates invite
Admin API or script:
1. Create `BetaInvite`.
2. Create `PreProvisionedUser`.
3. Create `PreProvisionedUserSeed`.
4. Optionally fetch and store Tally-derived understanding immediately by email.
5. Optionally send invite email containing a claim link.
### 2. User opens signup page
Two supported modes:
- direct signup with invited email
- invite-link signup with a short-lived token
Preferred UX:
- `/signup?invite=<token>`
- frontend resolves token to masked email and invite status
- email field is prefilled and locked by default
### 3. User creates Supabase account
Frontend still calls Supabase `signUp`, but include invite context in user metadata:
- `invite_id`
- `invite_email`
This is useful for debugging but should not be the source of truth.
### 4. First authenticated backend call activates invite
Replace the current pure `get_or_create_user` behavior with:
1. Read JWT `sub` and `email`.
2. Look for existing `platform.User` by `id`.
3. If found, return it.
4. Else look for `PreProvisionedUser` by normalized email and `status = PENDING_CLAIM`.
5. In one transaction:
- create `platform.User` with `id = auth sub`
- bind `PreProvisionedUser.authUserId = auth sub`
- mark `PreProvisionedUser.status = ACTIVE`
- mark `BetaInvite.status = CLAIMED`
- create or upsert `Profile`
- create or upsert `UserOnboarding`
- create or upsert `CoPilotUnderstanding`
- create `UserWorkspace` if required for the product experience
6. If no pre-provisioned record exists, either:
- reject access for closed beta, or
- fall back to normal creation if a feature flag allows open signup
This activation logic should live in the backend service, not a Postgres trigger, because it needs to coordinate across multiple platform tables and apply merge rules.
## Data merge rules
When activating a pre-provisioned invite:
- `User`
- source of truth for `id` is Supabase JWT `sub`
- source of truth for `email` is Supabase auth email
- `name` prefers pre-provisioned value, falls back to auth metadata name
- `Profile`
- if seed profile exists, use it
- else create the current generated default
- never overwrite an existing profile if activation is retried
- `UserOnboarding`
- seed values are initial defaults only
- use `upsert` with create-on-missing semantics
- `CoPilotUnderstanding`
- seed from stored Tally extraction if present
- skip Tally backfill on first login if this already exists
- prompt/tally-specific context
- do not jam this into `User.metadata` unless no better typed model exists
- use `PreProvisionedUserSeed.promptContext` now, migrate to typed tables later if product solidifies
## Invite enforcement
The current closed-beta gate appears to rely on a Supabase-side DB error surfaced as "not allowed" in [`frontend/src/app/api/auth/utils.ts`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/frontend/src/app/api/auth/utils.ts).
That should evolve to:
### Short term
Keep the current Supabase signup restriction, but have it check `BetaInvite`/`PreProvisionedUser` instead of a raw allowlist.
### Medium term
Stop using a generic DB exception as the main product signal and expose a backend endpoint:
- `POST /internal/beta-invites/validate`
- `GET /internal/beta-invites/:token`
Then the frontend can fail earlier with a specific state:
- invited and ready
- invite expired
- already claimed
- not invited
Supabase should still remain the hard gate for credential creation, but the UI should stop depending on opaque trigger failures for normal control flow.
## Trigger changes
The `auth.users` trigger that creates `platform.User` and `platform.Profile` should be removed once activation logic is live.
Reason:
- it cannot see or safely apply pre-provisioned seed data
- it only handles create, not merge/bind
- it duplicates responsibility already present in `get_or_create_user`
Interim state during rollout:
- keep trigger disabled for production once backend activation ships
- keep a backfill script for any auth users created during the transition
## API surface
### Admin APIs
- `POST /admin/beta-invites`
- create invite + pre-provisioned user + seed data
- `POST /admin/beta-invites/bulk`
- CSV import for beta cohorts
- `POST /admin/beta-invites/:id/resend`
- `POST /admin/beta-invites/:id/revoke`
- `GET /admin/beta-invites`
### Public/auth-adjacent APIs
- `GET /beta-invites/lookup?token=...`
- return safe invite info for signup page
- `POST /beta-invites/claim-preview`
- optional: check whether an email is invited before attempting Supabase signup
### Internal service function changes
- Replace `get_or_create_user` with `get_or_activate_user`
- keep a compatibility wrapper if needed to limit churn
## Suggested backend implementation
### New module
- `backend/backend/data/beta_invite.py`
Responsibilities:
- create invite
- resolve invite token
- normalize email
- activate pre-provisioned user
- revoke / expire invite
### Existing module changes
- [`backend/backend/data/user.py`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/backend/backend/data/user.py)
- move lazy creation logic into activation-aware flow
- [`backend/backend/api/features/v1.py`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/backend/backend/api/features/v1.py)
- `/auth/user` should call activation-aware function
- only run background Tally population if no seeded understanding exists
- [`backend/backend/data/tally.py`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/backend/backend/data/tally.py)
- add a reusable "seed from email" function for invite creation time
## Suggested frontend implementation
### Signup
- read invite token from URL
- call invite lookup endpoint
- prefill locked email when token is valid
- if invite is invalid, show specific error, not generic waitlist modal
Files likely affected:
- [`frontend/src/app/(platform)/signup/page.tsx`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/frontend/src/app/(platform)/signup/page.tsx)
- [`frontend/src/app/(platform)/signup/actions.ts`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/frontend/src/app/(platform)/signup/actions.ts)
- [`frontend/src/components/auth/WaitlistErrorContent.tsx`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/frontend/src/components/auth/WaitlistErrorContent.tsx)
### Admin UI
Add a simple internal page for beta invites instead of manually editing allowlists.
Possible location:
- `frontend/src/app/(platform)/admin/users/invites/page.tsx`
## Rollout plan
### Phase 1: schema and service layer
- add `BetaInvite`, `PreProvisionedUser`, `PreProvisionedUserSeed`
- implement activation transaction
- keep existing trigger/allowlist in place
### Phase 2: admin creation path
- add admin API or CLI script
- support single invite and CSV bulk upload
- seed Tally/business understanding during invite creation
### Phase 3: signup UX
- invite token lookup
- better invite state messaging
- preserve existing closed beta modal for non-invited traffic
### Phase 4: remove legacy coupling
- disable `auth.users` profile trigger
- simplify `get_or_create_user`
- migrate allowlist logic to invite tables
## Edge cases
### Existing auth user, no platform user
This already happens today. Activation flow should treat it as:
- if auth email matches a pending pre-provisioned invite, bind and activate
- else create a plain `User` only if open-signup feature flag is enabled
### Existing platform user, invited again
Do not create another `PreProvisionedUser`. Create a second invite only if product explicitly wants re-invites. Otherwise reject as duplicate.
### Email changed after invite
Support admin-side reissue:
- revoke old invite
- create new invite with new email
- move seed data forward
Do not automatically bind across unrelated email addresses.
### OAuth signup
OAuth provider signups still work as long as the resulting Supabase email matches the invited email. If the provider returns a different email, activation should fail with a clear UI message.
### Tally data arrives after invite creation
Allow re-seeding before claim if `CoPilotUnderstanding` has not yet been created on activation.
## Migration notes
### Existing beta users
No immediate migration required for already-active users. This system is mainly for future invited users.
### Existing allowlist entries
Backfill them into `BetaInvite` plus `PreProvisionedUser`, then swap the Supabase gating logic to consult invite tables instead of the legacy allowlist table/trigger.
## Recommendation
Implement the pre-provisioning layer first and keep `platform.User` bound to Supabase `auth.users.id`. That is the lowest-risk design because it respects the existing identity model while giving the business exactly what it needs:
- invite someone before signup
- compute and store Tally/onboarding/prompt defaults before first login
- activate those defaults atomically when the user actually creates credentials
## First implementation slice
The smallest useful slice is:
1. Add `BetaInvite`, `PreProvisionedUser`, and `PreProvisionedUserSeed`.
2. Add a backend admin endpoint to create an invite from email plus optional seed payload.
3. Change `/auth/user` activation logic to bind and materialize seeded `Profile`, `UserOnboarding`, and `CoPilotUnderstanding`.
4. Keep the existing signup UI, but validate invite membership before or during signup.
That delivers the core behavior without needing the full admin UI on day one.

View File

@@ -10,6 +10,7 @@
"cssVariables": false,
"prefix": ""
},
"iconLibrary": "radix",
"aliases": {
"components": "@/components",
"utils": "@/lib/utils"

View File

@@ -1,5 +1,11 @@
import { Sidebar } from "@/components/__legacy__/Sidebar";
import { Users, DollarSign, UserSearch, FileText } from "lucide-react";
import {
Users,
DollarSign,
UserSearch,
FileText,
UserPlus,
} from "lucide-react";
import { IconSliders } from "@/components/__legacy__/ui/icons";
@@ -16,6 +22,11 @@ const sidebarLinkGroups = [
href: "/admin/spending",
icon: <DollarSign className="h-6 w-6" />,
},
{
text: "Beta Invites",
href: "/admin/users",
icon: <UserPlus className="h-6 w-6" />,
},
{
text: "User Impersonation",
href: "/admin/impersonation",

View File

@@ -1,33 +1,39 @@
"use server";
import { revalidatePath } from "next/cache";
import BackendApi from "@/lib/autogpt-server-api";
import {
StoreListingsWithVersionsResponse,
SubmissionStatus,
} from "@/lib/autogpt-server-api/types";
getV2GetAdminListingsHistory,
postV2ReviewStoreSubmission,
getV2AdminDownloadAgentFile,
} from "@/app/api/__generated__/endpoints/admin/admin";
import { okData } from "@/app/api/helpers";
import { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
export async function approveAgent(formData: FormData) {
const data = {
store_listing_version_id: formData.get("id") as string,
const storeListingVersionId = formData.get("id") as string;
const comments = formData.get("comments") as string;
await postV2ReviewStoreSubmission(storeListingVersionId, {
store_listing_version_id: storeListingVersionId,
is_approved: true,
comments: formData.get("comments") as string,
};
const api = new BackendApi();
await api.reviewSubmissionAdmin(data.store_listing_version_id, data);
comments,
});
revalidatePath("/admin/marketplace");
}
export async function rejectAgent(formData: FormData) {
const data = {
store_listing_version_id: formData.get("id") as string,
const storeListingVersionId = formData.get("id") as string;
const comments = formData.get("comments") as string;
const internal_comments =
(formData.get("internal_comments") as string) || undefined;
await postV2ReviewStoreSubmission(storeListingVersionId, {
store_listing_version_id: storeListingVersionId,
is_approved: false,
comments: formData.get("comments") as string,
internal_comments: formData.get("internal_comments") as string,
};
const api = new BackendApi();
await api.reviewSubmissionAdmin(data.store_listing_version_id, data);
comments,
internal_comments,
});
revalidatePath("/admin/marketplace");
}
@@ -37,26 +43,18 @@ export async function getAdminListingsWithVersions(
search?: string,
page: number = 1,
pageSize: number = 20,
): Promise<StoreListingsWithVersionsResponse> {
const data: Record<string, any> = {
) {
const response = await getV2GetAdminListingsHistory({
status,
search,
page,
page_size: pageSize,
};
});
if (status) {
data.status = status;
}
if (search) {
data.search = search;
}
const api = new BackendApi();
const response = await api.getAdminListingsWithVersions(data);
return response;
return okData(response);
}
export async function downloadAsAdmin(storeListingVersion: string) {
const api = new BackendApi();
const file = await api.downloadStoreAgentAdmin(storeListingVersion);
return file;
const response = await getV2AdminDownloadAgentFile(storeListingVersion);
return okData(response);
}

View File

@@ -6,10 +6,8 @@ import {
TableHeader,
TableRow,
} from "@/components/__legacy__/ui/table";
import {
StoreSubmission,
SubmissionStatus,
} from "@/lib/autogpt-server-api/types";
import type { StoreSubmissionAdminView } from "@/app/api/__generated__/models/storeSubmissionAdminView";
import type { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
import { PaginationControls } from "../../../../../components/__legacy__/ui/pagination-controls";
import { getAdminListingsWithVersions } from "@/app/(platform)/admin/marketplace/actions";
import { ExpandableRow } from "./ExpandleRow";
@@ -17,12 +15,14 @@ import { SearchAndFilterAdminMarketplace } from "./SearchFilterForm";
// Helper function to get the latest version by version number
const getLatestVersionByNumber = (
versions: StoreSubmission[],
): StoreSubmission | null => {
versions: StoreSubmissionAdminView[] | undefined,
): StoreSubmissionAdminView | null => {
if (!versions || versions.length === 0) return null;
return versions.reduce(
(latest, current) =>
(current.version ?? 0) > (latest.version ?? 1) ? current : latest,
(current.listing_version ?? 0) > (latest.listing_version ?? 1)
? current
: latest,
versions[0],
);
};
@@ -37,12 +37,14 @@ export async function AdminAgentsDataTable({
initialSearch?: string;
}) {
// Server-side data fetching
const { listings, pagination } = await getAdminListingsWithVersions(
const data = await getAdminListingsWithVersions(
initialStatus,
initialSearch,
initialPage,
10,
);
const listings = data?.listings ?? [];
const pagination = data?.pagination;
return (
<div className="space-y-4">
@@ -92,7 +94,7 @@ export async function AdminAgentsDataTable({
<PaginationControls
currentPage={initialPage}
totalPages={pagination.total_pages}
totalPages={pagination?.total_pages ?? 1}
/>
</div>
);

View File

@@ -13,7 +13,7 @@ import {
} from "@/components/__legacy__/ui/dialog";
import { Label } from "@/components/__legacy__/ui/label";
import { Textarea } from "@/components/__legacy__/ui/textarea";
import type { StoreSubmission } from "@/lib/autogpt-server-api/types";
import type { StoreSubmissionAdminView } from "@/app/api/__generated__/models/storeSubmissionAdminView";
import { useRouter } from "next/navigation";
import {
approveAgent,
@@ -23,7 +23,7 @@ import {
export function ApproveRejectButtons({
version,
}: {
version: StoreSubmission;
version: StoreSubmissionAdminView;
}) {
const router = useRouter();
const [isApproveDialogOpen, setIsApproveDialogOpen] = useState(false);
@@ -95,7 +95,7 @@ export function ApproveRejectButtons({
<input
type="hidden"
name="id"
value={version.store_listing_version_id || ""}
value={version.listing_version_id || ""}
/>
<div className="grid gap-4 py-4">
@@ -142,7 +142,7 @@ export function ApproveRejectButtons({
<input
type="hidden"
name="id"
value={version.store_listing_version_id || ""}
value={version.listing_version_id || ""}
/>
<div className="grid gap-4 py-4">

View File

@@ -12,11 +12,9 @@ import {
import { Badge } from "@/components/__legacy__/ui/badge";
import { ChevronDown, ChevronRight } from "lucide-react";
import { formatDistanceToNow } from "date-fns";
import {
type StoreListingWithVersions,
type StoreSubmission,
SubmissionStatus,
} from "@/lib/autogpt-server-api/types";
import type { StoreListingWithVersionsAdminView } from "@/app/api/__generated__/models/storeListingWithVersionsAdminView";
import type { StoreSubmissionAdminView } from "@/app/api/__generated__/models/storeSubmissionAdminView";
import { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
import { ApproveRejectButtons } from "./ApproveRejectButton";
import { DownloadAgentAdminButton } from "./DownloadAgentButton";
@@ -38,8 +36,8 @@ export function ExpandableRow({
listing,
latestVersion,
}: {
listing: StoreListingWithVersions;
latestVersion: StoreSubmission | null;
listing: StoreListingWithVersionsAdminView;
latestVersion: StoreSubmissionAdminView | null;
}) {
const [expanded, setExpanded] = useState(false);
@@ -69,17 +67,17 @@ export function ExpandableRow({
{latestVersion?.status && getStatusBadge(latestVersion.status)}
</TableCell>
<TableCell onClick={() => setExpanded(!expanded)}>
{latestVersion?.date_submitted
? formatDistanceToNow(new Date(latestVersion.date_submitted), {
{latestVersion?.submitted_at
? formatDistanceToNow(new Date(latestVersion.submitted_at), {
addSuffix: true,
})
: "Unknown"}
</TableCell>
<TableCell className="text-right">
<div className="flex justify-end gap-2">
{latestVersion?.store_listing_version_id && (
{latestVersion?.listing_version_id && (
<DownloadAgentAdminButton
storeListingVersionId={latestVersion.store_listing_version_id}
storeListingVersionId={latestVersion.listing_version_id}
/>
)}
@@ -115,14 +113,17 @@ export function ExpandableRow({
</TableRow>
</TableHeader>
<TableBody>
{listing.versions
.sort((a, b) => (b.version ?? 1) - (a.version ?? 0))
{(listing.versions ?? [])
.sort(
(a, b) =>
(b.listing_version ?? 1) - (a.listing_version ?? 0),
)
.map((version) => (
<TableRow key={version.store_listing_version_id}>
<TableRow key={version.listing_version_id}>
<TableCell>
v{version.version || "?"}
{version.store_listing_version_id ===
listing.active_version_id && (
v{version.listing_version || "?"}
{version.listing_version_id ===
listing.active_listing_version_id && (
<Badge className="ml-2 bg-blue-500">Active</Badge>
)}
</TableCell>
@@ -131,9 +132,9 @@ export function ExpandableRow({
{version.changes_summary || "No summary"}
</TableCell>
<TableCell>
{version.date_submitted
{version.submitted_at
? formatDistanceToNow(
new Date(version.date_submitted),
new Date(version.submitted_at),
{ addSuffix: true },
)
: "Unknown"}
@@ -182,10 +183,10 @@ export function ExpandableRow({
{/* <TableCell>{version.categories.join(", ")}</TableCell> */}
<TableCell className="text-right">
<div className="flex justify-end gap-2">
{version.store_listing_version_id && (
{version.listing_version_id && (
<DownloadAgentAdminButton
storeListingVersionId={
version.store_listing_version_id
version.listing_version_id
}
/>
)}

View File

@@ -12,7 +12,7 @@ import {
SelectTrigger,
SelectValue,
} from "@/components/__legacy__/ui/select";
import { SubmissionStatus } from "@/lib/autogpt-server-api/types";
import { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
export function SearchAndFilterAdminMarketplace({
initialSearch,

View File

@@ -1,11 +1,11 @@
import { withRoleAccess } from "@/lib/withRoleAccess";
import { Suspense } from "react";
import type { SubmissionStatus } from "@/lib/autogpt-server-api/types";
import type { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
import { AdminAgentsDataTable } from "./components/AdminAgentsDataTable";
type MarketplaceAdminPageSearchParams = {
page?: string;
status?: string;
status?: SubmissionStatus;
search?: string;
};
@@ -15,7 +15,7 @@ async function AdminMarketplaceDashboard({
searchParams: MarketplaceAdminPageSearchParams;
}) {
const page = searchParams.page ? Number.parseInt(searchParams.page) : 1;
const status = searchParams.status as SubmissionStatus | undefined;
const status = searchParams.status;
const search = searchParams.search;
return (

View File

@@ -0,0 +1,80 @@
"use client";
import { Card } from "@/components/atoms/Card/Card";
import { BulkInviteForm } from "../BulkInviteForm/BulkInviteForm";
import { InviteUserForm } from "../InviteUserForm/InviteUserForm";
import { InvitedUsersTable } from "../InvitedUsersTable/InvitedUsersTable";
import { useAdminUsersPage } from "../../useAdminUsersPage";
export function AdminUsersPage() {
const {
email,
name,
bulkInviteFile,
bulkInviteInputKey,
lastBulkInviteResult,
invitedUsers,
isLoadingInvitedUsers,
isRefreshingInvitedUsers,
isCreatingInvite,
isBulkInviting,
pendingInviteAction,
setEmail,
setName,
handleBulkInviteFileChange,
handleBulkInviteSubmit,
handleCreateInvite,
handleRetryTally,
handleRevoke,
} = useAdminUsersPage();
return (
<div className="mx-auto flex max-w-7xl flex-col gap-6 p-6">
<div className="flex flex-col gap-2">
<h1 className="text-3xl font-bold text-zinc-900">Beta Invites</h1>
<p className="max-w-3xl text-sm text-zinc-600">
Pre-provision beta users before they sign up. Invites store the
platform-side record, run Tally understanding extraction, and activate
the real account on the user&apos;s first authenticated request.
</p>
</div>
<div className="grid gap-6 xl:grid-cols-[24rem,1fr]">
<div className="flex flex-col gap-6">
<Card className="border border-zinc-200 shadow-sm">
<InviteUserForm
email={email}
name={name}
isSubmitting={isCreatingInvite}
onEmailChange={setEmail}
onNameChange={setName}
onSubmit={handleCreateInvite}
/>
</Card>
<Card className="border border-zinc-200 shadow-sm">
<BulkInviteForm
selectedFile={bulkInviteFile}
inputKey={bulkInviteInputKey}
isSubmitting={isBulkInviting}
lastResult={lastBulkInviteResult}
onFileChange={handleBulkInviteFileChange}
onSubmit={handleBulkInviteSubmit}
/>
</Card>
</div>
<Card className="border border-zinc-200 shadow-sm">
<InvitedUsersTable
invitedUsers={invitedUsers}
isLoading={isLoadingInvitedUsers}
isRefreshing={isRefreshingInvitedUsers}
pendingInviteAction={pendingInviteAction}
onRetryTally={handleRetryTally}
onRevoke={handleRevoke}
/>
</Card>
</div>
</div>
);
}

View File

@@ -0,0 +1,131 @@
"use client";
import type { BulkInvitedUsersResponse } from "@/app/api/__generated__/models/bulkInvitedUsersResponse";
import { Badge } from "@/components/atoms/Badge/Badge";
import { Button } from "@/components/atoms/Button/Button";
import type { FormEvent } from "react";
interface Props {
selectedFile: File | null;
inputKey: number;
isSubmitting: boolean;
lastResult: BulkInvitedUsersResponse | null;
onFileChange: (file: File | null) => void;
onSubmit: (event: FormEvent<HTMLFormElement>) => void;
}
function getStatusVariant(status: "CREATED" | "SKIPPED" | "ERROR") {
if (status === "CREATED") {
return "success";
}
if (status === "ERROR") {
return "error";
}
return "info";
}
export function BulkInviteForm({
selectedFile,
inputKey,
isSubmitting,
lastResult,
onFileChange,
onSubmit,
}: Props) {
return (
<form className="flex flex-col gap-4" onSubmit={onSubmit}>
<div className="flex flex-col gap-1">
<h2 className="text-xl font-semibold text-zinc-900">Bulk invite</h2>
<p className="text-sm text-zinc-600">
Upload a <span className="font-medium text-zinc-800">.txt</span> file
with one email per line, or a{" "}
<span className="font-medium text-zinc-800">.csv</span> with
<span className="font-medium text-zinc-800"> email</span> and optional
<span className="font-medium text-zinc-800"> name</span> columns.
</p>
</div>
<label className="flex cursor-pointer flex-col gap-2 rounded-2xl border border-dashed border-zinc-300 bg-zinc-50 px-4 py-5 text-sm text-zinc-600 transition-colors hover:border-zinc-400 hover:bg-zinc-100">
<span className="font-medium text-zinc-900">
{selectedFile ? selectedFile.name : "Choose invite file"}
</span>
<span>Maximum 500 rows, UTF-8 encoded.</span>
<input
key={inputKey}
type="file"
accept=".txt,.csv,text/plain,text/csv"
disabled={isSubmitting}
className="hidden"
onChange={(event) =>
onFileChange(event.target.files?.item(0) ?? null)
}
/>
</label>
<Button
type="submit"
variant="primary"
loading={isSubmitting}
disabled={!selectedFile}
className="w-full"
>
{isSubmitting ? "Uploading invites..." : "Upload invite file"}
</Button>
{lastResult ? (
<div className="flex flex-col gap-3 rounded-2xl border border-zinc-200 bg-zinc-50 p-4">
<div className="grid grid-cols-3 gap-2 text-center">
<div className="rounded-xl bg-white px-3 py-2">
<div className="text-lg font-semibold text-zinc-900">
{lastResult.created_count}
</div>
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
Created
</div>
</div>
<div className="rounded-xl bg-white px-3 py-2">
<div className="text-lg font-semibold text-zinc-900">
{lastResult.skipped_count}
</div>
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
Skipped
</div>
</div>
<div className="rounded-xl bg-white px-3 py-2">
<div className="text-lg font-semibold text-zinc-900">
{lastResult.error_count}
</div>
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
Errors
</div>
</div>
</div>
<div className="max-h-64 overflow-y-auto rounded-xl border border-zinc-200 bg-white">
<div className="flex flex-col divide-y divide-zinc-100">
{lastResult.results.map((row) => (
<div
key={`${row.row_number}-${row.email ?? row.message}`}
className="flex items-start gap-3 px-3 py-3"
>
<Badge variant={getStatusVariant(row.status)} size="small">
{row.status}
</Badge>
<div className="flex min-w-0 flex-1 flex-col gap-1">
<span className="text-sm font-medium text-zinc-900">
Row {row.row_number}
{row.email ? ` · ${row.email}` : ""}
</span>
<span className="text-xs text-zinc-500">{row.message}</span>
</div>
</div>
))}
</div>
</div>
</div>
) : null}
</form>
);
}

View File

@@ -0,0 +1,66 @@
"use client";
import { Button } from "@/components/atoms/Button/Button";
import { Input } from "@/components/atoms/Input/Input";
import type { FormEvent } from "react";
interface Props {
email: string;
name: string;
isSubmitting: boolean;
onEmailChange: (value: string) => void;
onNameChange: (value: string) => void;
onSubmit: (event: FormEvent<HTMLFormElement>) => void;
}
export function InviteUserForm({
email,
name,
isSubmitting,
onEmailChange,
onNameChange,
onSubmit,
}: Props) {
return (
<form className="flex flex-col gap-4" onSubmit={onSubmit}>
<div className="flex flex-col gap-1">
<h2 className="text-xl font-semibold text-zinc-900">Create invite</h2>
<p className="text-sm text-zinc-600">
The invite is stored immediately, then Tally pre-seeding starts in the
background.
</p>
</div>
<Input
id="invite-email"
label="Email"
type="email"
value={email}
placeholder="jane@example.com"
autoComplete="email"
disabled={isSubmitting}
onChange={(event) => onEmailChange(event.target.value)}
/>
<Input
id="invite-name"
label="Name"
type="text"
value={name}
placeholder="Jane Doe"
disabled={isSubmitting}
onChange={(event) => onNameChange(event.target.value)}
/>
<Button
type="submit"
variant="primary"
loading={isSubmitting}
disabled={!email.trim()}
className="w-full"
>
{isSubmitting ? "Creating invite..." : "Create invite"}
</Button>
</form>
);
}

View File

@@ -0,0 +1,209 @@
"use client";
import type { InvitedUserResponse } from "@/app/api/__generated__/models/invitedUserResponse";
import { Badge } from "@/components/atoms/Badge/Badge";
import { Button } from "@/components/atoms/Button/Button";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/__legacy__/ui/table";
interface Props {
invitedUsers: InvitedUserResponse[];
isLoading: boolean;
isRefreshing: boolean;
pendingInviteAction: string | null;
onRetryTally: (invitedUserId: string) => void;
onRevoke: (invitedUserId: string) => void;
}
function getInviteBadgeVariant(status: InvitedUserResponse["status"]) {
if (status === "CLAIMED") {
return "success";
}
if (status === "REVOKED") {
return "error";
}
return "info";
}
function getTallyBadgeVariant(status: InvitedUserResponse["tally_status"]) {
if (status === "READY") {
return "success";
}
if (status === "FAILED") {
return "error";
}
return "info";
}
function formatDate(value: Date | undefined) {
if (!value) {
return "-";
}
return value.toLocaleString();
}
function getTallySummary(invitedUser: InvitedUserResponse) {
if (invitedUser.tally_status === "FAILED" && invitedUser.tally_error) {
return invitedUser.tally_error;
}
if (invitedUser.tally_status === "READY" && invitedUser.tally_understanding) {
return "Stored and ready for activation";
}
if (invitedUser.tally_status === "READY") {
return "No matching Tally submission found";
}
if (invitedUser.tally_status === "RUNNING") {
return "Extraction in progress";
}
return "Waiting to run";
}
function isActionPending(
pendingInviteAction: string | null,
action: "retry" | "revoke",
invitedUserId: string,
) {
return pendingInviteAction === `${action}:${invitedUserId}`;
}
export function InvitedUsersTable({
invitedUsers,
isLoading,
isRefreshing,
pendingInviteAction,
onRetryTally,
onRevoke,
}: Props) {
return (
<div className="flex flex-col gap-4">
<div className="flex items-center justify-between gap-4">
<div className="flex flex-col gap-1">
<h2 className="text-xl font-semibold text-zinc-900">Invited users</h2>
<p className="text-sm text-zinc-600">
Live invite state, claim status, and Tally pre-seeding progress.
</p>
</div>
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
{isRefreshing ? "Refreshing" : `${invitedUsers.length} total`}
</span>
</div>
<div className="overflow-hidden rounded-2xl border border-zinc-200">
<Table>
<TableHeader className="bg-zinc-50">
<TableRow>
<TableHead>Email</TableHead>
<TableHead>Name</TableHead>
<TableHead>Invite</TableHead>
<TableHead>Tally</TableHead>
<TableHead>Claimed User</TableHead>
<TableHead>Created</TableHead>
<TableHead className="text-right">Actions</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{isLoading ? (
<TableRow>
<TableCell
colSpan={7}
className="py-10 text-center text-zinc-500"
>
Loading invited users...
</TableCell>
</TableRow>
) : invitedUsers.length === 0 ? (
<TableRow>
<TableCell
colSpan={7}
className="py-10 text-center text-zinc-500"
>
No invited users yet
</TableCell>
</TableRow>
) : (
invitedUsers.map((invitedUser) => (
<TableRow key={invitedUser.id} className="align-top">
<TableCell className="font-medium text-zinc-900">
{invitedUser.email}
</TableCell>
<TableCell>{invitedUser.name || "-"}</TableCell>
<TableCell>
<Badge variant={getInviteBadgeVariant(invitedUser.status)}>
{invitedUser.status}
</Badge>
</TableCell>
<TableCell>
<div className="flex max-w-xs flex-col gap-2">
<Badge
variant={getTallyBadgeVariant(invitedUser.tally_status)}
>
{invitedUser.tally_status}
</Badge>
<span className="text-xs text-zinc-500">
{getTallySummary(invitedUser)}
</span>
<span className="text-xs text-zinc-400">
{formatDate(invitedUser.tally_computed_at ?? undefined)}
</span>
</div>
</TableCell>
<TableCell className="font-mono text-xs text-zinc-500">
{invitedUser.auth_user_id || "-"}
</TableCell>
<TableCell className="text-sm text-zinc-500">
{formatDate(invitedUser.created_at)}
</TableCell>
<TableCell>
<div className="flex justify-end gap-2">
<Button
variant="outline"
size="small"
disabled={invitedUser.status === "REVOKED"}
loading={isActionPending(
pendingInviteAction,
"retry",
invitedUser.id,
)}
onClick={() => onRetryTally(invitedUser.id)}
>
Retry Tally
</Button>
<Button
variant="secondary"
size="small"
disabled={invitedUser.status !== "INVITED"}
loading={isActionPending(
pendingInviteAction,
"revoke",
invitedUser.id,
)}
onClick={() => onRevoke(invitedUser.id)}
>
Revoke
</Button>
</div>
</TableCell>
</TableRow>
))
)}
</TableBody>
</Table>
</div>
</div>
);
}

View File

@@ -1,16 +1,11 @@
import { withRoleAccess } from "@/lib/withRoleAccess";
import React from "react";
import { AdminUsersPage } from "./components/AdminUsersPage/AdminUsersPage";
function AdminUsers() {
return (
<div>
<h1>Users Dashboard</h1>
{/* Add your admin-only content here */}
</div>
);
return <AdminUsersPage />;
}
export default async function AdminUsersPage() {
export default async function AdminUsersRoute() {
"use server";
const withAdminAccess = await withRoleAccess(["admin"]);
const ProtectedAdminUsers = await withAdminAccess(AdminUsers);

View File

@@ -0,0 +1,201 @@
"use client";
import type { BulkInvitedUsersResponse } from "@/app/api/__generated__/models/bulkInvitedUsersResponse";
import { okData } from "@/app/api/helpers";
import {
getGetV2ListInvitedUsersQueryKey,
useGetV2ListInvitedUsers,
usePostV2BulkCreateInvitedUsers,
usePostV2CreateInvitedUser,
usePostV2RetryInvitedUserTally,
usePostV2RevokeInvitedUser,
} from "@/app/api/__generated__/endpoints/admin/admin";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { ApiError } from "@/lib/autogpt-server-api/helpers";
import { useQueryClient } from "@tanstack/react-query";
import { type FormEvent, useState } from "react";
function getErrorMessage(error: unknown) {
if (error instanceof ApiError) {
return error.message;
}
if (error instanceof Error) {
return error.message;
}
return "Something went wrong";
}
export function useAdminUsersPage() {
const queryClient = useQueryClient();
const { toast } = useToast();
const [email, setEmail] = useState("");
const [name, setName] = useState("");
const [bulkInviteFile, setBulkInviteFile] = useState<File | null>(null);
const [bulkInviteInputKey, setBulkInviteInputKey] = useState(0);
const [lastBulkInviteResult, setLastBulkInviteResult] =
useState<BulkInvitedUsersResponse | null>(null);
const [pendingInviteAction, setPendingInviteAction] = useState<string | null>(
null,
);
const invitedUsersQuery = useGetV2ListInvitedUsers({
query: {
select: okData,
refetchInterval: 5000,
},
});
const createInvitedUserMutation = usePostV2CreateInvitedUser({
mutation: {
onSuccess: async () => {
setEmail("");
setName("");
await queryClient.invalidateQueries({
queryKey: getGetV2ListInvitedUsersQueryKey(),
});
toast({
title: "Invited user created",
variant: "default",
});
},
onError: (error) => {
toast({
title: getErrorMessage(error),
variant: "destructive",
});
},
},
});
const bulkCreateInvitedUsersMutation = usePostV2BulkCreateInvitedUsers({
mutation: {
onSuccess: async (response) => {
const result = okData(response) ?? null;
setBulkInviteFile(null);
setBulkInviteInputKey((currentValue) => currentValue + 1);
setLastBulkInviteResult(result);
await queryClient.invalidateQueries({
queryKey: getGetV2ListInvitedUsersQueryKey(),
});
toast({
title: result
? `${result.created_count} invites created`
: "Bulk invite upload complete",
variant: "default",
});
},
onError: (error) => {
toast({
title: getErrorMessage(error),
variant: "destructive",
});
},
},
});
const retryInvitedUserTallyMutation = usePostV2RetryInvitedUserTally({
mutation: {
onSuccess: async () => {
setPendingInviteAction(null);
await queryClient.invalidateQueries({
queryKey: getGetV2ListInvitedUsersQueryKey(),
});
toast({
title: "Tally pre-seeding restarted",
variant: "default",
});
},
onError: (error) => {
setPendingInviteAction(null);
toast({
title: getErrorMessage(error),
variant: "destructive",
});
},
},
});
const revokeInvitedUserMutation = usePostV2RevokeInvitedUser({
mutation: {
onSuccess: async () => {
setPendingInviteAction(null);
await queryClient.invalidateQueries({
queryKey: getGetV2ListInvitedUsersQueryKey(),
});
toast({
title: "Invite revoked",
variant: "default",
});
},
onError: (error) => {
setPendingInviteAction(null);
toast({
title: getErrorMessage(error),
variant: "destructive",
});
},
},
});
function handleCreateInvite(event: FormEvent<HTMLFormElement>) {
event.preventDefault();
createInvitedUserMutation.mutate({
data: {
email,
name: name.trim() || null,
},
});
}
function handleRetryTally(invitedUserId: string) {
setPendingInviteAction(`retry:${invitedUserId}`);
retryInvitedUserTallyMutation.mutate({ invitedUserId });
}
function handleBulkInviteFileChange(file: File | null) {
setBulkInviteFile(file);
}
function handleBulkInviteSubmit(event: FormEvent<HTMLFormElement>) {
event.preventDefault();
if (!bulkInviteFile) {
return;
}
bulkCreateInvitedUsersMutation.mutate({
data: {
file: bulkInviteFile,
},
});
}
function handleRevoke(invitedUserId: string) {
setPendingInviteAction(`revoke:${invitedUserId}`);
revokeInvitedUserMutation.mutate({ invitedUserId });
}
return {
email,
name,
bulkInviteFile,
bulkInviteInputKey,
lastBulkInviteResult,
invitedUsers: invitedUsersQuery.data?.invited_users ?? [],
isLoadingInvitedUsers: invitedUsersQuery.isLoading,
isRefreshingInvitedUsers: invitedUsersQuery.isFetching,
isCreatingInvite: createInvitedUserMutation.isPending,
isBulkInviting: bulkCreateInvitedUsersMutation.isPending,
pendingInviteAction,
setEmail,
setName,
handleBulkInviteFileChange,
handleBulkInviteSubmit,
handleCreateInvite,
handleRetryTally,
handleRevoke,
};
}

View File

@@ -151,6 +151,9 @@ export function ChatInput({
onFilesSelected={handleFilesSelected}
disabled={isBusy}
/>
</PromptInputTools>
<div className="flex items-center gap-4">
{showMicButton && (
<RecordingButton
isRecording={isRecording}
@@ -160,13 +163,12 @@ export function ChatInput({
onClick={toggleRecording}
/>
)}
</PromptInputTools>
{isStreaming ? (
<PromptInputSubmit status="streaming" onStop={onStop} />
) : (
<PromptInputSubmit disabled={!canSend} />
)}
{isStreaming ? (
<PromptInputSubmit status="streaming" onStop={onStop} />
) : (
<PromptInputSubmit disabled={!canSend} />
)}
</div>
</PromptInputFooter>
</InputGroup>
</form>

View File

@@ -28,10 +28,9 @@ export function RecordingButton({
disabled={disabled}
onClick={onClick}
className={cn(
"border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
"border-0 bg-white text-zinc-500 hover:bg-zinc-50 hover:text-zinc-700",
disabled && "opacity-40",
isRecording &&
"animate-pulse border-red-500 bg-red-500 text-white hover:border-red-600 hover:bg-red-600",
isRecording && "animate-pulse bg-red-500 text-white hover:bg-red-600",
isTranscribing && "bg-zinc-100 text-zinc-400",
isStreaming && "opacity-40",
)}

View File

@@ -5,15 +5,19 @@ import {
} from "@/components/ai-elements/conversation";
import { Message, MessageContent } from "@/components/ai-elements/message";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { FileUIPart, UIDataTypes, UIMessage, UITools } from "ai";
import { FileUIPart, ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
import { TOOL_PART_PREFIX } from "../JobStatsBar/constants";
import { TurnStatsBar } from "../JobStatsBar/TurnStatsBar";
import { parseSpecialMarkers } from "./helpers";
import { AssistantMessageActions } from "./components/AssistantMessageActions";
import { CollapsedToolGroup } from "./components/CollapsedToolGroup";
import { MessageAttachments } from "./components/MessageAttachments";
import { MessagePartRenderer } from "./components/MessagePartRenderer";
import { ReasoningCollapse } from "./components/ReasoningCollapse";
import { ThinkingIndicator } from "./components/ThinkingIndicator";
type MessagePart = UIMessage<unknown, UIDataTypes, UITools>["parts"][number];
interface Props {
messages: UIMessage<unknown, UIDataTypes, UITools>[];
status: string;
@@ -23,6 +27,132 @@ interface Props {
sessionID?: string | null;
}
function isCompletedToolPart(part: MessagePart): part is ToolUIPart {
return (
part.type.startsWith("tool-") &&
"state" in part &&
(part.state === "output-available" || part.state === "output-error")
);
}
type RenderSegment =
| { kind: "part"; part: MessagePart; index: number }
| { kind: "collapsed-group"; parts: ToolUIPart[] };
// Tool types that have custom renderers and should NOT be collapsed
const CUSTOM_TOOL_TYPES = new Set([
"tool-find_block",
"tool-find_agent",
"tool-find_library_agent",
"tool-search_docs",
"tool-get_doc_page",
"tool-run_block",
"tool-run_mcp_tool",
"tool-run_agent",
"tool-schedule_agent",
"tool-create_agent",
"tool-edit_agent",
"tool-view_agent_output",
"tool-search_feature_requests",
"tool-create_feature_request",
]);
/**
* Groups consecutive completed generic tool parts into collapsed segments.
* Non-generic tools (those with custom renderers) and active/streaming tools
* are left as individual parts.
*/
function buildRenderSegments(
parts: MessagePart[],
baseIndex = 0,
): RenderSegment[] {
const segments: RenderSegment[] = [];
let pendingGroup: Array<{ part: ToolUIPart; index: number }> | null = null;
function flushGroup() {
if (!pendingGroup) return;
if (pendingGroup.length >= 2) {
segments.push({
kind: "collapsed-group",
parts: pendingGroup.map((p) => p.part),
});
} else {
for (const p of pendingGroup) {
segments.push({ kind: "part", part: p.part, index: p.index });
}
}
pendingGroup = null;
}
parts.forEach((part, i) => {
const absoluteIndex = baseIndex + i;
const isGenericCompletedTool =
isCompletedToolPart(part) && !CUSTOM_TOOL_TYPES.has(part.type);
if (isGenericCompletedTool) {
if (!pendingGroup) pendingGroup = [];
pendingGroup.push({ part: part as ToolUIPart, index: absoluteIndex });
} else {
flushGroup();
segments.push({ kind: "part", part, index: absoluteIndex });
}
});
flushGroup();
return segments;
}
/**
* For finalized assistant messages, split parts into "reasoning" (intermediate
* text + tools before the final response) and "response" (final text after the
* last tool). If there are no tools, everything is response.
*/
function splitReasoningAndResponse(parts: MessagePart[]): {
reasoning: MessagePart[];
response: MessagePart[];
} {
const lastToolIndex = parts.findLastIndex((p) => p.type.startsWith("tool-"));
// No tools → everything is response
if (lastToolIndex === -1) {
return { reasoning: [], response: parts };
}
// Check if there's any text after the last tool
const hasResponseAfterTools = parts
.slice(lastToolIndex + 1)
.some((p) => p.type === "text");
if (!hasResponseAfterTools) {
// No final text response → don't collapse anything
return { reasoning: [], response: parts };
}
return {
reasoning: parts.slice(0, lastToolIndex + 1),
response: parts.slice(lastToolIndex + 1),
};
}
function renderSegments(
segments: RenderSegment[],
messageID: string,
): React.ReactNode[] {
return segments.map((seg, segIdx) => {
if (seg.kind === "collapsed-group") {
return <CollapsedToolGroup key={`group-${segIdx}`} parts={seg.parts} />;
}
return (
<MessagePartRenderer
key={`${messageID}-${seg.index}`}
part={seg.part}
messageID={messageID}
partIndex={seg.index}
/>
);
});
}
/** Collect all messages belonging to a turn: the user message + every
* assistant message up to (but not including) the next user message. */
function getTurnMessages(
@@ -119,6 +249,24 @@ export function ChatMessagesContainer({
(p): p is FileUIPart => p.type === "file",
);
// For finalized assistant messages, split into reasoning + response.
// During streaming, show everything normally with tool collapsing.
const isFinalized =
message.role === "assistant" && !isCurrentlyStreaming;
const { reasoning, response } = isFinalized
? splitReasoningAndResponse(message.parts)
: { reasoning: [] as MessagePart[], response: message.parts };
const hasReasoning = reasoning.length > 0;
const responseStartIndex = message.parts.length - response.length;
const responseSegments =
message.role === "assistant"
? buildRenderSegments(response, responseStartIndex)
: null;
const reasoningSegments = hasReasoning
? buildRenderSegments(reasoning, 0)
: null;
return (
<Message from={message.role} key={message.id}>
<MessageContent
@@ -128,14 +276,21 @@ export function ChatMessagesContainer({
"group-[.is-assistant]:bg-transparent group-[.is-assistant]:text-slate-900"
}
>
{message.parts.map((part, i) => (
<MessagePartRenderer
key={`${message.id}-${i}`}
part={part}
messageID={message.id}
partIndex={i}
/>
))}
{hasReasoning && reasoningSegments && (
<ReasoningCollapse>
{renderSegments(reasoningSegments, message.id)}
</ReasoningCollapse>
)}
{responseSegments
? renderSegments(responseSegments, message.id)
: message.parts.map((part, i) => (
<MessagePartRenderer
key={`${message.id}-${i}`}
part={part}
messageID={message.id}
partIndex={i}
/>
))}
{isLastInTurn && !isCurrentlyStreaming && (
<TurnStatsBar
turnMessages={getTurnMessages(messages, messageIndex)}

View File

@@ -0,0 +1,152 @@
"use client";
import { useId, useState } from "react";
import {
ArrowsClockwiseIcon,
CaretRightIcon,
CheckCircleIcon,
FileIcon,
FilesIcon,
GearIcon,
GlobeIcon,
ListChecksIcon,
MagnifyingGlassIcon,
MonitorIcon,
PencilSimpleIcon,
TerminalIcon,
TrashIcon,
WarningDiamondIcon,
} from "@phosphor-icons/react";
import type { ToolUIPart } from "ai";
import {
type ToolCategory,
extractToolName,
getAnimationText,
getToolCategory,
} from "../../../tools/GenericTool/helpers";
interface Props {
parts: ToolUIPart[];
}
/** Category icon matching GenericTool's ToolIcon for completed states. */
function EntryIcon({
category,
isError,
}: {
category: ToolCategory;
isError: boolean;
}) {
if (isError) {
return (
<WarningDiamondIcon size={14} weight="regular" className="text-red-500" />
);
}
const iconClass = "text-green-500";
switch (category) {
case "bash":
return <TerminalIcon size={14} weight="regular" className={iconClass} />;
case "web":
return <GlobeIcon size={14} weight="regular" className={iconClass} />;
case "browser":
return <MonitorIcon size={14} weight="regular" className={iconClass} />;
case "file-read":
case "file-write":
return <FileIcon size={14} weight="regular" className={iconClass} />;
case "file-delete":
return <TrashIcon size={14} weight="regular" className={iconClass} />;
case "file-list":
return <FilesIcon size={14} weight="regular" className={iconClass} />;
case "search":
return (
<MagnifyingGlassIcon size={14} weight="regular" className={iconClass} />
);
case "edit":
return (
<PencilSimpleIcon size={14} weight="regular" className={iconClass} />
);
case "todo":
return (
<ListChecksIcon size={14} weight="regular" className={iconClass} />
);
case "compaction":
return (
<ArrowsClockwiseIcon size={14} weight="regular" className={iconClass} />
);
default:
return <GearIcon size={14} weight="regular" className={iconClass} />;
}
}
export function CollapsedToolGroup({ parts }: Props) {
const [expanded, setExpanded] = useState(false);
const panelId = useId();
const errorCount = parts.filter((p) => p.state === "output-error").length;
const label =
errorCount > 0
? `${parts.length} tool calls (${errorCount} failed)`
: `${parts.length} tool calls completed`;
return (
<div className="py-1">
<button
type="button"
onClick={() => setExpanded(!expanded)}
aria-expanded={expanded}
aria-controls={panelId}
className="flex items-center gap-1.5 text-sm text-muted-foreground transition-colors hover:text-foreground"
>
<CaretRightIcon
size={12}
weight="bold"
className={
"transition-transform duration-150 " + (expanded ? "rotate-90" : "")
}
/>
{errorCount > 0 ? (
<WarningDiamondIcon
size={14}
weight="regular"
className="text-red-500"
/>
) : (
<CheckCircleIcon
size={14}
weight="regular"
className="text-green-500"
/>
)}
<span>{label}</span>
</button>
{expanded && (
<div
id={panelId}
className="ml-5 mt-1 space-y-0.5 border-l border-neutral-200 pl-3"
>
{parts.map((part) => {
const toolName = extractToolName(part);
const category = getToolCategory(toolName);
const text = getAnimationText(part, category);
const isError = part.state === "output-error";
return (
<div
key={part.toolCallId}
className={
"flex items-center gap-1.5 text-xs " +
(isError ? "text-red-500" : "text-muted-foreground")
}
>
<EntryIcon category={category} isError={isError} />
<span>{text}</span>
</div>
);
})}
</div>
)}
</div>
);
}

View File

@@ -15,6 +15,7 @@ export function FeedbackModal({ isOpen, onSubmit, onCancel }: Props) {
const [comment, setComment] = useState("");
function handleSubmit() {
if (!comment.trim()) return;
onSubmit(comment);
setComment("");
}
@@ -36,7 +37,7 @@ export function FeedbackModal({ isOpen, onSubmit, onCancel }: Props) {
>
<Dialog.Content>
<div className="mx-auto w-[95%] space-y-4">
<p className="text-sm text-slate-600">
<p className="text-sm text-muted-foreground">
Your feedback helps us improve. Share details below.
</p>
<Textarea
@@ -48,12 +49,18 @@ export function FeedbackModal({ isOpen, onSubmit, onCancel }: Props) {
className="resize-none"
/>
<div className="flex items-center justify-between">
<p className="text-xs text-slate-400">{comment.length}/2000</p>
<p className="text-xs text-muted-foreground">
{comment.length}/2000
</p>
<div className="flex gap-2">
<Button variant="outline" size="sm" onClick={handleClose}>
Cancel
</Button>
<Button size="sm" onClick={handleSubmit}>
<Button
size="sm"
onClick={handleSubmit}
disabled={!comment.trim()}
>
Submit feedback
</Button>
</div>

View File

@@ -10,6 +10,7 @@ import {
SearchFeatureRequestsTool,
} from "../../../tools/FeatureRequests/FeatureRequests";
import { FindAgentsTool } from "../../../tools/FindAgents/FindAgents";
import { FolderTool } from "../../../tools/FolderTool/FolderTool";
import { FindBlocksTool } from "../../../tools/FindBlocks/FindBlocks";
import { GenericTool } from "../../../tools/GenericTool/GenericTool";
import { RunAgentTool } from "../../../tools/RunAgent/RunAgent";
@@ -145,6 +146,13 @@ export function MessagePartRenderer({ part, messageID, partIndex }: Props) {
return <SearchFeatureRequestsTool key={key} part={part as ToolUIPart} />;
case "tool-create_feature_request":
return <CreateFeatureRequestTool key={key} part={part as ToolUIPart} />;
case "tool-create_folder":
case "tool-list_folders":
case "tool-update_folder":
case "tool-move_folder":
case "tool-delete_folder":
case "tool-move_agents_to_folder":
return <FolderTool key={key} part={part as ToolUIPart} />;
default:
// Render a generic tool indicator for SDK built-in
// tools (Read, Glob, Grep, etc.) or any unrecognized tool

View File

@@ -0,0 +1,27 @@
"use client";
import { LightbulbIcon } from "@phosphor-icons/react";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
interface Props {
children: React.ReactNode;
}
export function ReasoningCollapse({ children }: Props) {
return (
<Dialog title="Reasoning">
<Dialog.Trigger>
<button
type="button"
className="flex items-center gap-1 text-xs text-zinc-500 transition-colors hover:text-zinc-700"
>
<LightbulbIcon size={12} weight="bold" />
<span>Show reasoning</span>
</button>
</Dialog.Trigger>
<Dialog.Content>
<div className="space-y-1">{children}</div>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -3,6 +3,7 @@ import {
getGetV2ListSessionsQueryKey,
useDeleteV2DeleteSession,
useGetV2ListSessions,
usePatchV2UpdateSessionTitle,
} from "@/app/api/__generated__/endpoints/chat/chat";
import { Button } from "@/components/atoms/Button/Button";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
@@ -17,7 +18,6 @@ import { toast } from "@/components/molecules/Toast/use-toast";
import {
Sidebar,
SidebarContent,
SidebarFooter,
SidebarHeader,
SidebarTrigger,
useSidebar,
@@ -25,8 +25,9 @@ import {
import { cn } from "@/lib/utils";
import { DotsThree, PlusCircleIcon, PlusIcon } from "@phosphor-icons/react";
import { useQueryClient } from "@tanstack/react-query";
import { motion } from "framer-motion";
import { AnimatePresence, motion } from "framer-motion";
import { parseAsString, useQueryState } from "nuqs";
import { useEffect, useRef, useState } from "react";
import { useCopilotUIStore } from "../../store";
import { DeleteChatDialog } from "../DeleteChatDialog/DeleteChatDialog";
@@ -65,6 +66,39 @@ export function ChatSidebar() {
},
});
const [editingSessionId, setEditingSessionId] = useState<string | null>(null);
const [editingTitle, setEditingTitle] = useState("");
const renameInputRef = useRef<HTMLInputElement>(null);
const renameCancelledRef = useRef(false);
const { mutate: renameSession } = usePatchV2UpdateSessionTitle({
mutation: {
onSuccess: () => {
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
setEditingSessionId(null);
},
onError: (error) => {
toast({
title: "Failed to rename chat",
description:
error instanceof Error ? error.message : "An error occurred",
variant: "destructive",
});
setEditingSessionId(null);
},
},
});
// Auto-focus the rename input when editing starts
useEffect(() => {
if (editingSessionId && renameInputRef.current) {
renameInputRef.current.focus();
renameInputRef.current.select();
}
}, [editingSessionId]);
const sessions =
sessionsResponse?.status === 200 ? sessionsResponse.data.sessions : [];
@@ -76,6 +110,26 @@ export function ChatSidebar() {
setSessionId(id);
}
function handleRenameClick(
e: React.MouseEvent,
id: string,
title: string | null | undefined,
) {
e.stopPropagation();
renameCancelledRef.current = false;
setEditingSessionId(id);
setEditingTitle(title || "");
}
function handleRenameSubmit(id: string) {
const trimmed = editingTitle.trim();
if (trimmed) {
renameSession({ sessionId: id, data: { title: trimmed } });
} else {
setEditingSessionId(null);
}
}
function handleDeleteClick(
e: React.MouseEvent,
id: string,
@@ -160,29 +214,42 @@ export function ChatSidebar() {
</motion.div>
</SidebarHeader>
)}
{!isCollapsed && (
<SidebarHeader className="shrink-0 px-4 pb-4 pt-4 shadow-[0_4px_6px_-1px_rgba(0,0,0,0.05)]">
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.1 }}
className="flex flex-col gap-3 px-3"
>
<div className="flex items-center justify-between">
<Text variant="h3" size="body-medium">
Your chats
</Text>
<div className="relative left-6">
<SidebarTrigger />
</div>
</div>
<Button
variant="primary"
size="small"
onClick={handleNewChat}
className="w-full"
leftIcon={<PlusIcon className="h-4 w-4" weight="bold" />}
>
New Chat
</Button>
</motion.div>
</SidebarHeader>
)}
<SidebarContent className="gap-4 overflow-y-auto px-4 py-4 [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
{!isCollapsed && (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.1 }}
className="flex items-center justify-between px-3"
>
<Text variant="h3" size="body-medium">
Your chats
</Text>
<div className="relative left-6">
<SidebarTrigger />
</div>
</motion.div>
)}
{!isCollapsed && (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.15 }}
className="mt-4 flex flex-col gap-1"
className="flex flex-col gap-1"
>
{isLoadingSessions ? (
<div className="flex min-h-[30rem] items-center justify-center py-4">
@@ -203,76 +270,105 @@ export function ChatSidebar() {
: "hover:bg-zinc-50",
)}
>
<button
onClick={() => handleSelectSession(session.id)}
className="w-full px-3 py-2.5 pr-10 text-left"
>
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
<div className="min-w-0 max-w-full">
<Text
variant="body"
className={cn(
"truncate font-normal",
session.id === sessionId
? "text-zinc-600"
: "text-zinc-800",
)}
>
{session.title || `Untitled chat`}
{editingSessionId === session.id ? (
<div className="px-3 py-2.5">
<input
ref={renameInputRef}
type="text"
aria-label="Rename chat"
value={editingTitle}
onChange={(e) => setEditingTitle(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter") {
e.currentTarget.blur();
} else if (e.key === "Escape") {
renameCancelledRef.current = true;
setEditingSessionId(null);
}
}}
onBlur={() => {
if (renameCancelledRef.current) {
renameCancelledRef.current = false;
return;
}
handleRenameSubmit(session.id);
}}
className="w-full rounded border border-zinc-300 bg-white px-2 py-1 text-sm text-zinc-800 outline-none focus:border-purple-500 focus:ring-1 focus:ring-purple-500"
/>
</div>
) : (
<button
onClick={() => handleSelectSession(session.id)}
className="w-full px-3 py-2.5 pr-10 text-left"
>
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
<div className="min-w-0 max-w-full">
<Text
variant="body"
className={cn(
"truncate font-normal",
session.id === sessionId
? "text-zinc-600"
: "text-zinc-800",
)}
>
<AnimatePresence mode="wait" initial={false}>
<motion.span
key={session.title || "untitled"}
initial={{ opacity: 0, y: 4 }}
animate={{ opacity: 1, y: 0 }}
exit={{ opacity: 0, y: -4 }}
transition={{ duration: 0.2 }}
className="block truncate"
>
{session.title || "Untitled chat"}
</motion.span>
</AnimatePresence>
</Text>
</div>
<Text variant="small" className="text-neutral-400">
{formatDate(session.updated_at)}
</Text>
</div>
<Text variant="small" className="text-neutral-400">
{formatDate(session.updated_at)}
</Text>
</div>
</button>
<DropdownMenu>
<DropdownMenuTrigger asChild>
<button
onClick={(e) => e.stopPropagation()}
className="absolute right-2 top-1/2 -translate-y-1/2 rounded-full p-1.5 text-zinc-600 transition-all hover:bg-neutral-100"
aria-label="More actions"
>
<DotsThree className="h-4 w-4" />
</button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem
onClick={(e) =>
handleDeleteClick(e, session.id, session.title)
}
disabled={isDeleting}
className="text-red-600 focus:bg-red-50 focus:text-red-600"
>
Delete chat
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
</button>
)}
{editingSessionId !== session.id && (
<DropdownMenu>
<DropdownMenuTrigger asChild>
<button
onClick={(e) => e.stopPropagation()}
className="absolute right-2 top-1/2 -translate-y-1/2 rounded-full p-1.5 text-zinc-600 transition-all hover:bg-neutral-100"
aria-label="More actions"
>
<DotsThree className="h-4 w-4" />
</button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem
onClick={(e) =>
handleRenameClick(e, session.id, session.title)
}
>
Rename
</DropdownMenuItem>
<DropdownMenuItem
onClick={(e) =>
handleDeleteClick(e, session.id, session.title)
}
disabled={isDeleting}
className="text-red-600 focus:bg-red-50 focus:text-red-600"
>
Delete chat
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
)}
</div>
))
)}
</motion.div>
)}
</SidebarContent>
{!isCollapsed && sessionId && (
<SidebarFooter className="shrink-0 bg-zinc-50 p-3 pb-1 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.2 }}
>
<Button
variant="primary"
size="small"
onClick={handleNewChat}
className="w-full"
leftIcon={<PlusIcon className="h-4 w-4" weight="bold" />}
>
New Chat
</Button>
</motion.div>
</SidebarFooter>
)}
</Sidebar>
<DeleteChatDialog

View File

@@ -29,7 +29,6 @@ export function DeleteChatDialog({
}
},
}}
onClose={isDeleting ? undefined : onCancel}
>
<Dialog.Content>
<Text variant="body">

View File

@@ -71,6 +71,17 @@ export function MobileDrawer({
<X width="1rem" height="1rem" />
</Button>
</div>
<div className="mt-2">
<Button
variant="primary"
size="small"
onClick={onNewChat}
className="w-full"
leftIcon={<PlusIcon width="1rem" height="1rem" />}
>
New Chat
</Button>
</div>
</div>
<div
className={cn(
@@ -120,19 +131,6 @@ export function MobileDrawer({
))
)}
</div>
{currentSessionId && (
<div className="shrink-0 bg-white p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
<Button
variant="primary"
size="small"
onClick={onNewChat}
className="w-full"
leftIcon={<PlusIcon width="1rem" height="1rem" />}
>
New Chat
</Button>
</div>
)}
</Drawer.Content>
</Drawer.Portal>
</Drawer.Root>

View File

@@ -181,6 +181,14 @@ export function convertChatSessionMessagesToUiMessages(
if (parts.length === 0) return;
// Merge consecutive assistant messages into a single UIMessage
// to avoid split bubbles on page reload.
const prevUI = uiMessages[uiMessages.length - 1];
if (msg.role === "assistant" && prevUI && prevUI.role === "assistant") {
prevUI.parts.push(...parts);
return;
}
uiMessages.push({
id: `${sessionId}-${index}`,
role: msg.role,

View File

@@ -10,7 +10,7 @@ import {
WarningDiamondIcon,
} from "@phosphor-icons/react";
import type { ToolUIPart } from "ai";
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
import { ScaleLoader } from "../../components/ScaleLoader/ScaleLoader";
export type CreateAgentToolOutput =
| AgentPreviewResponse
@@ -134,7 +134,7 @@ export function ToolIcon({
);
}
if (isStreaming) {
return <OrbitLoader size={24} />;
return <ScaleLoader size={14} />;
}
return <PlusIcon size={14} weight="regular" className="text-neutral-400" />;
}

Some files were not shown because too many files have changed in this diff Show More