Compare commits

..

16 Commits

Author SHA1 Message Date
Swifty
af7e8d19fe feat(platform): add beta invite provisioning 2026-03-09 16:38:23 +01:00
Otto
eadc68f2a5 feat(frontend/copilot): move microphone button to right side of input box (#12320)
Requested by @olivia-1421

Moves the microphone/recording button from the left-side tools group to
the right side, next to the submit button. The left side is now reserved
for the attachment/upload (plus) button only.

**Before:** `[ 📎 🎤 ] .................. [ ➤ ]`
**After:**  `[ 📎 ] .................. [ 🎤 ➤ ]`

---
Co-authored-by: Olivia <olivia-1421@users.noreply.github.com>

---------

Co-authored-by: Ubbe <hi@ubbe.dev>
Co-authored-by: Lluis Agusti <hi@llu.lu>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 18:37:02 +08:00
Reinier van der Leer
eca7b5e793 Merge commit from fork 2026-03-08 10:24:44 +01:00
Otto
c304a4937a fix(backend): Handle manual run attempts for triggered agents (#12298)
When a webhook-triggered agent is executed directly (e.g. via Copilot)
without actual webhook data, `GraphExecution.from_db()` crashes with
`KeyError: 'payload'` because it does a hard key access on
`exec.input_data["payload"]` for webhook blocks.

This caused 232 Sentry events (AUTOGPT-SERVER-821) and multiple
INCOMPLETE graph executions due to retries.

**Changes:**

1. **Defensive fix in `from_db()`** — use `.get("payload")` instead of
`["payload"]` to handle missing keys gracefully (matches existing
pattern for input blocks using `.get("value")`)

2. **Upfront refusal in `_construct_starting_node_execution_input()`** —
refuse execution of webhook/webhook_manual blocks when no payload is
provided. The check is placed after `nodes_input_masks` application, so
legitimate webhook triggers (which inject payload via
`nodes_input_masks`) pass through fine.

Resolves [SENTRY-1113: Copilot is able to manually initiate runs for
triggered agents (which
fails)](https://linear.app/autogpt/issue/SENTRY-1113/copilot-is-able-to-manually-initiate-runs-for-triggered-agents-which)

---
Co-authored-by: Reinier van der Leer (@Pwuts) <pwuts@agpt.co>
2026-03-06 20:47:51 +00:00
Zamil Majdy
8cfabcf4fd refactor(backend/copilot): centralize prompt building in prompting.py (#12324)
## Summary

Centralizes all prompt building logic into a new
`backend/copilot/prompting.py` module with clear SDK vs baseline and
local vs E2B distinctions.

### Key Changes

**New `prompting.py` module:**
- `get_sdk_supplement(use_e2b, cwd)` - For SDK mode (NO tool docs -
Claude gets schemas automatically)
- `get_baseline_supplement(use_e2b, cwd)` - For baseline mode (WITH
auto-generated tool docs from TOOL_REGISTRY)
- Handles local/E2B storage differences

**SDK mode (`sdk/service.py`):**
- Removed 165+ lines of duplicate constants
- Now imports and uses `get_sdk_supplement()`
- Cleaner, more maintainable

**Baseline mode (`baseline/service.py`):**
- Now appends `get_baseline_supplement()` to system prompt
- Baseline mode finally gets tool documentation!

**Enhanced tool descriptions:**
- `create_agent`: Added feedback loop workflow (suggested_goal,
clarifying_questions)
- `run_mcp_tool`: Added known server URLs, 2-step workflow, auth
handling

**Tests:**
- Updated to verify SDK excludes tool docs, baseline includes them
- All existing tests pass

### Architecture Benefits

 Single source of truth for prompt supplements
 Clear SDK vs baseline distinction (SDK doesn't need tool docs)
 Clear local vs E2B distinction (storage systems)
 Easy to maintain and update
 Eliminates code duplication

## Test plan

- [x] Unit tests pass (TestPromptSupplement class)
- [x] SDK mode excludes tool documentation
- [x] Baseline mode includes tool documentation
- [x] E2B vs local mode differences handled correctly
2026-03-06 18:56:20 +00:00
Zamil Majdy
7bf407b66c Merge branch 'master' of github.com:Significant-Gravitas/AutoGPT into dev 2026-03-07 02:01:41 +07:00
Zamil Majdy
7ead4c040f hotfix(backend/copilot): capture tool results in transcript (#12323)
## Summary
- Fixes tool results not being captured in the CoPilot transcript during
SDK-based streaming
- Adds `transcript_builder.add_user_message()` call with `tool_result`
content block when a `StreamToolOutputAvailable` event is received
- Ensures transcript accurately reflects the full conversation including
tool outputs, which is critical for Langfuse tracing and debugging

## Context
After the transcript refactor in #12318, tool call results from the SDK
streaming loop were not being recorded in the transcript. This meant
Langfuse traces were missing tool outputs, making it hard to debug agent
behavior.

## Test plan
- [ ] Verify CoPilot conversation with tool calls captures tool results
in Langfuse traces
- [ ] Verify transcript includes tool_result content blocks after tool
execution
2026-03-06 18:58:48 +00:00
Abhimanyu Yadav
0f813f1bf9 feat(copilot): Add folder management tools to CoPilot (#12290)
Adds folder management capabilities to the CoPilot, allowing users to
organize agents into folders directly from the chat interface.

<img width="823" height="356" alt="Screenshot 2026-03-05 at 5 26 30 PM"
src="https://github.com/user-attachments/assets/4c55f926-1e71-488f-9eb6-fca87c4ab01b"
/>
<img width="797" height="150" alt="Screenshot 2026-03-05 at 5 28 40 PM"
src="https://github.com/user-attachments/assets/5c9c6f8b-57ac-4122-b17d-b9f091bb7c4e"
/>
<img width="763" height="196" alt="Screenshot 2026-03-05 at 5 28 36 PM"
src="https://github.com/user-attachments/assets/d1b22b5d-921d-44ac-90e8-a5820bb3146d"
/>
<img width="756" height="199" alt="Screenshot 2026-03-05 at 5 30 17 PM"
src="https://github.com/user-attachments/assets/40a59748-f42e-4521-bae0-cc786918a9b5"
/>

### Changes

**Backend -- 6 new CoPilot tools** (`manage_folders.py`):
- `create_folder` -- Create folders with optional parent, icon, and
color
- `list_folders` -- List folder tree or children of a specific folder,
with optional `include_agents` to show agents inside each folder
- `update_folder` -- Rename or change icon/color
- `move_folder` -- Reparent a folder or move to root
- `delete_folder` -- Soft-delete (agents moved to root, not deleted)
- `move_agents_to_folder` -- Bulk-move agents into a folder or back to
root

**Backend -- DatabaseManager RPC registration**:
- Registered all 7 folder DB functions (`create_folder`, `list_folders`,
`get_folder_tree`, `update_folder`, `move_folder`, `delete_folder`,
`bulk_move_agents_to_folder`) in `DatabaseManager` and
`DatabaseManagerAsyncClient` so they work via RPC in the CoPilotExecutor
process
- `manage_folders.py` uses `db_accessors.library_db()` pattern
(consistent with all other copilot tools) instead of direct Prisma
imports

**Backend -- folder_id threading**:
- `create_agent` and `customize_agent` tools accept optional `folder_id`
to save agents directly into a folder
- `save_agent_to_library` -> `create_graph_in_library` ->
`create_library_agent` pipeline passes `folder_id` through
- `create_library_agent` refactored from `asyncio.gather` to sequential
loop to support conditional `folderId` assignment on the main graph only
(not sub-graphs)

**Backend -- system prompt and models**:
- Added folder tool descriptions and usage guidance to Otto's system
prompt
- Added `FolderAgentSummary` model for lightweight agent info in folder
listings
- Added 6 `ResponseType` enum values and corresponding Pydantic response
models (`FolderInfo`, `FolderTreeInfo`, `FolderCreatedResponse`, etc.)

**Frontend -- FolderTool UI component**:
- `FolderTool.tsx` -- Renders folder operations in chat using the
`file-tree` molecule component for tree view, with `FileIcon` for agents
and `FolderIcon` for folders (both `text-neutral-600`)
- `helpers.ts` -- Type guards, output parsing, animation text helpers,
and `FolderAgentSummary` type
- `MessagePartRenderer.tsx` -- Routes 6 folder tool types to
`FolderTool` component
- Flat folder list view shows agents inside `FolderCard` when
`include_agents` is set

**Frontend -- file-tree molecule**:
- Fixed 3 pre-existing lint errors in `file-tree.tsx` (unused `ref`,
`handleSelect`, `className` params)
- Updated tree indicator line color from `bg-neutral-100` to
`bg-neutral-400` for visibility
- Added `file-tree.stories.tsx` with 5 stories: Default, AllExpanded,
FoldersOnly, WithInitialSelection, NoIndicator
- Added `ui/scroll-area.tsx` (dependency of file-tree, was missing from
non-legacy ui folder)

### Checklist

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Create a folder via copilot chat ("create a folder called
Marketing")
  - [x] List folders ("show me my folders")
- [x] List folders with agents ("show me my folders and the agents in
them")
- [x] Update folder name/icon/color ("rename Marketing folder to Sales")
- [x] Move folder to a different parent ("move Sales into the Projects
folder")
  - [x] Delete a folder and verify agents move to root
- [x] Move agents into a folder ("put my newsletter agent in the
Marketing folder")
- [x] Create agent with folder_id ("create a scraper agent and save it
in my Tools folder")
- [x] Verify FolderTool UI renders loading, success, error, and empty
states correctly
- [x] Verify folder tree renders nested folders with file-tree component
- [x] Verify agents appear as FileIcon nodes in tree view when
include_agents is true
  - [x] Verify file-tree storybook stories render correctly
2026-03-06 14:59:03 +00:00
Reinier van der Leer
aa08063939 refactor(backend/db): Improve & clean up Marketplace DB layer & API (#12284)
These changes were part of #12206, but here they are separately for
easier review.
This is all primarily to make the v2 API (#11678) work possible/easier.

### Changes 🏗️

- Fix relations between `Profile`, `StoreListing`, and `AgentGraph`
- Redefine `StoreSubmission` view with more efficient joins (100x
speed-up on dev DB) and more consistent field names
- Clean up query functions in `store/db.py`
- Clean up models in `store/model.py`
- Add missing fields to `StoreAgent` and `StoreSubmission` views
- Rename ambiguous `agent_id` -> `graph_id`
- Clean up API route definitions & docs in `store/routes.py`
  - Make routes more consistent
- Avoid collision edge-case between `/agents/{username}/{agent_name}`
and `/agents/{store_listing_version_id}/*`
- Replace all usages of legacy `BackendAPI` for store endpoints with
generated client
- Remove scope requirements on public store endpoints in v1 external API

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Test all Marketplace views (including admin views)
    - [x] Download an agent from the marketplace
  - [x] Submit an agent to the Marketplace
  - [x] Approve/reject Marketplace submission
2026-03-06 14:38:12 +00:00
Zamil Majdy
bde6a4c0df Merge branch 'master' of github.com:Significant-Gravitas/AutoGPT into dev
# Conflicts:
#	autogpt_platform/backend/backend/copilot/sdk/service.py
2026-03-06 21:07:37 +07:00
Ubbe
7507240177 feat(copilot): collapse repeated tool calls and fix stream stuck on completion (#12282)
## Summary
- **Frontend:** Group consecutive completed generic tool parts into
collapsible summary rows with a "Reasoning" collapse for finalized
messages. Merge consecutive assistant messages on hydration to avoid
split bubbles. Extract GenericTool helpers. Add `reconnectExhausted`
state and a brief delay before refetching session to reduce stale
`active_stream` reconnect cycles.
- **Backend:** Make transcript upload fire-and-forget instead of
blocking the generator exit. The 30s upload timeout in
`_try_upload_transcript` was delaying `mark_session_completed()`,
keeping the SSE stream alive with only heartbeats after the LLM had
finished — causing the UI to stay stuck in "streaming" state.

## Test plan
- [ ] Send a message in Copilot that triggers multiple tool calls —
verify they collapse into a grouped summary row once completed
- [ ] Verify the final text response appears below the collapsed
reasoning section
- [ ] Confirm the stream properly closes after the agent finishes (no
stuck "Stop" button)
- [ ] Refresh mid-stream and verify reconnection works correctly
- [ ] Click Stop during streaming — verify the UI becomes responsive
immediately

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-06 21:21:59 +08:00
Abhimanyu Yadav
d7c3f5b8fc fix(frontend): bypass Next.js proxy for file uploads to fix 413 error (#12315)
## Summary
- File uploads routed through the Next.js API proxy (`/api/proxy/...`)
fail with HTTP 413 for files >4.5MB due to Vercel's serverless function
body size limit
- Created shared `uploadFileDirect` utility (`src/lib/direct-upload.ts`)
that uploads files directly from the browser to the Python backend,
bypassing the proxy entirely
- Updated `useWorkspaceUpload` to use direct upload instead of the
generated hook (which went through the proxy)
- Deduplicated the copilot page's inline upload logic to use the same
shared utility

## Changes 🏗️
- **New**: `src/lib/direct-upload.ts` — shared utility for
direct-to-backend file uploads (up to 256MB)
- **Updated**: `useWorkspaceUpload.ts` — replaced proxy-based generated
hook with `uploadFileDirect`
- **Updated**: `useCopilotPage.ts` — replaced inline upload logic with
shared `uploadFileDirect`, removed unused imports

## Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Upload a file >5MB via workspace file input (e.g. in agent
builder) — should succeed without 413
  - [x] Upload a file >5MB via copilot chat — should succeed without 413
  - [x] Upload a small file (<1MB) via both paths — should still work
  - [x] Verify file delete still works from workspace file input
2026-03-06 12:20:18 +00:00
Krzysztof Czerwinski
08c49a78f8 feat(copilot): UX improvements (#12258)
CoPilot conversation UX improvements (SECRT-2055):

1. **Rename conversations** — Inline rename via the session dropdown
menu. New `PATCH /sessions/{session_id}/title` endpoint with server-side
validation (rejects blank/whitespace-only titles, normalizes
whitespace). Pressing Enter or clicking away submits; Escape cancels
without submitting.

2. **New Chat button moved to top & sticky** — The 'New Chat' button is
now at the top of the sidebar (under 'Your chats') instead of the
footer, and stays fixed — only the session list below it scrolls. A
subtle shadow separator mirrors the original footer style.

3. **Auto-generated title appears live** — After the first message in a
new chat, the sidebar polls for the backend-generated title and animates
it in smoothly once available. The backend also guards against
auto-title overwriting a user-set title.

4. **External Link popup redesign** — Replaced the CSS-hacked external
link confirmation dialog with a proper AutoGPT `Dialog` component using
the design system (`Button`, `Text`, `Dialog`). Removed the old
`globals.css` workaround.

<img width="321" height="263" alt="Screenshot 2026-03-03 at 6 31 50 pm"
src="https://github.com/user-attachments/assets/3cdd1c6f-cca6-4f16-8165-15a1dc2d53f7"
/>

<img width="374" height="74" alt="Screenshot 2026-03-02 at 6 39 07 pm"
src="https://github.com/user-attachments/assets/6f9fc953-5fa7-4469-9eab-7074e7604519"
/>

<img width="548" height="293" alt="Screenshot 2026-03-02 at 6 36 28 pm"
src="https://github.com/user-attachments/assets/0f34683b-7281-4826-ac6f-ac7926e67854"
/>

### Changes 🏗️

**Backend:**
- `routes.py`: Added `PATCH /sessions/{session_id}/title` endpoint with
`UpdateSessionTitleRequest` Pydantic model — validates non-blank title,
normalizes whitespace, returns 404 vs 500 correctly
- `routes_test.py`: New test file — 7 test cases covering success,
whitespace trimming, blank rejection (422), not found (404), internal
failure (500)
- `service.py`: Auto-title generation now checks if a user-set title
already exists before overwriting
- `openapi.json`: Updated with new endpoint schema

**Frontend:**
- `ChatSidebar.tsx`: Inline rename (Enter/blur submits, Escape cancels
via ref flag); "New Chat" button sticky at top with shadow separator;
session title animates when auto-generated title appears
(`AnimatePresence`)
- `useCopilotPage.ts`: Polls for auto-generated title after stream ends,
stops as soon as title appears in cache
- `MobileDrawer.tsx`: Updated to match sidebar layout changes
- `DeleteChatDialog.tsx`: Removed redundant `onClose` prop (controlled
Dialog already handles close)
- `message.tsx`: Added `ExternalLinkModal` using AutoGPT design system;
removed redundant `onClose` prop
- `globals.css`: Removed old CSS hack for external link modal

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Create a new chat, send a message — verify auto-generated title
appears in sidebar without refresh
- [x] Rename a chat via dropdown — Enter submits, Escape reverts, blank
title rejected
- [x] Rename a chat, then send another message — verify user title is
not overwritten by auto-title
- [x] With many chats, scroll the sidebar — verify "New Chat" button
stays fixed at top
- [x] Click an external link in a message — verify the new dialog
appears with AutoGPT styling

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-06 06:01:41 +00:00
Bently
5d56548e6b fix(frontend): prevent crash on /library with 401 error from pagination helper (#12292)
## Changes
Fixes crash on `/library` page when backend returns a 401 authentication
error.

### Problem

When the backend returns a 401 error, React Query still calls
`getNextPageParam` with the error response. The response doesn't have
the expected pagination structure, causing `pagination` to be
`undefined`. The code then crashes trying
 to access `pagination.current_page`.

Error:
TypeError: Cannot read properties of undefined (reading 'current_page')
    at Object.getNextPageParam

### Solution

Added a defensive null check in `getPaginationNextPageNumber()` to
handle cases where `pagination` is undefined:

```typescript
const { pagination } = lastPage.data;
if (!pagination) return undefined;
```
When undefined is returned, React Query interprets this as "no next page
available" and gracefully stops pagination instead of crashing.

Testing

- Manual testing: Verify /library page handles 401 errors without
crashing
- The fix is defensive and doesn't change behavior for successful
responses

Related Issues

Closes OPEN-2684
2026-03-05 19:52:36 +00:00
Otto
6ecf55d214 fix(frontend): fix 'Open link' button text color to white for contrast (#12304)
Requested by @ntindle

The Streamdown external link safety modal's "Open link" button had dark
text (`color: black`) on a dark background, making it unreadable.
Changed to `color: white` for proper contrast per our design system.

**File:** `autogpt_platform/frontend/src/app/globals.css`

Resolves SECRT-2061

---
Co-authored-by: Nick Tindle (@ntindle)
2026-03-05 19:50:39 +00:00
Bently
7c8c7bf395 feat(llm): add Claude Sonnet 4.6 model (#12158)
## Summary
Adds Claude Sonnet 4.6 (`claude-sonnet-4-6`) to the platform.

## Model Details (from [Anthropic
docs](https://www.anthropic.com/news/claude-sonnet-4-6))
- **API ID:** `claude-sonnet-4-6`
- **Pricing:** $3 / input MTok, $15 / output MTok (same as Sonnet 4.5)
- **Context window:** 200K tokens (1M beta)
- **Max output:** 64K tokens
- **Knowledge cutoff:** Aug 2025 (reliable), Jan 2026 (training data)

## Changes
- Added `CLAUDE_4_6_SONNET` to `LlmModel` enum
- Added metadata entry with correct context/output limits
- Updated Stagehand to use Sonnet 4.6 (better for browser automation
tasks)

## Why
Sonnet 4.6 brings major improvements in coding, computer use, and
reasoning. Developers with early access often prefer it to even Opus
4.5.

---------

Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2026-03-05 19:36:56 +00:00
151 changed files with 9510 additions and 4042 deletions

View File

@@ -1,7 +1,7 @@
import logging
import urllib.parse
from collections import defaultdict
from typing import Annotated, Any, Literal, Optional, Sequence
from typing import Annotated, Any, Optional, Sequence
from fastapi import APIRouter, Body, HTTPException, Security
from prisma.enums import AgentExecutionStatus, APIKeyPermission
@@ -9,9 +9,10 @@ from pydantic import BaseModel, Field
from typing_extensions import TypedDict
import backend.api.features.store.cache as store_cache
import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model
import backend.blocks
from backend.api.external.middleware import require_permission
from backend.api.external.middleware import require_auth, require_permission
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data import user as user_db
@@ -230,13 +231,13 @@ async def get_graph_execution_results(
@v1_router.get(
path="/store/agents",
tags=["store"],
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
response_model=store_model.StoreAgentsResponse,
)
async def get_store_agents(
featured: bool = False,
creator: str | None = None,
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
sorted_by: store_db.StoreAgentsSortOptions | None = None,
search_query: str | None = None,
category: str | None = None,
page: int = 1,
@@ -278,7 +279,7 @@ async def get_store_agents(
@v1_router.get(
path="/store/agents/{username}/{agent_name}",
tags=["store"],
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
response_model=store_model.StoreAgentDetails,
)
async def get_store_agent(
@@ -306,13 +307,13 @@ async def get_store_agent(
@v1_router.get(
path="/store/creators",
tags=["store"],
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
response_model=store_model.CreatorsResponse,
)
async def get_store_creators(
featured: bool = False,
search_query: str | None = None,
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
page: int = 1,
page_size: int = 20,
) -> store_model.CreatorsResponse:
@@ -348,7 +349,7 @@ async def get_store_creators(
@v1_router.get(
path="/store/creators/{username}",
tags=["store"],
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
response_model=store_model.CreatorDetails,
)
async def get_store_creator(

View File

@@ -1,4 +1,8 @@
from pydantic import BaseModel
from datetime import datetime
from typing import Any, Literal, Optional
import prisma.enums
from pydantic import BaseModel, EmailStr
from backend.data.model import UserTransaction
from backend.util.models import Pagination
@@ -14,3 +18,42 @@ class UserHistoryResponse(BaseModel):
class AddUserCreditsResponse(BaseModel):
new_balance: int
transaction_key: str
class CreateInvitedUserRequest(BaseModel):
email: EmailStr
name: Optional[str] = None
class InvitedUserResponse(BaseModel):
id: str
email: str
status: prisma.enums.InvitedUserStatus
auth_user_id: Optional[str] = None
name: Optional[str] = None
tally_understanding: Optional[dict[str, Any]] = None
tally_status: prisma.enums.TallyComputationStatus
tally_computed_at: Optional[datetime] = None
tally_error: Optional[str] = None
created_at: datetime
updated_at: datetime
class InvitedUsersResponse(BaseModel):
invited_users: list[InvitedUserResponse]
class BulkInvitedUserRowResponse(BaseModel):
row_number: int
email: Optional[str] = None
name: Optional[str] = None
status: Literal["CREATED", "SKIPPED", "ERROR"]
message: str
invited_user: Optional[InvitedUserResponse] = None
class BulkInvitedUsersResponse(BaseModel):
created_count: int
skipped_count: int
error_count: int
results: list[BulkInvitedUserRowResponse]

View File

@@ -24,14 +24,13 @@ router = fastapi.APIRouter(
@router.get(
"/listings",
summary="Get Admin Listings History",
response_model=store_model.StoreListingsWithVersionsResponse,
)
async def get_admin_listings_with_versions(
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
search: typing.Optional[str] = None,
page: int = 1,
page_size: int = 20,
):
) -> store_model.StoreListingsWithVersionsAdminViewResponse:
"""
Get store listings with their version history for admins.
@@ -45,36 +44,26 @@ async def get_admin_listings_with_versions(
page_size: Number of items per page
Returns:
StoreListingsWithVersionsResponse with listings and their versions
Paginated listings with their versions
"""
try:
listings = await store_db.get_admin_listings_with_versions(
status=status,
search_query=search,
page=page,
page_size=page_size,
)
return listings
except Exception as e:
logger.exception("Error getting admin listings with versions: %s", e)
return fastapi.responses.JSONResponse(
status_code=500,
content={
"detail": "An error occurred while retrieving listings with versions"
},
)
listings = await store_db.get_admin_listings_with_versions(
status=status,
search_query=search,
page=page,
page_size=page_size,
)
return listings
@router.post(
"/submissions/{store_listing_version_id}/review",
summary="Review Store Submission",
response_model=store_model.StoreSubmission,
)
async def review_submission(
store_listing_version_id: str,
request: store_model.ReviewSubmissionRequest,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
) -> store_model.StoreSubmissionAdminView:
"""
Review a store listing submission.
@@ -84,31 +73,24 @@ async def review_submission(
user_id: Authenticated admin user performing the review
Returns:
StoreSubmission with updated review information
StoreSubmissionAdminView with updated review information
"""
try:
already_approved = await store_db.check_submission_already_approved(
store_listing_version_id=store_listing_version_id,
)
submission = await store_db.review_store_submission(
store_listing_version_id=store_listing_version_id,
is_approved=request.is_approved,
external_comments=request.comments,
internal_comments=request.internal_comments or "",
reviewer_id=user_id,
)
already_approved = await store_db.check_submission_already_approved(
store_listing_version_id=store_listing_version_id,
)
submission = await store_db.review_store_submission(
store_listing_version_id=store_listing_version_id,
is_approved=request.is_approved,
external_comments=request.comments,
internal_comments=request.internal_comments or "",
reviewer_id=user_id,
)
state_changed = already_approved != request.is_approved
# Clear caches when the request is approved as it updates what is shown on the store
if state_changed:
store_cache.clear_all_caches()
return submission
except Exception as e:
logger.exception("Error reviewing submission: %s", e)
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while reviewing the submission"},
)
state_changed = already_approved != request.is_approved
# Clear caches whenever approval state changes, since store visibility can change
if state_changed:
store_cache.clear_all_caches()
return submission
@router.get(

View File

@@ -0,0 +1,143 @@
import logging
from autogpt_libs.auth import get_user_id, requires_admin_user
from fastapi import APIRouter, File, Security, UploadFile
from backend.data.invited_user import (
BulkInvitedUsersResult,
InvitedUserRecord,
bulk_create_invited_users_from_file,
create_invited_user,
list_invited_users,
retry_invited_user_tally,
revoke_invited_user,
)
from .model import (
BulkInvitedUserRowResponse,
BulkInvitedUsersResponse,
CreateInvitedUserRequest,
InvitedUserResponse,
InvitedUsersResponse,
)
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/admin",
tags=["users", "admin"],
dependencies=[Security(requires_admin_user)],
)
def _to_response(invited_user: InvitedUserRecord) -> InvitedUserResponse:
return InvitedUserResponse(**invited_user.model_dump())
def _to_bulk_response(result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
return BulkInvitedUsersResponse(
created_count=result.created_count,
skipped_count=result.skipped_count,
error_count=result.error_count,
results=[
BulkInvitedUserRowResponse(
row_number=row.row_number,
email=row.email,
name=row.name,
status=row.status,
message=row.message,
invited_user=(
_to_response(row.invited_user)
if row.invited_user is not None
else None
),
)
for row in result.results
],
)
@router.get(
"/invited-users",
response_model=InvitedUsersResponse,
summary="List Invited Users",
)
async def get_invited_users(
admin_user_id: str = Security(get_user_id),
) -> InvitedUsersResponse:
logger.info("Admin user %s requested invited users", admin_user_id)
invited_users = await list_invited_users()
return InvitedUsersResponse(
invited_users=[_to_response(invited_user) for invited_user in invited_users]
)
@router.post(
"/invited-users",
response_model=InvitedUserResponse,
summary="Create Invited User",
)
async def create_invited_user_route(
request: CreateInvitedUserRequest,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info(
"Admin user %s created invited user for %s",
admin_user_id,
request.email,
)
invited_user = await create_invited_user(request.email, request.name)
return _to_response(invited_user)
@router.post(
"/invited-users/bulk",
response_model=BulkInvitedUsersResponse,
summary="Bulk Create Invited Users",
operation_id="postV2BulkCreateInvitedUsers",
)
async def bulk_create_invited_users_route(
file: UploadFile = File(...),
admin_user_id: str = Security(get_user_id),
) -> BulkInvitedUsersResponse:
logger.info(
"Admin user %s bulk invited users from %s",
admin_user_id,
file.filename or "<unnamed>",
)
content = await file.read()
result = await bulk_create_invited_users_from_file(file.filename, content)
return _to_bulk_response(result)
@router.post(
"/invited-users/{invited_user_id}/revoke",
response_model=InvitedUserResponse,
summary="Revoke Invited User",
)
async def revoke_invited_user_route(
invited_user_id: str,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info("Admin user %s revoked invited user %s", admin_user_id, invited_user_id)
invited_user = await revoke_invited_user(invited_user_id)
return _to_response(invited_user)
@router.post(
"/invited-users/{invited_user_id}/retry-tally",
response_model=InvitedUserResponse,
summary="Retry Invited User Tally",
)
async def retry_invited_user_tally_route(
invited_user_id: str,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info(
"Admin user %s retried Tally seed for invited user %s",
admin_user_id,
invited_user_id,
)
invited_user = await retry_invited_user_tally(invited_user_id)
return _to_response(invited_user)

View File

@@ -0,0 +1,165 @@
from datetime import datetime, timezone
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import prisma.enums
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from backend.data.invited_user import (
BulkInvitedUserRowResult,
BulkInvitedUsersResult,
InvitedUserRecord,
)
from .user_admin_routes import router as user_admin_router
app = fastapi.FastAPI()
app.include_router(user_admin_router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def _sample_invited_user() -> InvitedUserRecord:
now = datetime.now(timezone.utc)
return InvitedUserRecord(
id="invite-1",
email="invited@example.com",
status=prisma.enums.InvitedUserStatus.INVITED,
auth_user_id=None,
name="Invited User",
tally_understanding=None,
tally_status=prisma.enums.TallyComputationStatus.PENDING,
tally_computed_at=None,
tally_error=None,
created_at=now,
updated_at=now,
)
def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
return BulkInvitedUsersResult(
created_count=1,
skipped_count=1,
error_count=0,
results=[
BulkInvitedUserRowResult(
row_number=1,
email="invited@example.com",
name=None,
status="CREATED",
message="Invite created",
invited_user=_sample_invited_user(),
),
BulkInvitedUserRowResult(
row_number=2,
email="duplicate@example.com",
name=None,
status="SKIPPED",
message="An invited user with this email already exists",
invited_user=None,
),
],
)
def test_get_invited_users(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.list_invited_users",
AsyncMock(return_value=[_sample_invited_user()]),
)
response = client.get("/admin/invited-users")
assert response.status_code == 200
data = response.json()
assert len(data["invited_users"]) == 1
assert data["invited_users"][0]["email"] == "invited@example.com"
assert data["invited_users"][0]["status"] == "INVITED"
def test_create_invited_user(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.create_invited_user",
AsyncMock(return_value=_sample_invited_user()),
)
response = client.post(
"/admin/invited-users",
json={"email": "invited@example.com", "name": "Invited User"},
)
assert response.status_code == 200
data = response.json()
assert data["email"] == "invited@example.com"
assert data["name"] == "Invited User"
def test_bulk_create_invited_users(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.bulk_create_invited_users_from_file",
AsyncMock(return_value=_sample_bulk_invited_users_result()),
)
response = client.post(
"/admin/invited-users/bulk",
files={
"file": ("invites.txt", b"invited@example.com\nduplicate@example.com\n")
},
)
assert response.status_code == 200
data = response.json()
assert data["created_count"] == 1
assert data["skipped_count"] == 1
assert data["results"][0]["status"] == "CREATED"
assert data["results"][1]["status"] == "SKIPPED"
def test_revoke_invited_user(
mocker: pytest_mock.MockerFixture,
) -> None:
revoked = _sample_invited_user().model_copy(
update={"status": prisma.enums.InvitedUserStatus.REVOKED}
)
mocker.patch(
"backend.api.features.admin.user_admin_routes.revoke_invited_user",
AsyncMock(return_value=revoked),
)
response = client.post("/admin/invited-users/invite-1/revoke")
assert response.status_code == 200
assert response.json()["status"] == "REVOKED"
def test_retry_invited_user_tally(
mocker: pytest_mock.MockerFixture,
) -> None:
retried = _sample_invited_user().model_copy(
update={"tally_status": prisma.enums.TallyComputationStatus.RUNNING}
)
mocker.patch(
"backend.api.features.admin.user_admin_routes.retry_invited_user_tally",
AsyncMock(return_value=retried),
)
response = client.post("/admin/invited-users/invite-1/retry-tally")
assert response.status_code == 200
assert response.json()["tally_status"] == "RUNNING"

View File

@@ -11,7 +11,7 @@ from autogpt_libs import auth
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from prisma.models import UserWorkspaceFile
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
@@ -25,6 +25,7 @@ from backend.copilot.model import (
delete_chat_session,
get_chat_session,
get_user_sessions,
update_session_title,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.models import (
@@ -141,6 +142,20 @@ class CancelSessionResponse(BaseModel):
reason: str | None = None
class UpdateSessionTitleRequest(BaseModel):
"""Request model for updating a session's title."""
title: str
@field_validator("title")
@classmethod
def title_must_not_be_blank(cls, v: str) -> str:
stripped = v.strip()
if not stripped:
raise ValueError("Title must not be blank")
return stripped
# ========== Routes ==========
@@ -264,6 +279,43 @@ async def delete_session(
return Response(status_code=204)
@router.patch(
"/sessions/{session_id}/title",
summary="Update session title",
dependencies=[Security(auth.requires_user)],
status_code=200,
responses={404: {"description": "Session not found or access denied"}},
)
async def update_session_title_route(
session_id: str,
request: UpdateSessionTitleRequest,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> dict:
"""
Update the title of a chat session.
Allows the user to rename their chat session.
Args:
session_id: The session ID to update.
request: Request body containing the new title.
user_id: The authenticated user's ID.
Returns:
dict: Status of the update.
Raises:
HTTPException: 404 if session not found or not owned by user.
"""
success = await update_session_title(session_id, user_id, request.title)
if not success:
raise HTTPException(
status_code=404,
detail=f"Session {session_id} not found or access denied",
)
return {"status": "ok"}
@router.get(
"/sessions/{session_id}",
)
@@ -753,7 +805,6 @@ async def resume_session_stream(
@router.patch(
"/sessions/{session_id}/assign-user",
dependencies=[Security(auth.requires_user)],
status_code=200,
)
async def session_assign_user(
session_id: str,

View File

@@ -1,4 +1,6 @@
"""Tests for chat route file_ids validation and enrichment."""
"""Tests for chat API routes: session title update and file attachment validation."""
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
@@ -17,6 +19,7 @@ TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
"""Setup auth overrides for all tests in this module"""
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
@@ -24,7 +27,95 @@ def setup_app_auth(mock_jwt_user):
app.dependency_overrides.clear()
# ---- file_ids Pydantic validation (B1) ----
def _mock_update_session_title(
mocker: pytest_mock.MockerFixture, *, success: bool = True
):
"""Mock update_session_title."""
return mocker.patch(
"backend.api.features.chat.routes.update_session_title",
new_callable=AsyncMock,
return_value=success,
)
# ─── Update title: success ─────────────────────────────────────────────
def test_update_title_success(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
mock_update = _mock_update_session_title(mocker, success=True)
response = client.patch(
"/sessions/sess-1/title",
json={"title": "My project"},
)
assert response.status_code == 200
assert response.json() == {"status": "ok"}
mock_update.assert_called_once_with("sess-1", test_user_id, "My project")
def test_update_title_trims_whitespace(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
mock_update = _mock_update_session_title(mocker, success=True)
response = client.patch(
"/sessions/sess-1/title",
json={"title": " trimmed "},
)
assert response.status_code == 200
mock_update.assert_called_once_with("sess-1", test_user_id, "trimmed")
# ─── Update title: blank / whitespace-only → 422 ──────────────────────
def test_update_title_blank_rejected(
test_user_id: str,
) -> None:
"""Whitespace-only titles must be rejected before hitting the DB."""
response = client.patch(
"/sessions/sess-1/title",
json={"title": " "},
)
assert response.status_code == 422
def test_update_title_empty_rejected(
test_user_id: str,
) -> None:
response = client.patch(
"/sessions/sess-1/title",
json={"title": ""},
)
assert response.status_code == 422
# ─── Update title: session not found or wrong user → 404 ──────────────
def test_update_title_not_found(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
_mock_update_session_title(mocker, success=False)
response = client.patch(
"/sessions/sess-1/title",
json={"title": "New name"},
)
assert response.status_code == 404
# ─── file_ids Pydantic validation ─────────────────────────────────────
def test_stream_chat_rejects_too_many_file_ids():
@@ -92,7 +183,7 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
assert response.status_code == 200
# ---- UUID format filtering ----
# ─── UUID format filtering ─────────────────────────────────────────────
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
@@ -131,7 +222,7 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
assert call_kwargs["where"]["id"]["in"] == [valid_id]
# ---- Cross-workspace file_ids ----
# ─── Cross-workspace file_ids ─────────────────────────────────────────
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):

View File

@@ -8,7 +8,6 @@ import prisma.errors
import prisma.models
import prisma.types
import backend.api.features.store.exceptions as store_exceptions
import backend.api.features.store.image_gen as store_image_gen
import backend.api.features.store.media as store_media
import backend.data.graph as graph_db
@@ -251,7 +250,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
The requested LibraryAgent.
Raises:
AgentNotFoundError: If the specified agent does not exist.
NotFoundError: If the specified agent does not exist.
DatabaseError: If there's an error during retrieval.
"""
library_agent = await prisma.models.LibraryAgent.prisma().find_first(
@@ -398,6 +397,7 @@ async def create_library_agent(
hitl_safe_mode: bool = True,
sensitive_action_safe_mode: bool = False,
create_library_agents_for_sub_graphs: bool = True,
folder_id: str | None = None,
) -> list[library_model.LibraryAgent]:
"""
Adds an agent to the user's library (LibraryAgent table).
@@ -414,12 +414,18 @@ async def create_library_agent(
If the graph has sub-graphs, the parent graph will always be the first entry in the list.
Raises:
AgentNotFoundError: If the specified agent does not exist.
NotFoundError: If the specified agent does not exist.
DatabaseError: If there's an error during creation or if image generation fails.
"""
logger.info(
f"Creating library agent for graph #{graph.id} v{graph.version}; user:<redacted>"
)
# Authorization: FK only checks existence, not ownership.
# Verify the folder belongs to this user to prevent cross-user nesting.
if folder_id:
await get_folder(folder_id, user_id)
graph_entries = (
[graph, *graph.sub_graphs] if create_library_agents_for_sub_graphs else [graph]
)
@@ -432,7 +438,6 @@ async def create_library_agent(
isCreatedByUser=(user_id == user_id),
useGraphIsActiveVersion=True,
User={"connect": {"id": user_id}},
# Creator={"connect": {"id": user_id}},
AgentGraph={
"connect": {
"graphVersionId": {
@@ -448,6 +453,11 @@ async def create_library_agent(
sensitive_action_safe_mode=sensitive_action_safe_mode,
).model_dump()
),
**(
{"Folder": {"connect": {"id": folder_id}}}
if folder_id and graph_entry is graph
else {}
),
),
include=library_agent_include(
user_id, include_nodes=False, include_executions=False
@@ -529,6 +539,7 @@ async def update_agent_version_in_library(
async def create_graph_in_library(
graph: graph_db.Graph,
user_id: str,
folder_id: str | None = None,
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
"""Create a new graph and add it to the user's library."""
graph.version = 1
@@ -542,6 +553,7 @@ async def create_graph_in_library(
user_id=user_id,
sensitive_action_safe_mode=True,
create_library_agents_for_sub_graphs=False,
folder_id=folder_id,
)
if created_graph.is_active:
@@ -817,7 +829,7 @@ async def add_store_agent_to_library(
The newly created LibraryAgent if successfully added, the existing corresponding one if any.
Raises:
AgentNotFoundError: If the store listing or associated agent is not found.
NotFoundError: If the store listing or associated agent is not found.
DatabaseError: If there's an issue creating the LibraryAgent record.
"""
logger.debug(
@@ -832,7 +844,7 @@ async def add_store_agent_to_library(
)
if not store_listing_version or not store_listing_version.AgentGraph:
logger.warning(f"Store listing version not found: {store_listing_version_id}")
raise store_exceptions.AgentNotFoundError(
raise NotFoundError(
f"Store listing version {store_listing_version_id} not found or invalid"
)
@@ -846,7 +858,7 @@ async def add_store_agent_to_library(
include_subgraphs=False,
)
if not graph_model:
raise store_exceptions.AgentNotFoundError(
raise NotFoundError(
f"Graph #{graph.id} v{graph.version} not found or accessible"
)
@@ -1481,6 +1493,67 @@ async def bulk_move_agents_to_folder(
return [library_model.LibraryAgent.from_db(agent) for agent in agents]
def collect_tree_ids(
nodes: list[library_model.LibraryFolderTree],
visited: set[str] | None = None,
) -> list[str]:
"""Collect all folder IDs from a folder tree."""
if visited is None:
visited = set()
ids: list[str] = []
for n in nodes:
if n.id in visited:
continue
visited.add(n.id)
ids.append(n.id)
ids.extend(collect_tree_ids(n.children, visited))
return ids
async def get_folder_agent_summaries(
user_id: str, folder_id: str
) -> list[dict[str, str | None]]:
"""Get a lightweight list of agents in a folder (id, name, description)."""
all_agents: list[library_model.LibraryAgent] = []
for page in itertools.count(1):
resp = await list_library_agents(
user_id=user_id, folder_id=folder_id, page=page
)
all_agents.extend(resp.agents)
if page >= resp.pagination.total_pages:
break
return [
{"id": a.id, "name": a.name, "description": a.description} for a in all_agents
]
async def get_root_agent_summaries(
user_id: str,
) -> list[dict[str, str | None]]:
"""Get a lightweight list of root-level agents (folderId IS NULL)."""
all_agents: list[library_model.LibraryAgent] = []
for page in itertools.count(1):
resp = await list_library_agents(
user_id=user_id, include_root_only=True, page=page
)
all_agents.extend(resp.agents)
if page >= resp.pagination.total_pages:
break
return [
{"id": a.id, "name": a.name, "description": a.description} for a in all_agents
]
async def get_folder_agents_map(
user_id: str, folder_ids: list[str]
) -> dict[str, list[dict[str, str | None]]]:
"""Get agent summaries for multiple folders concurrently."""
results = await asyncio.gather(
*(get_folder_agent_summaries(user_id, fid) for fid in folder_ids)
)
return dict(zip(folder_ids, results))
##############################################
########### Presets DB Functions #############
##############################################

View File

@@ -4,7 +4,6 @@ import prisma.enums
import prisma.models
import pytest
import backend.api.features.store.exceptions
from backend.data.db import connect
from backend.data.includes import library_agent_include
@@ -218,7 +217,7 @@ async def test_add_agent_to_library_not_found(mocker):
)
# Call function and verify exception
with pytest.raises(backend.api.features.store.exceptions.AgentNotFoundError):
with pytest.raises(db.NotFoundError):
await db.add_store_agent_to_library("version123", "test-user")
# Verify mock called correctly

View File

@@ -1,5 +1,3 @@
from typing import Literal
from backend.util.cache import cached
from . import db as store_db
@@ -23,7 +21,7 @@ def clear_all_caches():
async def _get_cached_store_agents(
featured: bool,
creator: str | None,
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None,
sorted_by: store_db.StoreAgentsSortOptions | None,
search_query: str | None,
category: str | None,
page: int,
@@ -57,7 +55,7 @@ async def _get_cached_agent_details(
async def _get_cached_store_creators(
featured: bool,
search_query: str | None,
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None,
sorted_by: store_db.StoreCreatorsSortOptions | None,
page: int,
page_size: int,
):
@@ -75,4 +73,4 @@ async def _get_cached_store_creators(
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
async def _get_cached_creator_details(username: str):
"""Cached helper to get creator details."""
return await store_db.get_store_creator_details(username=username.lower())
return await store_db.get_store_creator(username=username.lower())

File diff suppressed because it is too large Load Diff

View File

@@ -26,7 +26,7 @@ async def test_get_store_agents(mocker):
mock_agents = [
prisma.models.StoreAgent(
listing_id="test-id",
storeListingVersionId="version123",
listing_version_id="version123",
slug="test-agent",
agent_name="Test Agent",
agent_video=None,
@@ -40,11 +40,11 @@ async def test_get_store_agents(mocker):
runs=10,
rating=4.5,
versions=["1.0"],
agentGraphVersions=["1"],
agentGraphId="test-graph-id",
graph_id="test-graph-id",
graph_versions=["1"],
updated_at=datetime.now(),
is_available=False,
useForOnboarding=False,
use_for_onboarding=False,
)
]
@@ -68,10 +68,10 @@ async def test_get_store_agents(mocker):
@pytest.mark.asyncio(loop_scope="session")
async def test_get_store_agent_details(mocker):
# Mock data
# Mock data - StoreAgent view already contains the active version data
mock_agent = prisma.models.StoreAgent(
listing_id="test-id",
storeListingVersionId="version123",
listing_version_id="version123",
slug="test-agent",
agent_name="Test Agent",
agent_video="video.mp4",
@@ -85,102 +85,38 @@ async def test_get_store_agent_details(mocker):
runs=10,
rating=4.5,
versions=["1.0"],
agentGraphVersions=["1"],
agentGraphId="test-graph-id",
updated_at=datetime.now(),
is_available=False,
useForOnboarding=False,
)
# Mock active version agent (what we want to return for active version)
mock_active_agent = prisma.models.StoreAgent(
listing_id="test-id",
storeListingVersionId="active-version-id",
slug="test-agent",
agent_name="Test Agent Active",
agent_video="active_video.mp4",
agent_image=["active_image.jpg"],
featured=False,
creator_username="creator",
creator_avatar="avatar.jpg",
sub_heading="Test heading active",
description="Test description active",
categories=["test"],
runs=15,
rating=4.8,
versions=["1.0", "2.0"],
agentGraphVersions=["1", "2"],
agentGraphId="test-graph-id-active",
graph_id="test-graph-id",
graph_versions=["1"],
updated_at=datetime.now(),
is_available=True,
useForOnboarding=False,
use_for_onboarding=False,
)
# Create a mock StoreListing result
mock_store_listing = mocker.MagicMock()
mock_store_listing.activeVersionId = "active-version-id"
mock_store_listing.hasApprovedVersion = True
mock_store_listing.ActiveVersion = mocker.MagicMock()
mock_store_listing.ActiveVersion.recommendedScheduleCron = None
# Mock StoreAgent prisma call - need to handle multiple calls
# Mock StoreAgent prisma call
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
# Set up side_effect to return different results for different calls
def mock_find_first_side_effect(*args, **kwargs):
where_clause = kwargs.get("where", {})
if "storeListingVersionId" in where_clause:
# Second call for active version
return mock_active_agent
else:
# First call for initial lookup
return mock_agent
mock_store_agent.return_value.find_first = mocker.AsyncMock(
side_effect=mock_find_first_side_effect
)
# Mock Profile prisma call
mock_profile = mocker.MagicMock()
mock_profile.userId = "user-id-123"
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
mock_profile_db.return_value.find_first = mocker.AsyncMock(
return_value=mock_profile
)
# Mock StoreListing prisma call
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(
return_value=mock_store_listing
)
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
# Call function
result = await db.get_store_agent_details("creator", "test-agent")
# Verify results - should use active version data
# Verify results - constructed from the StoreAgent view
assert result.slug == "test-agent"
assert result.agent_name == "Test Agent Active" # From active version
assert result.active_version_id == "active-version-id"
assert result.agent_name == "Test Agent"
assert result.active_version_id == "version123"
assert result.has_approved_version is True
assert (
result.store_listing_version_id == "active-version-id"
) # Should be active version ID
assert result.store_listing_version_id == "version123"
assert result.graph_id == "test-graph-id"
assert result.runs == 10
assert result.rating == 4.5
# Verify mocks called correctly - now expecting 2 calls
assert mock_store_agent.return_value.find_first.call_count == 2
# Check the specific calls
calls = mock_store_agent.return_value.find_first.call_args_list
assert calls[0] == mocker.call(
# Verify single StoreAgent lookup
mock_store_agent.return_value.find_first.assert_called_once_with(
where={"creator_username": "creator", "slug": "test-agent"}
)
assert calls[1] == mocker.call(where={"storeListingVersionId": "active-version-id"})
mock_store_listing_db.return_value.find_first.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_get_store_creator_details(mocker):
async def test_get_store_creator(mocker):
# Mock data
mock_creator_data = prisma.models.Creator(
name="Test Creator",
@@ -202,7 +138,7 @@ async def test_get_store_creator_details(mocker):
mock_creator.return_value.find_unique.return_value = mock_creator_data
# Call function
result = await db.get_store_creator_details("creator")
result = await db.get_store_creator("creator")
# Verify results
assert result.username == "creator"
@@ -218,61 +154,110 @@ async def test_get_store_creator_details(mocker):
@pytest.mark.asyncio(loop_scope="session")
async def test_create_store_submission(mocker):
# Mock data
now = datetime.now()
# Mock agent graph (with no pending submissions) and user with profile
mock_profile = prisma.models.Profile(
id="profile-id",
userId="user-id",
name="Test User",
username="testuser",
description="Test",
isFeatured=False,
links=[],
createdAt=now,
updatedAt=now,
)
mock_user = prisma.models.User(
id="user-id",
email="test@example.com",
createdAt=now,
updatedAt=now,
Profile=[mock_profile],
emailVerified=True,
metadata="{}", # type: ignore[reportArgumentType]
integrations="",
maxEmailsPerDay=1,
notifyOnAgentRun=True,
notifyOnZeroBalance=True,
notifyOnLowBalance=True,
notifyOnBlockExecutionFailed=True,
notifyOnContinuousAgentError=True,
notifyOnDailySummary=True,
notifyOnWeeklySummary=True,
notifyOnMonthlySummary=True,
notifyOnAgentApproved=True,
notifyOnAgentRejected=True,
timezone="Europe/Delft",
)
mock_agent = prisma.models.AgentGraph(
id="agent-id",
version=1,
userId="user-id",
createdAt=datetime.now(),
createdAt=now,
isActive=True,
StoreListingVersions=[],
User=mock_user,
)
mock_listing = prisma.models.StoreListing(
# Mock the created StoreListingVersion (returned by create)
mock_store_listing_obj = prisma.models.StoreListing(
id="listing-id",
createdAt=datetime.now(),
updatedAt=datetime.now(),
createdAt=now,
updatedAt=now,
isDeleted=False,
hasApprovedVersion=False,
slug="test-agent",
agentGraphId="agent-id",
agentGraphVersion=1,
owningUserId="user-id",
Versions=[
prisma.models.StoreListingVersion(
id="version-id",
agentGraphId="agent-id",
agentGraphVersion=1,
name="Test Agent",
description="Test description",
createdAt=datetime.now(),
updatedAt=datetime.now(),
subHeading="Test heading",
imageUrls=["image.jpg"],
categories=["test"],
isFeatured=False,
isDeleted=False,
version=1,
storeListingId="listing-id",
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
isAvailable=True,
)
],
useForOnboarding=False,
)
mock_version = prisma.models.StoreListingVersion(
id="version-id",
agentGraphId="agent-id",
agentGraphVersion=1,
name="Test Agent",
description="Test description",
createdAt=now,
updatedAt=now,
subHeading="",
imageUrls=[],
categories=[],
isFeatured=False,
isDeleted=False,
version=1,
storeListingId="listing-id",
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
isAvailable=True,
submittedAt=now,
StoreListing=mock_store_listing_obj,
)
# Mock prisma calls
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
mock_store_listing = mocker.patch("prisma.models.StoreListing.prisma")
mock_store_listing.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_store_listing.return_value.create = mocker.AsyncMock(return_value=mock_listing)
# Mock transaction context manager
mock_tx = mocker.MagicMock()
mocker.patch(
"backend.api.features.store.db.transaction",
return_value=mocker.AsyncMock(
__aenter__=mocker.AsyncMock(return_value=mock_tx),
__aexit__=mocker.AsyncMock(return_value=False),
),
)
mock_sl = mocker.patch("prisma.models.StoreListing.prisma")
mock_sl.return_value.find_unique = mocker.AsyncMock(return_value=None)
mock_slv = mocker.patch("prisma.models.StoreListingVersion.prisma")
mock_slv.return_value.create = mocker.AsyncMock(return_value=mock_version)
# Call function
result = await db.create_store_submission(
user_id="user-id",
agent_id="agent-id",
agent_version=1,
graph_id="agent-id",
graph_version=1,
slug="test-agent",
name="Test Agent",
description="Test description",
@@ -281,11 +266,11 @@ async def test_create_store_submission(mocker):
# Verify results
assert result.name == "Test Agent"
assert result.description == "Test description"
assert result.store_listing_version_id == "version-id"
assert result.listing_version_id == "version-id"
# Verify mocks called correctly
mock_agent_graph.return_value.find_first.assert_called_once()
mock_store_listing.return_value.create.assert_called_once()
mock_slv.return_value.create.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
@@ -318,7 +303,6 @@ async def test_update_profile(mocker):
description="Test description",
links=["link1"],
avatar_url="avatar.jpg",
is_featured=False,
)
# Call function
@@ -389,7 +373,7 @@ async def test_get_store_agents_with_search_and_filters_parameterized():
creators=["creator1'; DROP TABLE Users; --", "creator2"],
category="AI'; DELETE FROM StoreAgent; --",
featured=True,
sorted_by="rating",
sorted_by=db.StoreAgentsSortOptions.RATING,
page=1,
page_size=20,
)

View File

@@ -57,12 +57,6 @@ class StoreError(ValueError):
pass
class AgentNotFoundError(NotFoundError):
"""Raised when an agent is not found"""
pass
class CreatorNotFoundError(NotFoundError):
"""Raised when a creator is not found"""

View File

@@ -568,7 +568,7 @@ async def hybrid_search(
SELECT uce."contentId" as "storeListingVersionId"
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
INNER JOIN {{schema_prefix}}"StoreAgent" sa
ON uce."contentId" = sa."storeListingVersionId"
ON uce."contentId" = sa.listing_version_id
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
AND uce."userId" IS NULL
AND uce.search @@ plainto_tsquery('english', {query_param})
@@ -582,7 +582,7 @@ async def hybrid_search(
SELECT uce."contentId", uce.embedding
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
INNER JOIN {{schema_prefix}}"StoreAgent" sa
ON uce."contentId" = sa."storeListingVersionId"
ON uce."contentId" = sa.listing_version_id
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
AND uce."userId" IS NULL
AND {where_clause}
@@ -605,7 +605,7 @@ async def hybrid_search(
sa.featured,
sa.is_available,
sa.updated_at,
sa."agentGraphId",
sa.graph_id,
-- Searchable text for BM25 reranking
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
-- Semantic score
@@ -627,9 +627,9 @@ async def hybrid_search(
sa.runs as popularity_raw
FROM candidates c
INNER JOIN {{schema_prefix}}"StoreAgent" sa
ON c."storeListingVersionId" = sa."storeListingVersionId"
ON c."storeListingVersionId" = sa.listing_version_id
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
ON sa."storeListingVersionId" = uce."contentId"
ON sa.listing_version_id = uce."contentId"
AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
),
max_vals AS (
@@ -665,7 +665,7 @@ async def hybrid_search(
featured,
is_available,
updated_at,
"agentGraphId",
graph_id,
searchable_text,
semantic_score,
lexical_score,

View File

@@ -1,11 +1,14 @@
import datetime
from typing import List
from typing import TYPE_CHECKING, List, Self
import prisma.enums
import pydantic
from backend.util.models import Pagination
if TYPE_CHECKING:
import prisma.models
class ChangelogEntry(pydantic.BaseModel):
version: str
@@ -13,9 +16,9 @@ class ChangelogEntry(pydantic.BaseModel):
date: datetime.datetime
class MyAgent(pydantic.BaseModel):
agent_id: str
agent_version: int
class MyUnpublishedAgent(pydantic.BaseModel):
graph_id: str
graph_version: int
agent_name: str
agent_image: str | None = None
description: str
@@ -23,8 +26,8 @@ class MyAgent(pydantic.BaseModel):
recommended_schedule_cron: str | None = None
class MyAgentsResponse(pydantic.BaseModel):
agents: list[MyAgent]
class MyUnpublishedAgentsResponse(pydantic.BaseModel):
agents: list[MyUnpublishedAgent]
pagination: Pagination
@@ -40,6 +43,21 @@ class StoreAgent(pydantic.BaseModel):
rating: float
agent_graph_id: str
@classmethod
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgent":
return cls(
slug=agent.slug,
agent_name=agent.agent_name,
agent_image=agent.agent_image[0] if agent.agent_image else "",
creator=agent.creator_username or "Needs Profile",
creator_avatar=agent.creator_avatar or "",
sub_heading=agent.sub_heading,
description=agent.description,
runs=agent.runs,
rating=agent.rating,
agent_graph_id=agent.graph_id,
)
class StoreAgentsResponse(pydantic.BaseModel):
agents: list[StoreAgent]
@@ -62,81 +80,192 @@ class StoreAgentDetails(pydantic.BaseModel):
runs: int
rating: float
versions: list[str]
agentGraphVersions: list[str]
agentGraphId: str
graph_id: str
graph_versions: list[str]
last_updated: datetime.datetime
recommended_schedule_cron: str | None = None
active_version_id: str | None = None
has_approved_version: bool = False
active_version_id: str
has_approved_version: bool
# Optional changelog data when include_changelog=True
changelog: list[ChangelogEntry] | None = None
class Creator(pydantic.BaseModel):
name: str
username: str
description: str
avatar_url: str
num_agents: int
agent_rating: float
agent_runs: int
is_featured: bool
class CreatorsResponse(pydantic.BaseModel):
creators: List[Creator]
pagination: Pagination
class CreatorDetails(pydantic.BaseModel):
name: str
username: str
description: str
links: list[str]
avatar_url: str
agent_rating: float
agent_runs: int
top_categories: list[str]
@classmethod
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgentDetails":
return cls(
store_listing_version_id=agent.listing_version_id,
slug=agent.slug,
agent_name=agent.agent_name,
agent_video=agent.agent_video or "",
agent_output_demo=agent.agent_output_demo or "",
agent_image=agent.agent_image,
creator=agent.creator_username or "",
creator_avatar=agent.creator_avatar or "",
sub_heading=agent.sub_heading,
description=agent.description,
categories=agent.categories,
runs=agent.runs,
rating=agent.rating,
versions=agent.versions,
graph_id=agent.graph_id,
graph_versions=agent.graph_versions,
last_updated=agent.updated_at,
recommended_schedule_cron=agent.recommended_schedule_cron,
active_version_id=agent.listing_version_id,
has_approved_version=True, # StoreAgent view only has approved agents
)
class Profile(pydantic.BaseModel):
name: str
"""Marketplace user profile (only attributes that the user can update)"""
username: str
name: str
description: str
avatar_url: str | None
links: list[str]
avatar_url: str
is_featured: bool = False
class ProfileDetails(Profile):
"""Marketplace user profile (including read-only fields)"""
is_featured: bool
@classmethod
def from_db(cls, profile: "prisma.models.Profile") -> "ProfileDetails":
return cls(
name=profile.name,
username=profile.username,
avatar_url=profile.avatarUrl,
description=profile.description,
links=profile.links,
is_featured=profile.isFeatured,
)
class CreatorDetails(ProfileDetails):
"""Marketplace creator profile details, including aggregated stats"""
num_agents: int
agent_runs: int
agent_rating: float
top_categories: list[str]
@classmethod
def from_db(cls, creator: "prisma.models.Creator") -> "CreatorDetails": # type: ignore[override]
return cls(
name=creator.name,
username=creator.username,
avatar_url=creator.avatar_url,
description=creator.description,
links=creator.links,
is_featured=creator.is_featured,
num_agents=creator.num_agents,
agent_runs=creator.agent_runs,
agent_rating=creator.agent_rating,
top_categories=creator.top_categories,
)
class CreatorsResponse(pydantic.BaseModel):
creators: List[CreatorDetails]
pagination: Pagination
class StoreSubmission(pydantic.BaseModel):
# From StoreListing:
listing_id: str
agent_id: str
agent_version: int
user_id: str
slug: str
# From StoreListingVersion:
listing_version_id: str
listing_version: int
graph_id: str
graph_version: int
name: str
sub_heading: str
slug: str
description: str
instructions: str | None = None
instructions: str | None
categories: list[str]
image_urls: list[str]
date_submitted: datetime.datetime
status: prisma.enums.SubmissionStatus
runs: int
rating: float
store_listing_version_id: str | None = None
version: int | None = None # Actual version number from the database
video_url: str | None
agent_output_demo_url: str | None
submitted_at: datetime.datetime | None
changes_summary: str | None
status: prisma.enums.SubmissionStatus
reviewed_at: datetime.datetime | None = None
reviewer_id: str | None = None
review_comments: str | None = None # External comments visible to creator
internal_comments: str | None = None # Private notes for admin use only
reviewed_at: datetime.datetime | None = None
changes_summary: str | None = None
# Additional fields for editing
video_url: str | None = None
agent_output_demo_url: str | None = None
categories: list[str] = []
# Aggregated from AgentGraphExecutions and StoreListingReviews:
run_count: int = 0
review_count: int = 0
review_avg_rating: float = 0.0
@classmethod
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
"""Construct from the StoreSubmission Prisma view."""
return cls(
listing_id=_sub.listing_id,
user_id=_sub.user_id,
slug=_sub.slug,
listing_version_id=_sub.listing_version_id,
listing_version=_sub.listing_version,
graph_id=_sub.graph_id,
graph_version=_sub.graph_version,
name=_sub.name,
sub_heading=_sub.sub_heading,
description=_sub.description,
instructions=_sub.instructions,
categories=_sub.categories,
image_urls=_sub.image_urls,
video_url=_sub.video_url,
agent_output_demo_url=_sub.agent_output_demo_url,
submitted_at=_sub.submitted_at,
changes_summary=_sub.changes_summary,
status=_sub.status,
reviewed_at=_sub.reviewed_at,
reviewer_id=_sub.reviewer_id,
review_comments=_sub.review_comments,
run_count=_sub.run_count,
review_count=_sub.review_count,
review_avg_rating=_sub.review_avg_rating,
)
@classmethod
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
"""
Construct from the StoreListingVersion Prisma model (with StoreListing included)
"""
if not (_l := _lv.StoreListing):
raise ValueError("StoreListingVersion must have included StoreListing")
return cls(
listing_id=_l.id,
user_id=_l.owningUserId,
slug=_l.slug,
listing_version_id=_lv.id,
listing_version=_lv.version,
graph_id=_lv.agentGraphId,
graph_version=_lv.agentGraphVersion,
name=_lv.name,
sub_heading=_lv.subHeading,
description=_lv.description,
instructions=_lv.instructions,
categories=_lv.categories,
image_urls=_lv.imageUrls,
video_url=_lv.videoUrl,
agent_output_demo_url=_lv.agentOutputDemoUrl,
submitted_at=_lv.submittedAt,
changes_summary=_lv.changesSummary,
status=_lv.submissionStatus,
reviewed_at=_lv.reviewedAt,
reviewer_id=_lv.reviewerId,
review_comments=_lv.reviewComments,
)
class StoreSubmissionsResponse(pydantic.BaseModel):
@@ -144,33 +273,12 @@ class StoreSubmissionsResponse(pydantic.BaseModel):
pagination: Pagination
class StoreListingWithVersions(pydantic.BaseModel):
"""A store listing with its version history"""
listing_id: str
slug: str
agent_id: str
agent_version: int
active_version_id: str | None = None
has_approved_version: bool = False
creator_email: str | None = None
latest_version: StoreSubmission | None = None
versions: list[StoreSubmission] = []
class StoreListingsWithVersionsResponse(pydantic.BaseModel):
"""Response model for listings with version history"""
listings: list[StoreListingWithVersions]
pagination: Pagination
class StoreSubmissionRequest(pydantic.BaseModel):
agent_id: str = pydantic.Field(
..., min_length=1, description="Agent ID cannot be empty"
graph_id: str = pydantic.Field(
..., min_length=1, description="Graph ID cannot be empty"
)
agent_version: int = pydantic.Field(
..., gt=0, description="Agent version must be greater than 0"
graph_version: int = pydantic.Field(
..., gt=0, description="Graph version must be greater than 0"
)
slug: str
name: str
@@ -198,12 +306,42 @@ class StoreSubmissionEditRequest(pydantic.BaseModel):
recommended_schedule_cron: str | None = None
class ProfileDetails(pydantic.BaseModel):
name: str
username: str
description: str
links: list[str]
avatar_url: str | None = None
class StoreSubmissionAdminView(StoreSubmission):
internal_comments: str | None # Private admin notes
@classmethod
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
return cls(
**StoreSubmission.from_db(_sub).model_dump(),
internal_comments=_sub.internal_comments,
)
@classmethod
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
return cls(
**StoreSubmission.from_listing_version(_lv).model_dump(),
internal_comments=_lv.internalComments,
)
class StoreListingWithVersionsAdminView(pydantic.BaseModel):
"""A store listing with its version history"""
listing_id: str
graph_id: str
slug: str
active_listing_version_id: str | None = None
has_approved_version: bool = False
creator_email: str | None = None
latest_version: StoreSubmissionAdminView | None = None
versions: list[StoreSubmissionAdminView] = []
class StoreListingsWithVersionsAdminViewResponse(pydantic.BaseModel):
"""Response model for listings with version history"""
listings: list[StoreListingWithVersionsAdminView]
pagination: Pagination
class StoreReview(pydantic.BaseModel):

View File

@@ -1,203 +0,0 @@
import datetime
import prisma.enums
from . import model as store_model
def test_pagination():
pagination = store_model.Pagination(
total_items=100, total_pages=5, current_page=2, page_size=20
)
assert pagination.total_items == 100
assert pagination.total_pages == 5
assert pagination.current_page == 2
assert pagination.page_size == 20
def test_store_agent():
agent = store_model.StoreAgent(
slug="test-agent",
agent_name="Test Agent",
agent_image="test.jpg",
creator="creator1",
creator_avatar="avatar.jpg",
sub_heading="Test subheading",
description="Test description",
runs=50,
rating=4.5,
agent_graph_id="test-graph-id",
)
assert agent.slug == "test-agent"
assert agent.agent_name == "Test Agent"
assert agent.runs == 50
assert agent.rating == 4.5
assert agent.agent_graph_id == "test-graph-id"
def test_store_agents_response():
response = store_model.StoreAgentsResponse(
agents=[
store_model.StoreAgent(
slug="test-agent",
agent_name="Test Agent",
agent_image="test.jpg",
creator="creator1",
creator_avatar="avatar.jpg",
sub_heading="Test subheading",
description="Test description",
runs=50,
rating=4.5,
agent_graph_id="test-graph-id",
)
],
pagination=store_model.Pagination(
total_items=1, total_pages=1, current_page=1, page_size=20
),
)
assert len(response.agents) == 1
assert response.pagination.total_items == 1
def test_store_agent_details():
details = store_model.StoreAgentDetails(
store_listing_version_id="version123",
slug="test-agent",
agent_name="Test Agent",
agent_video="video.mp4",
agent_output_demo="demo.mp4",
agent_image=["image1.jpg", "image2.jpg"],
creator="creator1",
creator_avatar="avatar.jpg",
sub_heading="Test subheading",
description="Test description",
categories=["cat1", "cat2"],
runs=50,
rating=4.5,
versions=["1.0", "2.0"],
agentGraphVersions=["1", "2"],
agentGraphId="test-graph-id",
last_updated=datetime.datetime.now(),
)
assert details.slug == "test-agent"
assert len(details.agent_image) == 2
assert len(details.categories) == 2
assert len(details.versions) == 2
def test_creator():
creator = store_model.Creator(
agent_rating=4.8,
agent_runs=1000,
name="Test Creator",
username="creator1",
description="Test description",
avatar_url="avatar.jpg",
num_agents=5,
is_featured=False,
)
assert creator.name == "Test Creator"
assert creator.num_agents == 5
def test_creators_response():
response = store_model.CreatorsResponse(
creators=[
store_model.Creator(
agent_rating=4.8,
agent_runs=1000,
name="Test Creator",
username="creator1",
description="Test description",
avatar_url="avatar.jpg",
num_agents=5,
is_featured=False,
)
],
pagination=store_model.Pagination(
total_items=1, total_pages=1, current_page=1, page_size=20
),
)
assert len(response.creators) == 1
assert response.pagination.total_items == 1
def test_creator_details():
details = store_model.CreatorDetails(
name="Test Creator",
username="creator1",
description="Test description",
links=["link1.com", "link2.com"],
avatar_url="avatar.jpg",
agent_rating=4.8,
agent_runs=1000,
top_categories=["cat1", "cat2"],
)
assert details.name == "Test Creator"
assert len(details.links) == 2
assert details.agent_rating == 4.8
assert len(details.top_categories) == 2
def test_store_submission():
submission = store_model.StoreSubmission(
listing_id="listing123",
agent_id="agent123",
agent_version=1,
sub_heading="Test subheading",
name="Test Agent",
slug="test-agent",
description="Test description",
image_urls=["image1.jpg", "image2.jpg"],
date_submitted=datetime.datetime(2023, 1, 1),
status=prisma.enums.SubmissionStatus.PENDING,
runs=50,
rating=4.5,
)
assert submission.name == "Test Agent"
assert len(submission.image_urls) == 2
assert submission.status == prisma.enums.SubmissionStatus.PENDING
def test_store_submissions_response():
response = store_model.StoreSubmissionsResponse(
submissions=[
store_model.StoreSubmission(
listing_id="listing123",
agent_id="agent123",
agent_version=1,
sub_heading="Test subheading",
name="Test Agent",
slug="test-agent",
description="Test description",
image_urls=["image1.jpg"],
date_submitted=datetime.datetime(2023, 1, 1),
status=prisma.enums.SubmissionStatus.PENDING,
runs=50,
rating=4.5,
)
],
pagination=store_model.Pagination(
total_items=1, total_pages=1, current_page=1, page_size=20
),
)
assert len(response.submissions) == 1
assert response.pagination.total_items == 1
def test_store_submission_request():
request = store_model.StoreSubmissionRequest(
agent_id="agent123",
agent_version=1,
slug="test-agent",
name="Test Agent",
sub_heading="Test subheading",
video_url="video.mp4",
image_urls=["image1.jpg", "image2.jpg"],
description="Test description",
categories=["cat1", "cat2"],
)
assert request.agent_id == "agent123"
assert request.agent_version == 1
assert len(request.image_urls) == 2
assert len(request.categories) == 2

View File

@@ -1,16 +1,17 @@
import logging
import tempfile
import typing
import urllib.parse
from typing import Literal
import autogpt_libs.auth
import fastapi
import fastapi.responses
import prisma.enums
from fastapi import Query, Security
from pydantic import BaseModel
import backend.data.graph
import backend.util.json
from backend.util.exceptions import NotFoundError
from backend.util.models import Pagination
from . import cache as store_cache
@@ -34,22 +35,15 @@ router = fastapi.APIRouter()
"/profile",
summary="Get user profile",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.ProfileDetails,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def get_profile(
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Get the profile details for the authenticated user.
Cached for 1 hour per user.
"""
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.ProfileDetails:
"""Get the profile details for the authenticated user."""
profile = await store_db.get_user_profile(user_id)
if profile is None:
return fastapi.responses.JSONResponse(
status_code=404,
content={"detail": "Profile not found"},
)
raise NotFoundError("User does not have a profile yet")
return profile
@@ -57,98 +51,17 @@ async def get_profile(
"/profile",
summary="Update user profile",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.CreatorDetails,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def update_or_create_profile(
profile: store_model.Profile,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Update the store profile for the authenticated user.
Args:
profile (Profile): The updated profile details
user_id (str): ID of the authenticated user
Returns:
CreatorDetails: The updated profile
Raises:
HTTPException: If there is an error updating the profile
"""
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.ProfileDetails:
"""Update the store profile for the authenticated user."""
updated_profile = await store_db.update_profile(user_id=user_id, profile=profile)
return updated_profile
##############################################
############### Agent Endpoints ##############
##############################################
@router.get(
"/agents",
summary="List store agents",
tags=["store", "public"],
response_model=store_model.StoreAgentsResponse,
)
async def get_agents(
featured: bool = False,
creator: str | None = None,
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
search_query: str | None = None,
category: str | None = None,
page: int = 1,
page_size: int = 20,
):
"""
Get a paginated list of agents from the store with optional filtering and sorting.
Args:
featured (bool, optional): Filter to only show featured agents. Defaults to False.
creator (str | None, optional): Filter agents by creator username. Defaults to None.
sorted_by (str | None, optional): Sort agents by "runs" or "rating". Defaults to None.
search_query (str | None, optional): Search agents by name, subheading and description. Defaults to None.
category (str | None, optional): Filter agents by category. Defaults to None.
page (int, optional): Page number for pagination. Defaults to 1.
page_size (int, optional): Number of agents per page. Defaults to 20.
Returns:
StoreAgentsResponse: Paginated list of agents matching the filters
Raises:
HTTPException: If page or page_size are less than 1
Used for:
- Home Page Featured Agents
- Home Page Top Agents
- Search Results
- Agent Details - Other Agents By Creator
- Agent Details - Similar Agents
- Creator Details - Agents By Creator
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
agents = await store_cache._get_cached_store_agents(
featured=featured,
creator=creator,
sorted_by=sorted_by,
search_query=search_query,
category=category,
page=page,
page_size=page_size,
)
return agents
##############################################
############### Search Endpoints #############
##############################################
@@ -158,60 +71,30 @@ async def get_agents(
"/search",
summary="Unified search across all content types",
tags=["store", "public"],
response_model=store_model.UnifiedSearchResponse,
)
async def unified_search(
query: str,
content_types: list[str] | None = fastapi.Query(
content_types: list[prisma.enums.ContentType] | None = Query(
default=None,
description="Content types to search: STORE_AGENT, BLOCK, DOCUMENTATION. If not specified, searches all.",
description="Content types to search. If not specified, searches all.",
),
page: int = 1,
page_size: int = 20,
user_id: str | None = fastapi.Security(
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
user_id: str | None = Security(
autogpt_libs.auth.get_optional_user_id, use_cache=False
),
):
) -> store_model.UnifiedSearchResponse:
"""
Search across all content types (store agents, blocks, documentation) using hybrid search.
Search across all content types (marketplace agents, blocks, documentation)
using hybrid search.
Combines semantic (embedding-based) and lexical (text-based) search for best results.
Args:
query: The search query string
content_types: Optional list of content types to filter by (STORE_AGENT, BLOCK, DOCUMENTATION)
page: Page number for pagination (default 1)
page_size: Number of results per page (default 20)
user_id: Optional authenticated user ID (for user-scoped content in future)
Returns:
UnifiedSearchResponse: Paginated list of search results with relevance scores
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
# Convert string content types to enum
content_type_enums: list[prisma.enums.ContentType] | None = None
if content_types:
try:
content_type_enums = [prisma.enums.ContentType(ct) for ct in content_types]
except ValueError as e:
raise fastapi.HTTPException(
status_code=422,
detail=f"Invalid content type. Valid values: STORE_AGENT, BLOCK, DOCUMENTATION. Error: {e}",
)
# Perform unified hybrid search
results, total = await store_hybrid_search.unified_hybrid_search(
query=query,
content_types=content_type_enums,
content_types=content_types,
user_id=user_id,
page=page,
page_size=page_size,
@@ -245,22 +128,69 @@ async def unified_search(
)
##############################################
############### Agent Endpoints ##############
##############################################
@router.get(
"/agents",
summary="List store agents",
tags=["store", "public"],
)
async def get_agents(
featured: bool = Query(
default=False, description="Filter to only show featured agents"
),
creator: str | None = Query(
default=None, description="Filter agents by creator username"
),
category: str | None = Query(default=None, description="Filter agents by category"),
search_query: str | None = Query(
default=None, description="Literal + semantic search on names and descriptions"
),
sorted_by: store_db.StoreAgentsSortOptions | None = Query(
default=None,
description="Property to sort results by. Ignored if search_query is provided.",
),
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
) -> store_model.StoreAgentsResponse:
"""
Get a paginated list of agents from the marketplace,
with optional filtering and sorting.
Used for:
- Home Page Featured Agents
- Home Page Top Agents
- Search Results
- Agent Details - Other Agents By Creator
- Agent Details - Similar Agents
- Creator Details - Agents By Creator
"""
agents = await store_cache._get_cached_store_agents(
featured=featured,
creator=creator,
sorted_by=sorted_by,
search_query=search_query,
category=category,
page=page,
page_size=page_size,
)
return agents
@router.get(
"/agents/{username}/{agent_name}",
summary="Get specific agent",
tags=["store", "public"],
response_model=store_model.StoreAgentDetails,
)
async def get_agent(
async def get_agent_by_name(
username: str,
agent_name: str,
include_changelog: bool = fastapi.Query(default=False),
):
"""
This is only used on the AgentDetails Page.
It returns the store listing agents details.
"""
include_changelog: bool = Query(default=False),
) -> store_model.StoreAgentDetails:
"""Get details of a marketplace agent"""
username = urllib.parse.unquote(username).lower()
# URL decode the agent name since it comes from the URL path
agent_name = urllib.parse.unquote(agent_name).lower()
@@ -270,76 +200,82 @@ async def get_agent(
return agent
@router.get(
"/graph/{store_listing_version_id}",
summary="Get agent graph",
tags=["store"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
)
async def get_graph_meta_by_store_listing_version_id(
store_listing_version_id: str,
) -> backend.data.graph.GraphModelWithoutNodes:
"""
Get Agent Graph from Store Listing Version ID.
"""
graph = await store_db.get_available_graph(store_listing_version_id)
return graph
@router.get(
"/agents/{store_listing_version_id}",
summary="Get agent by version",
tags=["store"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreAgentDetails,
)
async def get_store_agent(store_listing_version_id: str):
"""
Get Store Agent Details from Store Listing Version ID.
"""
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
return agent
@router.post(
"/agents/{username}/{agent_name}/review",
summary="Create agent review",
tags=["store"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreReview,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def create_review(
async def post_user_review_for_agent(
username: str,
agent_name: str,
review: store_model.StoreReviewCreate,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Create a review for a store agent.
Args:
username: Creator's username
agent_name: Name/slug of the agent
review: Review details including score and optional comments
user_id: ID of authenticated user creating the review
Returns:
The created review
"""
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.StoreReview:
"""Post a user review on a marketplace agent listing"""
username = urllib.parse.unquote(username).lower()
agent_name = urllib.parse.unquote(agent_name).lower()
# Create the review
created_review = await store_db.create_store_review(
user_id=user_id,
store_listing_version_id=review.store_listing_version_id,
score=review.score,
comments=review.comments,
)
return created_review
@router.get(
"/listings/versions/{store_listing_version_id}",
summary="Get agent by version",
tags=["store"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def get_agent_by_listing_version(
store_listing_version_id: str,
) -> store_model.StoreAgentDetails:
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
return agent
@router.get(
"/listings/versions/{store_listing_version_id}/graph",
summary="Get agent graph",
tags=["store"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def get_graph_meta_by_store_listing_version_id(
store_listing_version_id: str,
) -> backend.data.graph.GraphModelWithoutNodes:
"""Get outline of graph belonging to a specific marketplace listing version"""
graph = await store_db.get_available_graph(store_listing_version_id)
return graph
@router.get(
"/listings/versions/{store_listing_version_id}/graph/download",
summary="Download agent file",
tags=["store", "public"],
)
async def download_agent_file(
store_listing_version_id: str,
) -> fastapi.responses.FileResponse:
"""Download agent graph file for a specific marketplace listing version"""
graph_data = await store_db.get_agent(store_listing_version_id)
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
# Sending graph as a stream (similar to marketplace v1)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as tmp_file:
tmp_file.write(backend.util.json.dumps(graph_data))
tmp_file.flush()
return fastapi.responses.FileResponse(
tmp_file.name, filename=file_name, media_type="application/json"
)
##############################################
############# Creator Endpoints #############
##############################################
@@ -349,37 +285,19 @@ async def create_review(
"/creators",
summary="List store creators",
tags=["store", "public"],
response_model=store_model.CreatorsResponse,
)
async def get_creators(
featured: bool = False,
search_query: str | None = None,
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
page: int = 1,
page_size: int = 20,
):
"""
This is needed for:
- Home Page Featured Creators
- Search Results Page
---
To support this functionality we need:
- featured: bool - to limit the list to just featured agents
- search_query: str - vector search based on the creators profile description.
- sorted_by: [agent_rating, agent_runs] -
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
featured: bool = Query(
default=False, description="Filter to only show featured creators"
),
search_query: str | None = Query(
default=None, description="Literal + semantic search on names and descriptions"
),
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
) -> store_model.CreatorsResponse:
"""List or search marketplace creators"""
creators = await store_cache._get_cached_store_creators(
featured=featured,
search_query=search_query,
@@ -391,18 +309,12 @@ async def get_creators(
@router.get(
"/creator/{username}",
"/creators/{username}",
summary="Get creator details",
tags=["store", "public"],
response_model=store_model.CreatorDetails,
)
async def get_creator(
username: str,
):
"""
Get the details of a creator.
- Creator Details Page
"""
async def get_creator(username: str) -> store_model.CreatorDetails:
"""Get details on a marketplace creator"""
username = urllib.parse.unquote(username).lower()
creator = await store_cache._get_cached_creator_details(username=username)
return creator
@@ -414,20 +326,17 @@ async def get_creator(
@router.get(
"/myagents",
"/my-unpublished-agents",
summary="Get my agents",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.MyAgentsResponse,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def get_my_agents(
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
page: typing.Annotated[int, fastapi.Query(ge=1)] = 1,
page_size: typing.Annotated[int, fastapi.Query(ge=1)] = 20,
):
"""
Get user's own agents.
"""
async def get_my_unpublished_agents(
user_id: str = Security(autogpt_libs.auth.get_user_id),
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
) -> store_model.MyUnpublishedAgentsResponse:
"""List the authenticated user's unpublished agents"""
agents = await store_db.get_my_agents(user_id, page=page, page_size=page_size)
return agents
@@ -436,28 +345,17 @@ async def get_my_agents(
"/submissions/{submission_id}",
summary="Delete store submission",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=bool,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def delete_submission(
submission_id: str,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Delete a store listing submission.
Args:
user_id (str): ID of the authenticated user
submission_id (str): ID of the submission to be deleted
Returns:
bool: True if the submission was successfully deleted, False otherwise
"""
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> bool:
"""Delete a marketplace listing submission"""
result = await store_db.delete_store_submission(
user_id=user_id,
submission_id=submission_id,
)
return result
@@ -465,37 +363,14 @@ async def delete_submission(
"/submissions",
summary="List my submissions",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreSubmissionsResponse,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def get_submissions(
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
page: int = 1,
page_size: int = 20,
):
"""
Get a paginated list of store submissions for the authenticated user.
Args:
user_id (str): ID of the authenticated user
page (int, optional): Page number for pagination. Defaults to 1.
page_size (int, optional): Number of submissions per page. Defaults to 20.
Returns:
StoreListingsResponse: Paginated list of store submissions
Raises:
HTTPException: If page or page_size are less than 1
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
user_id: str = Security(autogpt_libs.auth.get_user_id),
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
) -> store_model.StoreSubmissionsResponse:
"""List the authenticated user's marketplace listing submissions"""
listings = await store_db.get_store_submissions(
user_id=user_id,
page=page,
@@ -508,30 +383,17 @@ async def get_submissions(
"/submissions",
summary="Create store submission",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreSubmission,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def create_submission(
submission_request: store_model.StoreSubmissionRequest,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Create a new store listing submission.
Args:
submission_request (StoreSubmissionRequest): The submission details
user_id (str): ID of the authenticated user submitting the listing
Returns:
StoreSubmission: The created store submission
Raises:
HTTPException: If there is an error creating the submission
"""
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.StoreSubmission:
"""Submit a new marketplace listing for review"""
result = await store_db.create_store_submission(
user_id=user_id,
agent_id=submission_request.agent_id,
agent_version=submission_request.agent_version,
graph_id=submission_request.graph_id,
graph_version=submission_request.graph_version,
slug=submission_request.slug,
name=submission_request.name,
video_url=submission_request.video_url,
@@ -544,7 +406,6 @@ async def create_submission(
changes_summary=submission_request.changes_summary or "Initial Submission",
recommended_schedule_cron=submission_request.recommended_schedule_cron,
)
return result
@@ -552,28 +413,14 @@ async def create_submission(
"/submissions/{store_listing_version_id}",
summary="Edit store submission",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreSubmission,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def edit_submission(
store_listing_version_id: str,
submission_request: store_model.StoreSubmissionEditRequest,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Edit an existing store listing submission.
Args:
store_listing_version_id (str): ID of the store listing version to edit
submission_request (StoreSubmissionRequest): The updated submission details
user_id (str): ID of the authenticated user editing the listing
Returns:
StoreSubmission: The updated store submission
Raises:
HTTPException: If there is an error editing the submission
"""
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.StoreSubmission:
"""Update a pending marketplace listing submission"""
result = await store_db.edit_store_submission(
user_id=user_id,
store_listing_version_id=store_listing_version_id,
@@ -588,7 +435,6 @@ async def edit_submission(
changes_summary=submission_request.changes_summary,
recommended_schedule_cron=submission_request.recommended_schedule_cron,
)
return result
@@ -596,115 +442,61 @@ async def edit_submission(
"/submissions/media",
summary="Upload submission media",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def upload_submission_media(
file: fastapi.UploadFile,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Upload media (images/videos) for a store listing submission.
Args:
file (UploadFile): The media file to upload
user_id (str): ID of the authenticated user uploading the media
Returns:
str: URL of the uploaded media file
Raises:
HTTPException: If there is an error uploading the media
"""
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> str:
"""Upload media for a marketplace listing submission"""
media_url = await store_media.upload_media(user_id=user_id, file=file)
return media_url
class ImageURLResponse(BaseModel):
image_url: str
@router.post(
"/submissions/generate_image",
summary="Generate submission image",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def generate_image(
agent_id: str,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
) -> fastapi.responses.Response:
graph_id: str,
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> ImageURLResponse:
"""
Generate an image for a store listing submission.
Args:
agent_id (str): ID of the agent to generate an image for
user_id (str): ID of the authenticated user
Returns:
JSONResponse: JSON containing the URL of the generated image
Generate an image for a marketplace listing submission based on the properties
of a given graph.
"""
agent = await backend.data.graph.get_graph(
graph_id=agent_id, version=None, user_id=user_id
graph = await backend.data.graph.get_graph(
graph_id=graph_id, version=None, user_id=user_id
)
if not agent:
raise fastapi.HTTPException(
status_code=404, detail=f"Agent with ID {agent_id} not found"
)
if not graph:
raise NotFoundError(f"Agent graph #{graph_id} not found")
# Use .jpeg here since we are generating JPEG images
filename = f"agent_{agent_id}.jpeg"
filename = f"agent_{graph_id}.jpeg"
existing_url = await store_media.check_media_exists(user_id, filename)
if existing_url:
logger.info(f"Using existing image for agent {agent_id}")
return fastapi.responses.JSONResponse(content={"image_url": existing_url})
logger.info(f"Using existing image for agent graph {graph_id}")
return ImageURLResponse(image_url=existing_url)
# Generate agent image as JPEG
image = await store_image_gen.generate_agent_image(agent=agent)
image = await store_image_gen.generate_agent_image(agent=graph)
# Create UploadFile with the correct filename and content_type
image_file = fastapi.UploadFile(
file=image,
filename=filename,
)
image_url = await store_media.upload_media(
user_id=user_id, file=image_file, use_file_name=True
)
return fastapi.responses.JSONResponse(content={"image_url": image_url})
@router.get(
"/download/agents/{store_listing_version_id}",
summary="Download agent file",
tags=["store", "public"],
)
async def download_agent_file(
store_listing_version_id: str = fastapi.Path(
..., description="The ID of the agent to download"
),
) -> fastapi.responses.FileResponse:
"""
Download the agent file by streaming its content.
Args:
store_listing_version_id (str): The ID of the agent to download
Returns:
StreamingResponse: A streaming response containing the agent's graph data.
Raises:
HTTPException: If the agent is not found or an unexpected error occurs.
"""
graph_data = await store_db.get_agent(store_listing_version_id)
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
# Sending graph as a stream (similar to marketplace v1)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as tmp_file:
tmp_file.write(backend.util.json.dumps(graph_data))
tmp_file.flush()
return fastapi.responses.FileResponse(
tmp_file.name, filename=file_name, media_type="application/json"
)
return ImageURLResponse(image_url=image_url)
##############################################

View File

@@ -8,6 +8,8 @@ import pytest
import pytest_mock
from pytest_snapshot.plugin import Snapshot
from backend.api.features.store.db import StoreAgentsSortOptions
from . import model as store_model
from . import routes as store_routes
@@ -196,7 +198,7 @@ def test_get_agents_sorted(
mock_db_call.assert_called_once_with(
featured=False,
creators=None,
sorted_by="runs",
sorted_by=StoreAgentsSortOptions.RUNS,
search_query=None,
category=None,
page=1,
@@ -380,9 +382,11 @@ def test_get_agent_details(
runs=100,
rating=4.5,
versions=["1.0.0", "1.1.0"],
agentGraphVersions=["1", "2"],
agentGraphId="test-graph-id",
graph_versions=["1", "2"],
graph_id="test-graph-id",
last_updated=FIXED_NOW,
active_version_id="test-version-id",
has_approved_version=True,
)
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agent_details")
mock_db_call.return_value = mocked_value
@@ -435,15 +439,17 @@ def test_get_creators_pagination(
) -> None:
mocked_value = store_model.CreatorsResponse(
creators=[
store_model.Creator(
store_model.CreatorDetails(
name=f"Creator {i}",
username=f"creator{i}",
description=f"Creator {i} description",
avatar_url=f"avatar{i}.jpg",
num_agents=1,
agent_rating=4.5,
agent_runs=100,
description=f"Creator {i} description",
links=[f"user{i}.link.com"],
is_featured=False,
num_agents=1,
agent_runs=100,
agent_rating=4.5,
top_categories=["cat1", "cat2", "cat3"],
)
for i in range(5)
],
@@ -496,19 +502,19 @@ def test_get_creator_details(
mocked_value = store_model.CreatorDetails(
name="Test User",
username="creator1",
avatar_url="avatar.jpg",
description="Test creator description",
links=["link1.com", "link2.com"],
avatar_url="avatar.jpg",
agent_rating=4.8,
is_featured=True,
num_agents=5,
agent_runs=1000,
agent_rating=4.8,
top_categories=["category1", "category2"],
)
mock_db_call = mocker.patch(
"backend.api.features.store.db.get_store_creator_details"
)
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_creator")
mock_db_call.return_value = mocked_value
response = client.get("/creator/creator1")
response = client.get("/creators/creator1")
assert response.status_code == 200
data = store_model.CreatorDetails.model_validate(response.json())
@@ -528,19 +534,26 @@ def test_get_submissions_success(
submissions=[
store_model.StoreSubmission(
listing_id="test-listing-id",
name="Test Agent",
description="Test agent description",
image_urls=["test.jpg"],
date_submitted=FIXED_NOW,
status=prisma.enums.SubmissionStatus.APPROVED,
runs=50,
rating=4.2,
agent_id="test-agent-id",
agent_version=1,
sub_heading="Test agent subheading",
user_id="test-user-id",
slug="test-agent",
video_url="test.mp4",
listing_version_id="test-version-id",
listing_version=1,
graph_id="test-agent-id",
graph_version=1,
name="Test Agent",
sub_heading="Test agent subheading",
description="Test agent description",
instructions="Click the button!",
categories=["test-category"],
image_urls=["test.jpg"],
video_url="test.mp4",
agent_output_demo_url="demo_video.mp4",
submitted_at=FIXED_NOW,
changes_summary="Initial Submission",
status=prisma.enums.SubmissionStatus.APPROVED,
run_count=50,
review_count=5,
review_avg_rating=4.2,
)
],
pagination=store_model.Pagination(

View File

@@ -11,6 +11,7 @@ import pytest
from backend.util.models import Pagination
from . import cache as store_cache
from .db import StoreAgentsSortOptions
from .model import StoreAgent, StoreAgentsResponse
@@ -215,7 +216,7 @@ class TestCacheDeletion:
await store_cache._get_cached_store_agents(
featured=True,
creator="testuser",
sorted_by="rating",
sorted_by=StoreAgentsSortOptions.RATING,
search_query="AI assistant",
category="productivity",
page=2,
@@ -227,7 +228,7 @@ class TestCacheDeletion:
deleted = store_cache._get_cached_store_agents.cache_delete(
featured=True,
creator="testuser",
sorted_by="rating",
sorted_by=StoreAgentsSortOptions.RATING,
search_query="AI assistant",
category="productivity",
page=2,
@@ -239,7 +240,7 @@ class TestCacheDeletion:
deleted = store_cache._get_cached_store_agents.cache_delete(
featured=True,
creator="testuser",
sorted_by="rating",
sorted_by=StoreAgentsSortOptions.RATING,
search_query="AI assistant",
category="productivity",
page=2,

View File

@@ -55,6 +55,7 @@ from backend.data.credit import (
set_auto_top_up,
)
from backend.data.graph import GraphSettings
from backend.data.invited_user import get_or_activate_user
from backend.data.model import CredentialsMetaInput, UserOnboarding
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.data.onboarding import (
@@ -70,7 +71,6 @@ from backend.data.onboarding import (
update_user_onboarding,
)
from backend.data.user import (
get_or_create_user,
get_user_by_id,
get_user_notification_preference,
update_user_email,
@@ -136,12 +136,10 @@ _tally_background_tasks: set[asyncio.Task] = set()
dependencies=[Security(requires_user)],
)
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
user = await get_or_create_user(user_data)
user = await get_or_activate_user(user_data)
# Fire-and-forget: populate business understanding from Tally form.
# We use created_at proximity instead of an is_new flag because
# get_or_create_user is cached — a separate is_new return value would be
# unreliable on repeated calls within the cache TTL.
# Fire-and-forget: backfill Tally understanding when invite pre-seeding did
# not produce a stored result before first activation.
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
if age_seconds < 30:
try:
@@ -165,8 +163,11 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
dependencies=[Security(requires_user)],
)
async def update_user_email_route(
user_id: Annotated[str, Security(get_user_id)], email: str = Body(...)
user_id: Annotated[str, Security(get_user_id)],
email: str = Body(...),
user_data: dict = Security(get_jwt_payload),
) -> dict[str, str]:
await get_or_activate_user(user_data)
await update_user_email(user_id, email)
return {"email": email}
@@ -182,7 +183,7 @@ async def get_user_timezone_route(
user_data: dict = Security(get_jwt_payload),
) -> TimezoneResponse:
"""Get user timezone setting."""
user = await get_or_create_user(user_data)
user = await get_or_activate_user(user_data)
return TimezoneResponse(timezone=user.timezone)
@@ -193,9 +194,12 @@ async def get_user_timezone_route(
dependencies=[Security(requires_user)],
)
async def update_user_timezone_route(
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
user_id: Annotated[str, Security(get_user_id)],
request: UpdateTimezoneRequest,
user_data: dict = Security(get_jwt_payload),
) -> TimezoneResponse:
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
await get_or_activate_user(user_data)
user = await update_user_timezone(user_id, str(request.timezone))
return TimezoneResponse(timezone=user.timezone)
@@ -208,7 +212,9 @@ async def update_user_timezone_route(
)
async def get_preferences(
user_id: Annotated[str, Security(get_user_id)],
user_data: dict = Security(get_jwt_payload),
) -> NotificationPreference:
await get_or_activate_user(user_data)
preferences = await get_user_notification_preference(user_id)
return preferences
@@ -222,7 +228,9 @@ async def get_preferences(
async def update_preferences(
user_id: Annotated[str, Security(get_user_id)],
preferences: NotificationPreferenceDTO = Body(...),
user_data: dict = Security(get_jwt_payload),
) -> NotificationPreference:
await get_or_activate_user(user_data)
output = await update_user_notification_preference(user_id, preferences)
return output
@@ -449,7 +457,6 @@ async def execute_graph_block(
async def upload_file(
user_id: Annotated[str, Security(get_user_id)],
file: UploadFile = File(...),
provider: str = "gcs",
expiration_hours: int = 24,
) -> UploadFileResponse:
"""
@@ -512,7 +519,6 @@ async def upload_file(
storage_path = await cloud_storage.store_file(
content=content,
filename=file_name,
provider=provider,
expiration_hours=expiration_hours,
user_id=user_id,
)

View File

@@ -51,7 +51,7 @@ def test_get_or_create_user_route(
}
mocker.patch(
"backend.api.features.v1.get_or_create_user",
"backend.api.features.v1.get_or_activate_user",
return_value=mock_user,
)
@@ -515,7 +515,6 @@ async def test_upload_file_success(test_user_id: str):
result = await upload_file(
file=upload_file_mock,
user_id=test_user_id,
provider="gcs",
expiration_hours=24,
)
@@ -533,7 +532,6 @@ async def test_upload_file_success(test_user_id: str):
mock_handler.store_file.assert_called_once_with(
content=file_content,
filename="test.txt",
provider="gcs",
expiration_hours=24,
user_id=test_user_id,
)

View File

@@ -19,6 +19,7 @@ from prisma.errors import PrismaError
import backend.api.features.admin.credit_admin_routes
import backend.api.features.admin.execution_analytics_routes
import backend.api.features.admin.store_admin_routes
import backend.api.features.admin.user_admin_routes
import backend.api.features.builder
import backend.api.features.builder.routes
import backend.api.features.chat.routes as chat_routes
@@ -55,6 +56,7 @@ from backend.util.exceptions import (
MissingConfigError,
NotAuthorizedError,
NotFoundError,
PreconditionFailed,
)
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
from backend.util.service import UnhealthyServiceError
@@ -275,6 +277,7 @@ app.add_exception_handler(RequestValidationError, validation_error_handler)
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
app.add_exception_handler(ValueError, handle_internal_http_error(400))
app.add_exception_handler(PreconditionFailed, handle_internal_http_error(428))
app.add_exception_handler(Exception, handle_internal_http_error(500))
app.include_router(backend.api.features.v1.v1_router, tags=["v1"], prefix="/api")
@@ -309,6 +312,11 @@ app.include_router(
tags=["v2", "admin"],
prefix="/api/executions",
)
app.include_router(
backend.api.features.admin.user_admin_routes.router,
tags=["v2", "admin"],
prefix="/api/users",
)
app.include_router(
backend.api.features.executions.review.routes.router,
tags=["v2", "executions", "review"],

View File

@@ -116,6 +116,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
CLAUDE_4_6_OPUS = "claude-opus-4-6"
CLAUDE_4_6_SONNET = "claude-sonnet-4-6"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
# AI/ML API models
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
@@ -274,6 +275,9 @@ MODEL_METADATA = {
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
), # claude-opus-4-6
LlmModel.CLAUDE_4_6_SONNET: ModelMetadata(
"anthropic", 200000, 64000, "Claude Sonnet 4.6", "Anthropic", "Anthropic", 3
), # claude-sonnet-4-6
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
), # claude-opus-4-5-20251101

View File

@@ -83,7 +83,8 @@ class StagehandRecommendedLlmModel(str, Enum):
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
# Anthropic
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929" # Keep for backwards compat
CLAUDE_4_6_SONNET = "claude-sonnet-4-6"
@property
def provider_name(self) -> str:
@@ -137,7 +138,7 @@ class StagehandObserveBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
@@ -227,7 +228,7 @@ class StagehandActBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
@@ -324,7 +325,7 @@ class StagehandExtractBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()

View File

@@ -1,8 +1,8 @@
import logging
from typing import Literal
from pydantic import BaseModel
from backend.api.features.store.db import StoreAgentsSortOptions
from backend.blocks._base import (
Block,
BlockCategory,
@@ -176,8 +176,8 @@ class SearchStoreAgentsBlock(Block):
category: str | None = SchemaField(
description="Filter by category", default=None
)
sort_by: Literal["rating", "runs", "name", "updated_at"] = SchemaField(
description="How to sort the results", default="rating"
sort_by: StoreAgentsSortOptions = SchemaField(
description="How to sort the results", default=StoreAgentsSortOptions.RATING
)
limit: int = SchemaField(
description="Maximum number of results to return", default=10, ge=1, le=100
@@ -278,7 +278,7 @@ class SearchStoreAgentsBlock(Block):
self,
query: str | None = None,
category: str | None = None,
sort_by: Literal["rating", "runs", "name", "updated_at"] = "rating",
sort_by: StoreAgentsSortOptions = StoreAgentsSortOptions.RATING,
limit: int = 10,
) -> SearchAgentsResponse:
"""

View File

@@ -2,6 +2,7 @@ from unittest.mock import MagicMock
import pytest
from backend.api.features.store.db import StoreAgentsSortOptions
from backend.blocks.system.library_operations import (
AddToLibraryFromStoreBlock,
LibraryAgent,
@@ -121,7 +122,10 @@ async def test_search_store_agents_block(mocker):
)
input_data = block.Input(
query="test", category="productivity", sort_by="rating", limit=10
query="test",
category="productivity",
sort_by=StoreAgentsSortOptions.RATING, # type: ignore[reportArgumentType]
limit=10,
)
outputs = {}

View File

@@ -22,6 +22,7 @@ from backend.copilot.model import (
update_session_title,
upsert_chat_session,
)
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.response_model import (
StreamBaseResponse,
StreamError,
@@ -62,8 +63,8 @@ async def _update_title_async(
"""Generate and persist a session title in the background."""
try:
title = await _generate_session_title(message, user_id, session_id)
if title:
await update_session_title(session_id, title)
if title and user_id:
await update_session_title(session_id, user_id, title, only_if_empty=True)
except Exception as e:
logger.warning("[Baseline] Failed to update session title: %s", e)
@@ -176,14 +177,17 @@ async def stream_chat_completion_baseline(
# changes from concurrent chats updating business understanding.
is_first_turn = len(session.messages) <= 1
if is_first_turn:
system_prompt, _ = await _build_system_prompt(
base_system_prompt, _ = await _build_system_prompt(
user_id, has_conversation_history=False
)
else:
system_prompt, _ = await _build_system_prompt(
base_system_prompt, _ = await _build_system_prompt(
user_id=None, has_conversation_history=True
)
# Append tool documentation and technical notes
system_prompt = base_system_prompt + get_baseline_supplement()
# Compress context if approaching the model's token limit
messages_for_context = await _compress_session_messages(session.messages)

View File

@@ -81,6 +81,35 @@ async def update_chat_session(
return ChatSession.from_db(session) if session else None
async def update_chat_session_title(
session_id: str,
user_id: str,
title: str,
*,
only_if_empty: bool = False,
) -> bool:
"""Update the title of a chat session, scoped to the owning user.
Always filters by (session_id, user_id) so callers cannot mutate another
user's session even when they know the session_id.
Args:
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
guard so auto-generated titles never overwrite a user-set title.
Returns True if a row was updated, False otherwise (session not found,
wrong user, or — when only_if_empty — title was already set).
"""
where: ChatSessionWhereInput = {"id": session_id, "userId": user_id}
if only_if_empty:
where["title"] = None
result = await PrismaChatSession.prisma().update_many(
where=where,
data={"title": title, "updatedAt": datetime.now(UTC)},
)
return result > 0
async def add_chat_message(
session_id: str,
role: str,

View File

@@ -469,8 +469,16 @@ async def upsert_chat_session(
)
db_error = e
# Save to cache (best-effort, even if DB failed)
# Save to cache (best-effort, even if DB failed).
# Title updates (update_session_title) run *outside* this lock because
# they only touch the title field, not messages. So a concurrent rename
# or auto-title may have written a newer title to Redis while this
# upsert was in progress. Always prefer the cached title to avoid
# overwriting it with the stale in-memory copy.
try:
existing_cached = await _get_session_from_cache(session.session_id)
if existing_cached and existing_cached.title:
session = session.model_copy(update={"title": existing_cached.title})
await cache_chat_session(session)
except Exception as e:
# If DB succeeded but cache failed, raise cache error
@@ -685,30 +693,48 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
return True
async def update_session_title(session_id: str, title: str) -> bool:
"""Update only the title of a chat session.
async def update_session_title(
session_id: str,
user_id: str,
title: str,
*,
only_if_empty: bool = False,
) -> bool:
"""Update the title of a chat session, scoped to the owning user.
This is a lightweight operation that doesn't touch messages, avoiding
race conditions with concurrent message updates. Use this for background
title generation instead of upsert_chat_session.
Lightweight operation that doesn't touch messages, avoiding race conditions
with concurrent message updates.
Args:
session_id: The session ID to update.
user_id: Owning user — the DB query filters on this.
title: The new title to set.
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
so auto-generated titles never overwrite a user-set title.
Returns:
True if updated successfully, False otherwise.
True if updated successfully, False otherwise (not found, wrong user,
or — when only_if_empty — title was already set).
"""
try:
result = await chat_db().update_chat_session(session_id=session_id, title=title)
if result is None:
logger.warning(f"Session {session_id} not found for title update")
updated = await chat_db().update_chat_session_title(
session_id, user_id, title, only_if_empty=only_if_empty
)
if not updated:
return False
# Invalidate the cache so the next access reloads from DB with the
# updated title. This avoids a read-modify-write on the full session
# blob, which could overwrite concurrent message updates.
await invalidate_session_cache(session_id)
# Update title in cache if it exists (instead of invalidating).
# This prevents race conditions where cache invalidation causes
# the frontend to see stale DB data while streaming is still in progress.
try:
cached = await _get_session_from_cache(session_id)
if cached:
cached.title = title
await cache_chat_session(cached)
except Exception as e:
logger.warning(
f"Cache title update failed for session {session_id} (non-critical): {e}"
)
return True
except Exception as e:

View File

@@ -1,29 +0,0 @@
"""Prompt constants for CoPilot - workflow guidance and supplementary documentation.
This module contains workflow patterns and guidance that supplement the main system prompt.
These are appended dynamically to the prompt along with auto-generated tool documentation.
"""
# Workflow guidance for key tool patterns
# This is appended after the auto-generated tool list to provide usage patterns
KEY_WORKFLOWS = """
## KEY WORKFLOWS
### MCP Integration Workflow
When using `run_mcp_tool`:
1. **Known servers** (use directly): Notion (https://mcp.notion.com/mcp), Linear (https://mcp.linear.app/mcp), Stripe (https://mcp.stripe.com), Intercom (https://mcp.intercom.com/mcp), Cloudflare (https://mcp.cloudflare.com/mcp), Atlassian (https://mcp.atlassian.com/mcp)
2. **Unknown servers**: Use `web_search("{{service}} MCP server URL")` to find the endpoint
3. **Discovery**: Call `run_mcp_tool(server_url)` to see available tools
4. **Execution**: Call `run_mcp_tool(server_url, tool_name, tool_arguments)`
5. **Authentication**: If credentials needed, user will be prompted. When they confirm, retry immediately with same arguments.
### Agent Creation Workflow
When using `create_agent`:
1. Always check `find_library_agent` first for existing solutions
2. Call `create_agent` with description
3. **If `suggested_goal` returned**: Present to user, ask for confirmation, call again with suggested goal if accepted
4. **If `clarifying_questions` returned**: After user answers, call again with original description AND answers in `context` parameter
### Folder Management
Use folder tools (`create_folder`, `list_folders`, `move_agents_to_folder`) to organize agents in the user's library for better discoverability."""

View File

@@ -0,0 +1,191 @@
"""Centralized prompt building logic for CoPilot.
This module contains all prompt construction functions and constants,
handling the distinction between:
- SDK mode vs Baseline mode (tool documentation needs)
- Local mode vs E2B mode (storage/filesystem differences)
"""
from backend.copilot.tools import TOOL_REGISTRY
# Shared technical notes that apply to both SDK and baseline modes
_SHARED_TOOL_NOTES = """\
### Sharing files with the user
After saving a file to the persistent workspace with `write_workspace_file`,
share it with the user by embedding the `download_url` from the response in
your message as a Markdown link or image:
- **Any file** — shows as a clickable download link:
`[report.csv](workspace://file_id#text/csv)`
- **Image** — renders inline in chat:
`![chart](workspace://file_id#image/png)`
- **Video** — renders inline in chat with player controls:
`![recording](workspace://file_id#video/mp4)`
The `download_url` field in the `write_workspace_file` response is already
in the correct format — paste it directly after the `(` in the Markdown.
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.
All tasks must run in the foreground.
"""
# Environment-specific supplement templates
def _build_storage_supplement(
working_dir: str,
sandbox_type: str,
storage_system_1_name: str,
storage_system_1_characteristics: list[str],
storage_system_1_persistence: list[str],
file_move_name_1_to_2: str,
file_move_name_2_to_1: str,
) -> str:
"""Build storage/filesystem supplement for a specific environment.
Template function handles all formatting (bullets, indentation, markdown).
Callers provide clean data as lists of strings.
Args:
working_dir: Working directory path
sandbox_type: Description of bash_exec sandbox
storage_system_1_name: Name of primary storage (ephemeral or cloud)
storage_system_1_characteristics: List of characteristic descriptions
storage_system_1_persistence: List of persistence behavior descriptions
file_move_name_1_to_2: Direction label for primary→persistent
file_move_name_2_to_1: Direction label for persistent→primary
"""
# Format lists as bullet points with proper indentation
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
persistence = "\n".join(f" - {p}" for p in storage_system_1_persistence)
return f"""
## Tool notes
### Shell commands
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
for shell commands — it runs {sandbox_type}.
### Working directory
- Your working directory is: `{working_dir}`
- All SDK file tools AND `bash_exec` operate on the same filesystem
- Use relative paths or absolute paths under `{working_dir}` for all file operations
### Two storage systems — CRITICAL to understand
1. **{storage_system_1_name}** (`{working_dir}`):
{characteristics}
{persistence}
2. **Persistent workspace** (cloud storage):
- Files here **survive across sessions indefinitely**
### Moving files between storages
- **{file_move_name_1_to_2}**: Copy to persistent workspace
- **{file_move_name_2_to_1}**: Download for processing
### File persistence
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
{_SHARED_TOOL_NOTES}"""
# Pre-built supplements for common environments
def _get_local_storage_supplement(cwd: str) -> str:
"""Local ephemeral storage (files lost between turns)."""
return _build_storage_supplement(
working_dir=cwd,
sandbox_type="in a network-isolated sandbox",
storage_system_1_name="Ephemeral working directory",
storage_system_1_characteristics=[
"Shared by SDK Read/Write/Edit/Glob/Grep tools AND `bash_exec`",
],
storage_system_1_persistence=[
"Files here are **lost between turns** — do NOT rely on them persisting",
"Use for temporary work: running scripts, processing data, etc.",
],
file_move_name_1_to_2="Ephemeral → Persistent",
file_move_name_2_to_1="Persistent → Ephemeral",
)
def _get_cloud_sandbox_supplement() -> str:
"""Cloud persistent sandbox (files survive across turns in session)."""
return _build_storage_supplement(
working_dir="/home/user",
sandbox_type="in a cloud sandbox with full internet access",
storage_system_1_name="Cloud sandbox",
storage_system_1_characteristics=[
"Shared by all file tools AND `bash_exec` — same filesystem",
"Full Linux environment with internet access",
],
storage_system_1_persistence=[
"Files **persist across turns** within the current session",
"Lost when the session expires (12 h inactivity)",
],
file_move_name_1_to_2="Sandbox → Persistent",
file_move_name_2_to_1="Persistent → Sandbox",
)
def _generate_tool_documentation() -> str:
"""Auto-generate tool documentation from TOOL_REGISTRY.
NOTE: This is ONLY used in baseline mode (direct OpenAI API).
SDK mode doesn't need it since Claude gets tool schemas automatically.
This generates a complete list of available tools with their descriptions,
ensuring the documentation stays in sync with the actual tool implementations.
All workflow guidance is now embedded in individual tool descriptions.
Only documents tools that are available in the current environment
(checked via tool.is_available property).
"""
docs = "\n## AVAILABLE TOOLS\n\n"
# Sort tools alphabetically for consistent output
# Filter by is_available to match get_available_tools() behavior
for name in sorted(TOOL_REGISTRY.keys()):
tool = TOOL_REGISTRY[name]
if not tool.is_available:
continue
schema = tool.as_openai_tool()
desc = schema["function"].get("description", "No description available")
# Format as bullet list with tool name in code style
docs += f"- **`{name}`**: {desc}\n"
return docs
def get_sdk_supplement(use_e2b: bool, cwd: str = "") -> str:
"""Get the supplement for SDK mode (Claude Agent SDK).
SDK mode does NOT include tool documentation because Claude automatically
receives tool schemas from the SDK. Only includes technical notes about
storage systems and execution environment.
Args:
use_e2b: Whether E2B cloud sandbox is being used
cwd: Current working directory (only used in local_storage mode)
Returns:
The supplement string to append to the system prompt
"""
if use_e2b:
return _get_cloud_sandbox_supplement()
return _get_local_storage_supplement(cwd)
def get_baseline_supplement() -> str:
"""Get the supplement for baseline mode (direct OpenAI API).
Baseline mode INCLUDES auto-generated tool documentation because the
direct API doesn't automatically provide tool schemas to Claude.
Also includes shared technical notes (but NOT SDK-specific environment details).
Returns:
The supplement string to append to the system prompt
"""
tool_docs = _generate_tool_documentation()
return tool_docs + _SHARED_TOOL_NOTES

View File

@@ -44,7 +44,7 @@ from ..model import (
update_session_title,
upsert_chat_session,
)
from ..prompt_constants import KEY_WORKFLOWS
from ..prompting import get_sdk_supplement
from ..response_model import (
StreamBaseResponse,
StreamError,
@@ -60,7 +60,6 @@ from ..service import (
_generate_session_title,
_is_langfuse_configured,
)
from ..tools import TOOL_REGISTRY
from ..tools.e2b_sandbox import get_or_create_sandbox
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tools.workspace_files import get_manager
@@ -148,169 +147,6 @@ _SDK_CWD_PREFIX = WORKSPACE_PREFIX
_HEARTBEAT_INTERVAL = 10.0 # seconds
# Appended to the system prompt to inform the agent about available tools.
# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead,
# which has kernel-level network isolation (unshare --net).
def _generate_tool_documentation() -> str:
"""Auto-generate tool documentation from TOOL_REGISTRY.
This generates a complete list of available tools with their descriptions,
ensuring the documentation stays in sync with the actual tool implementations.
"""
docs = "\n## AVAILABLE TOOLS\n\n"
# Sort tools alphabetically for consistent output
for name in sorted(TOOL_REGISTRY.keys()):
tool = TOOL_REGISTRY[name]
schema = tool.as_openai_tool()
desc = schema["function"].get("description", "No description available")
# Format as bullet list with tool name in code style
docs += f"- **`{name}`**: {desc}\n"
# Add workflow guidance for key tools
docs += KEY_WORKFLOWS
return docs
_SHARED_TOOL_NOTES = """\
### Web search and research
- **`web_search(query)`** — Search the web for current information (uses Claude's
native web search). Use this when you need up-to-date information, facts,
statistics, or current events that are beyond your knowledge cutoff.
- **`web_fetch(url)`** — Retrieve and analyze content from a specific URL.
Use this when you have a specific URL to read (documentation, articles, etc.).
### Sharing files with the user
After saving a file to the persistent workspace with `write_workspace_file`,
share it with the user by embedding the `download_url` from the response in
your message as a Markdown link or image:
- **Any file** — shows as a clickable download link:
`[report.csv](workspace://file_id#text/csv)`
- **Image** — renders inline in chat:
`![chart](workspace://file_id#image/png)`
- **Video** — renders inline in chat with player controls:
`![recording](workspace://file_id#video/mp4)`
The `download_url` field in the `write_workspace_file` response is already
in the correct format — paste it directly after the `(` in the Markdown.
### Long-running tools
Long-running tools (create_agent, edit_agent, etc.) are handled
asynchronously. You will receive an immediate response; the actual result
is delivered to the user via a background stream.
### Large tool outputs
When a tool output exceeds the display limit, it is automatically saved to
the persistent workspace. The truncated output includes a
`<tool-output-truncated>` tag with the workspace path. Use
`read_workspace_file(path="...", offset=N, length=50000)` to retrieve
additional sections.
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.
All tasks must run in the foreground.
"""
_LOCAL_TOOL_SUPPLEMENT = (
"""
## Tool notes
### Shell commands
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
for shell commands — it runs in a network-isolated sandbox.
### Working directory
- Your working directory is: `{cwd}`
- All SDK Read/Write/Edit/Glob/Grep tools AND `bash_exec` operate inside this
directory. This is the ONLY writable path — do not attempt to read or write
anywhere else on the filesystem.
- Use relative paths or absolute paths under `{cwd}` for all file operations.
### Two storage systems — CRITICAL to understand
1. **Ephemeral working directory** (`{cwd}`):
- Shared by SDK Read/Write/Edit/Glob/Grep tools AND `bash_exec`
- Files here are **lost between turns** — do NOT rely on them persisting
- Use for temporary work: running scripts, processing data, etc.
2. **Persistent workspace** (cloud storage):
- Files here **survive across turns and sessions**
- Use `write_workspace_file` to save important files (code, outputs, configs)
- Use `read_workspace_file` to retrieve previously saved files
- Use `list_workspace_files` to see what files you've saved before
- Call `list_workspace_files(include_all_sessions=True)` to see files from
all sessions
### Moving files between ephemeral and persistent storage
- **Ephemeral → Persistent**: Use `write_workspace_file` with either:
- `content` param (plain text) — for text files
- `source_path` param — to copy any file directly from the ephemeral dir
- **Persistent → Ephemeral**: Use `read_workspace_file` with `save_to_path`
param to download a workspace file to the ephemeral dir for processing
### File persistence workflow
When you create or modify important files (code, configs, outputs), you MUST:
1. Save them using `write_workspace_file` so they persist
2. At the start of a new turn, call `list_workspace_files` to see what files
are available from previous turns
"""
+ _SHARED_TOOL_NOTES
)
_E2B_TOOL_SUPPLEMENT = (
"""
## Tool notes
### Shell commands
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
for shell commands — it runs in a cloud sandbox with full internet access.
### Working directory
- Your working directory is: `/home/user` (cloud sandbox)
- All file tools (`read_file`, `write_file`, `edit_file`, `glob`, `grep`)
AND `bash_exec` operate on the **same cloud sandbox filesystem**.
- Files created by `bash_exec` are immediately visible to `read_file` and
vice-versa — they share one filesystem.
- Use relative paths (resolved from `/home/user`) or absolute paths.
### Two storage systems — CRITICAL to understand
1. **Cloud sandbox** (`/home/user`):
- Shared by all file tools AND `bash_exec` — same filesystem
- Files **persist across turns** within the current session
- Full Linux environment with internet access
- Lost when the session expires (12 h inactivity)
2. **Persistent workspace** (cloud storage):
- Files here **survive across sessions indefinitely**
- Use `write_workspace_file` to save important files permanently
- Use `read_workspace_file` to retrieve previously saved files
- Use `list_workspace_files` to see what files you've saved before
- Call `list_workspace_files(include_all_sessions=True)` to see files from
all sessions
### Moving files between sandbox and persistent storage
- **Sandbox → Persistent**: Use `write_workspace_file` with `source_path`
to copy from the sandbox to permanent storage
- **Persistent → Sandbox**: Use `read_workspace_file` with `save_to_path`
to download into the sandbox for processing
### File persistence workflow
Important files that must survive beyond this session should be saved with
`write_workspace_file`. Sandbox files persist across turns but are lost
when the session expires.
"""
+ _SHARED_TOOL_NOTES
)
STREAM_LOCK_PREFIX = "copilot:stream:lock:"
@@ -491,13 +327,14 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
}
)
elif isinstance(block, ToolResultBlock):
result.append(
{
"type": "tool_result",
"tool_use_id": block.tool_use_id,
"content": block.content,
}
)
tool_result_entry: dict[str, Any] = {
"type": "tool_result",
"tool_use_id": block.tool_use_id,
"content": block.content,
}
if block.is_error:
tool_result_entry["is_error"] = True
result.append(tool_result_entry)
elif isinstance(block, ThinkingBlock):
result.append(
{
@@ -996,16 +833,9 @@ async def stream_chat_completion_sdk(
)
use_e2b = e2b_sandbox is not None
# Generate tool documentation and append appropriate supplement
tool_docs = _generate_tool_documentation()
system_prompt = (
base_system_prompt
+ tool_docs
+ (
_E2B_TOOL_SUPPLEMENT
if use_e2b
else _LOCAL_TOOL_SUPPLEMENT.format(cwd=sdk_cwd)
)
# Append appropriate supplement (Claude gets tool schemas automatically)
system_prompt = base_system_prompt + get_sdk_supplement(
use_e2b=use_e2b, cwd=sdk_cwd
)
# Process transcript download result
@@ -1170,19 +1000,18 @@ async def stream_chat_completion_sdk(
json.dumps(user_msg) + "\n"
)
# Capture user message in transcript (multimodal)
transcript_builder.add_user_message(content=content_blocks)
transcript_builder.append_user(content=content_blocks)
else:
await client.query(query_message, session_id=session_id)
# Capture actual user message in transcript (not the engineered query)
# query_message may include context wrappers, but transcript needs raw input
transcript_builder.add_user_message(content=current_message)
transcript_builder.append_user(content=current_message)
assistant_response = ChatMessage(role="assistant", content="")
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False
ended_with_stream_error = False
# Use an explicit async iterator with non-cancelling heartbeats.
# CRITICAL: we must NOT cancel __anext__() mid-flight — doing so
# (via asyncio.timeout or wait_for) corrupts the SDK's internal
@@ -1253,15 +1082,6 @@ async def stream_chat_completion_sdk(
len(adapter.resolved_tool_calls),
)
# Capture SDK messages in transcript
if isinstance(sdk_msg, AssistantMessage):
content_blocks = _format_sdk_content_blocks(sdk_msg.content)
model_name = getattr(sdk_msg, "model", "")
transcript_builder.add_assistant_message(
content_blocks=content_blocks,
model=model_name,
)
# Race-condition fix: SDK hooks (PostToolUse) are
# executed asynchronously via start_soon() — the next
# message can arrive before the hook stashes output.
@@ -1392,33 +1212,37 @@ async def stream_chat_completion_sdk(
has_appended_assistant = True
elif isinstance(response, StreamToolOutputAvailable):
tool_result_content = (
content = (
response.output
if isinstance(response.output, str)
else str(response.output)
else json.dumps(response.output, ensure_ascii=False)
)
session.messages.append(
ChatMessage(
role="tool",
content=tool_result_content,
content=content,
tool_call_id=response.toolCallId,
)
)
# Capture tool result in transcript as user message with tool_result content
transcript_builder.add_user_message(
content=[
{
"type": "tool_result",
"tool_use_id": response.toolCallId,
"content": tool_result_content,
}
]
transcript_builder.append_tool_result(
tool_use_id=response.toolCallId,
content=content,
)
has_tool_results = True
elif isinstance(response, StreamFinish):
stream_completed = True
# Append assistant entry AFTER convert_message so that
# any stashed tool results from the previous turn are
# recorded first, preserving the required API order:
# assistant(tool_use) → tool_result → assistant(text).
if isinstance(sdk_msg, AssistantMessage):
transcript_builder.append_assistant(
content_blocks=_format_sdk_content_blocks(sdk_msg.content),
model=sdk_msg.model,
)
except asyncio.CancelledError:
# Task/generator was cancelled (e.g. client disconnect,
# server shutdown). Log and let the safety-net / finally
@@ -1467,6 +1291,15 @@ async def stream_chat_completion_sdk(
type(response).__name__,
getattr(response, "toolName", "N/A"),
)
if isinstance(response, StreamToolOutputAvailable):
transcript_builder.append_tool_result(
tool_use_id=response.toolCallId,
content=(
response.output
if isinstance(response.output, str)
else json.dumps(response.output, ensure_ascii=False)
),
)
yield response
# If the stream ended without a ResultMessage, the SDK
@@ -1645,7 +1478,7 @@ async def _update_title_async(
message, user_id=user_id, session_id=session_id
)
if title and user_id:
await update_session_title(session_id, title)
await update_session_title(session_id, user_id, title, only_if_empty=True)
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
except Exception as e:
logger.warning(f"[SDK] Failed to update session title: {e}")

View File

@@ -7,7 +7,7 @@ from unittest.mock import AsyncMock, patch
import pytest
from .service import _generate_tool_documentation, _prepare_file_attachments
from .service import _prepare_file_attachments
@dataclass
@@ -147,92 +147,101 @@ class TestPrepareFileAttachments:
assert len(result.image_blocks) == 1
class TestGenerateToolDocumentation:
"""Tests for auto-generated tool documentation from TOOL_REGISTRY."""
class TestPromptSupplement:
"""Tests for centralized prompt supplement generation."""
def test_generate_tool_documentation_structure(self):
"""Test that tool documentation has expected structure."""
docs = _generate_tool_documentation()
def test_sdk_supplement_excludes_tool_docs(self):
"""SDK mode should NOT include tool documentation (Claude gets schemas automatically)."""
from backend.copilot.prompting import get_sdk_supplement
# Check main sections exist
assert "## AVAILABLE TOOLS" in docs
assert "## KEY WORKFLOWS" in docs
# Test both local and E2B modes
local_supplement = get_sdk_supplement(use_e2b=False, cwd="/tmp/test")
e2b_supplement = get_sdk_supplement(use_e2b=True, cwd="")
# Verify no duplicate sections
assert docs.count("## AVAILABLE TOOLS") == 1
assert docs.count("## KEY WORKFLOWS") == 1
# Should NOT have tool list section
assert "## AVAILABLE TOOLS" not in local_supplement
assert "## AVAILABLE TOOLS" not in e2b_supplement
def test_tool_documentation_includes_key_tools(self):
"""Test that documentation includes essential copilot tools."""
docs = _generate_tool_documentation()
# Should still have technical notes
assert "## Tool notes" in local_supplement
assert "## Tool notes" in e2b_supplement
# Core agent workflow tools
def test_baseline_supplement_includes_tool_docs(self):
"""Baseline mode MUST include tool documentation (direct API needs it)."""
from backend.copilot.prompting import get_baseline_supplement
supplement = get_baseline_supplement()
# MUST have tool list section
assert "## AVAILABLE TOOLS" in supplement
# Should NOT have environment-specific notes (SDK-only)
assert "## Tool notes" not in supplement
def test_baseline_supplement_includes_key_tools(self):
"""Baseline supplement should document all essential tools."""
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.tools import TOOL_REGISTRY
docs = get_baseline_supplement()
# Core agent workflow tools (always available)
assert "`create_agent`" in docs
assert "`run_agent`" in docs
assert "`find_library_agent`" in docs
assert "`edit_agent`" in docs
# MCP integration
# MCP integration (always available)
assert "`run_mcp_tool`" in docs
# Browser automation
assert "`browser_navigate`" in docs
# Folder management
# Folder management (always available)
assert "`create_folder`" in docs
def test_tool_documentation_format(self):
"""Test that each tool follows bullet list format."""
docs = _generate_tool_documentation()
# Browser tools only if available (Playwright may not be installed in CI)
if (
TOOL_REGISTRY.get("browser_navigate")
and TOOL_REGISTRY["browser_navigate"].is_available
):
assert "`browser_navigate`" in docs
lines = docs.split("\n")
tool_lines = [line for line in lines if line.strip().startswith("- **`")]
def test_baseline_supplement_includes_workflows(self):
"""Baseline supplement should include workflow guidance in tool descriptions."""
from backend.copilot.prompting import get_baseline_supplement
# Should have multiple tools (at least 20 from TOOL_REGISTRY)
assert len(tool_lines) >= 20
docs = get_baseline_supplement()
# Each tool line should have proper markdown format
for line in tool_lines:
assert line.startswith("- **`"), f"Bad format: {line}"
assert "`**:" in line, f"Missing description separator: {line}"
# Workflows are now in individual tool descriptions (not separate sections)
# Check that key workflow concepts appear in tool descriptions
assert "suggested_goal" in docs or "clarifying_questions" in docs
assert "run_mcp_tool" in docs
def test_tool_documentation_includes_workflows(self):
"""Test that key workflow patterns are documented."""
docs = _generate_tool_documentation()
# Check workflow sections
assert "MCP Integration Workflow" in docs
assert "Agent Creation Workflow" in docs
assert "Folder Management" in docs
# Check workflow details
assert "suggested_goal" in docs # Agent creation feedback loop
assert "clarifying_questions" in docs # Agent creation feedback loop
assert "run_mcp_tool(server_url)" in docs # MCP discovery pattern
def test_tool_documentation_completeness(self):
"""Test that all tools from TOOL_REGISTRY appear in documentation."""
def test_baseline_supplement_completeness(self):
"""All available tools from TOOL_REGISTRY should appear in baseline supplement."""
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.tools import TOOL_REGISTRY
docs = _generate_tool_documentation()
docs = get_baseline_supplement()
# Verify each registered tool is documented
for tool_name in TOOL_REGISTRY.keys():
# Verify each available registered tool is documented
# (matches _generate_tool_documentation which filters by is_available)
for tool_name, tool in TOOL_REGISTRY.items():
if not tool.is_available:
continue
assert (
f"`{tool_name}`" in docs
), f"Tool '{tool_name}' missing from auto-generated documentation"
), f"Tool '{tool_name}' missing from baseline supplement"
def test_tool_documentation_no_duplicate_tools(self):
"""Test that no tool appears multiple times in the list."""
def test_baseline_supplement_no_duplicate_tools(self):
"""No tool should appear multiple times in baseline supplement."""
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.tools import TOOL_REGISTRY
docs = _generate_tool_documentation()
docs = get_baseline_supplement()
# Extract the tools section (before KEY WORKFLOWS)
tools_section = docs.split("## KEY WORKFLOWS")[0]
# Count occurrences of each tool
for tool_name in TOOL_REGISTRY.keys():
# Count occurrences of each available tool in the entire supplement
for tool_name, tool in TOOL_REGISTRY.items():
if not tool.is_available:
continue
# Count how many times this tool appears as a bullet point
count = tools_section.count(f"- **`{tool_name}`**")
count = docs.count(f"- **`{tool_name}`**")
assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)"

View File

@@ -44,6 +44,15 @@ class TranscriptBuilder:
self._entries: list[TranscriptEntry] = []
self._last_uuid: str | None = None
def _last_is_assistant(self) -> bool:
return bool(self._entries) and self._entries[-1].type == "assistant"
def _last_message_id(self) -> str:
"""Return the message.id of the last entry, or '' if none."""
if self._entries:
return self._entries[-1].message.get("id", "")
return ""
def load_previous(self, content: str, log_prefix: str = "[Transcript]") -> None:
"""Load complete previous transcript.
@@ -91,10 +100,8 @@ class TranscriptBuilder:
self._last_uuid[:12] if self._last_uuid else None,
)
def add_user_message(
self, content: str | list[dict], uuid: str | None = None
) -> None:
"""Add user message to the complete context."""
def append_user(self, content: str | list[dict], uuid: str | None = None) -> None:
"""Append a user entry."""
msg_uuid = uuid or str(uuid4())
self._entries.append(
@@ -107,10 +114,34 @@ class TranscriptBuilder:
)
self._last_uuid = msg_uuid
def add_assistant_message(
self, content_blocks: list[dict], model: str = ""
def append_tool_result(self, tool_use_id: str, content: str) -> None:
"""Append a tool result as a user entry (one per tool call)."""
self.append_user(
content=[
{"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
]
)
def append_assistant(
self,
content_blocks: list[dict],
model: str = "",
stop_reason: str | None = None,
) -> None:
"""Add assistant message to the complete context."""
"""Append an assistant entry.
Consecutive assistant entries automatically share the same message ID
so the CLI can merge them (thinking → text → tool_use) into a single
API message on ``--resume``. A new ID is assigned whenever an
assistant entry follows a non-assistant entry (user message or tool
result), because that marks the start of a new API response.
"""
message_id = (
self._last_message_id()
if self._last_is_assistant()
else f"msg_sdk_{uuid4().hex[:24]}"
)
msg_uuid = str(uuid4())
self._entries.append(
@@ -121,7 +152,11 @@ class TranscriptBuilder:
message={
"role": "assistant",
"model": model,
"id": message_id,
"type": "message",
"content": content_blocks,
"stop_reason": stop_reason,
"stop_sequence": None,
},
)
)
@@ -130,6 +165,9 @@ class TranscriptBuilder:
def to_jsonl(self) -> str:
"""Export complete context as JSONL.
Consecutive assistant entries are kept separate to match the
native CLI format — the SDK merges them internally on resume.
Returns the FULL conversation state (all entries), not incremental.
This output REPLACES any previous transcript.
"""

View File

@@ -18,7 +18,7 @@ from langfuse.openai import (
from backend.data.db_accessors import understanding_db
from backend.data.understanding import format_understanding_for_prompt
from backend.util.exceptions import NotFoundError
from backend.util.exceptions import NotAuthorizedError, NotFoundError
from backend.util.settings import AppEnvironment, Settings
from .config import ChatConfig
@@ -198,6 +198,12 @@ async def assign_user_to_session(
session = await get_chat_session(session_id, None)
if not session:
raise NotFoundError(f"Session {session_id} not found")
if session.user_id is not None and session.user_id != user_id:
logger.warning(
f"[SECURITY] Attempt to claim session {session_id} by user {user_id}, "
f"but it already belongs to user {session.user_id}"
)
raise NotAuthorizedError(f"Not authorized to claim session {session_id}")
session.user_id = user_id
session = await upsert_chat_session(session)
return session

View File

@@ -20,6 +20,14 @@ from .find_agent import FindAgentTool
from .find_block import FindBlockTool
from .find_library_agent import FindLibraryAgentTool
from .get_doc_page import GetDocPageTool
from .manage_folders import (
CreateFolderTool,
DeleteFolderTool,
ListFoldersTool,
MoveAgentsToFolderTool,
MoveFolderTool,
UpdateFolderTool,
)
from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .run_mcp_tool import RunMCPToolTool
@@ -47,6 +55,13 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"find_agent": FindAgentTool(),
"find_block": FindBlockTool(),
"find_library_agent": FindLibraryAgentTool(),
# Folder management tools
"create_folder": CreateFolderTool(),
"list_folders": ListFoldersTool(),
"update_folder": UpdateFolderTool(),
"move_folder": MoveFolderTool(),
"delete_folder": DeleteFolderTool(),
"move_agents_to_folder": MoveAgentsToFolderTool(),
"run_agent": RunAgentTool(),
"run_block": RunBlockTool(),
"run_mcp_tool": RunMCPToolTool(),

View File

@@ -151,8 +151,8 @@ async def setup_test_data(server):
unique_slug = f"test-agent-{str(uuid.uuid4())[:8]}"
store_submission = await store_db.create_store_submission(
user_id=user.id,
agent_id=created_graph.id,
agent_version=created_graph.version,
graph_id=created_graph.id,
graph_version=created_graph.version,
slug=unique_slug,
name="Test Agent",
description="A simple test agent",
@@ -161,10 +161,10 @@ async def setup_test_data(server):
image_urls=["https://example.com/image.jpg"],
)
assert store_submission.store_listing_version_id is not None
assert store_submission.listing_version_id is not None
# 4. Approve the store listing version
await store_db.review_store_submission(
store_listing_version_id=store_submission.store_listing_version_id,
store_listing_version_id=store_submission.listing_version_id,
is_approved=True,
external_comments="Approved for testing",
internal_comments="Test approval",
@@ -321,8 +321,8 @@ async def setup_llm_test_data(server):
unique_slug = f"llm-test-agent-{str(uuid.uuid4())[:8]}"
store_submission = await store_db.create_store_submission(
user_id=user.id,
agent_id=created_graph.id,
agent_version=created_graph.version,
graph_id=created_graph.id,
graph_version=created_graph.version,
slug=unique_slug,
name="LLM Test Agent",
description="An agent with LLM capabilities",
@@ -330,9 +330,9 @@ async def setup_llm_test_data(server):
categories=["testing", "ai"],
image_urls=["https://example.com/image.jpg"],
)
assert store_submission.store_listing_version_id is not None
assert store_submission.listing_version_id is not None
await store_db.review_store_submission(
store_listing_version_id=store_submission.store_listing_version_id,
store_listing_version_id=store_submission.listing_version_id,
is_approved=True,
external_comments="Approved for testing",
internal_comments="Test approval for LLM agent",
@@ -476,8 +476,8 @@ async def setup_firecrawl_test_data(server):
unique_slug = f"firecrawl-test-agent-{str(uuid.uuid4())[:8]}"
store_submission = await store_db.create_store_submission(
user_id=user.id,
agent_id=created_graph.id,
agent_version=created_graph.version,
graph_id=created_graph.id,
graph_version=created_graph.version,
slug=unique_slug,
name="Firecrawl Test Agent",
description="An agent with Firecrawl integration (no credentials)",
@@ -485,9 +485,9 @@ async def setup_firecrawl_test_data(server):
categories=["testing", "scraping"],
image_urls=["https://example.com/image.jpg"],
)
assert store_submission.store_listing_version_id is not None
assert store_submission.listing_version_id is not None
await store_db.review_store_submission(
store_listing_version_id=store_submission.store_listing_version_id,
store_listing_version_id=store_submission.listing_version_id,
is_approved=True,
external_comments="Approved for testing",
internal_comments="Test approval for Firecrawl agent",

View File

@@ -695,7 +695,10 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
async def save_agent_to_library(
agent_json: dict[str, Any], user_id: str, is_update: bool = False
agent_json: dict[str, Any],
user_id: str,
is_update: bool = False,
folder_id: str | None = None,
) -> tuple[Graph, Any]:
"""Save agent to database and user's library.
@@ -703,6 +706,7 @@ async def save_agent_to_library(
agent_json: Agent JSON dict
user_id: User ID
is_update: Whether this is an update to an existing agent
folder_id: Optional folder ID to place the agent in
Returns:
Tuple of (created Graph, LibraryAgent)
@@ -711,7 +715,7 @@ async def save_agent_to_library(
db = library_db()
if is_update:
return await db.update_graph_in_library(graph, user_id)
return await db.create_graph_in_library(graph, user_id)
return await db.create_graph_in_library(graph, user_id, folder_id=folder_id)
def graph_to_json(graph: Graph) -> dict[str, Any]:

View File

@@ -39,9 +39,13 @@ class CreateAgentTool(BaseTool):
return (
"Create a new agent workflow from a natural language description. "
"First generates a preview, then saves to library if save=true. "
"\n\nIMPORTANT: Before calling this tool, search for relevant existing agents "
"using find_library_agent that could be used as building blocks. "
"Pass their IDs in the library_agent_ids parameter so the generator can compose them."
"\n\nWorkflow: (1) Always check find_library_agent first for existing building blocks. "
"(2) Call create_agent with description and library_agent_ids. "
"(3) If response contains suggested_goal: Present to user, ask for confirmation, "
"then call again with the suggested goal if accepted. "
"(4) If response contains clarifying_questions: Present to user, collect answers, "
"then call again with original description AND answers in the context parameter. "
"\n\nThis feedback loop ensures the generated agent matches user intent."
)
@property
@@ -84,6 +88,14 @@ class CreateAgentTool(BaseTool):
),
"default": True,
},
"folder_id": {
"type": "string",
"description": (
"Optional folder ID to save the agent into. "
"If not provided, the agent is saved at root level. "
"Use list_folders to find available folders."
),
},
},
"required": ["description"],
}
@@ -105,6 +117,7 @@ class CreateAgentTool(BaseTool):
context = kwargs.get("context", "")
library_agent_ids = kwargs.get("library_agent_ids", [])
save = kwargs.get("save", True)
folder_id = kwargs.get("folder_id")
session_id = session.session_id if session else None
logger.info(
@@ -336,7 +349,7 @@ class CreateAgentTool(BaseTool):
try:
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id
agent_json, user_id, folder_id=folder_id
)
logger.info(

View File

@@ -3,9 +3,9 @@
import logging
from typing import Any
from backend.api.features.store.exceptions import AgentNotFoundError
from backend.copilot.model import ChatSession
from backend.data.db_accessors import store_db as get_store_db
from backend.util.exceptions import NotFoundError
from .agent_generator import (
AgentGeneratorNotConfiguredError,
@@ -80,6 +80,14 @@ class CustomizeAgentTool(BaseTool):
),
"default": True,
},
"folder_id": {
"type": "string",
"description": (
"Optional folder ID to save the agent into. "
"If not provided, the agent is saved at root level. "
"Use list_folders to find available folders."
),
},
},
"required": ["agent_id", "modifications"],
}
@@ -102,6 +110,7 @@ class CustomizeAgentTool(BaseTool):
modifications = kwargs.get("modifications", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
folder_id = kwargs.get("folder_id")
session_id = session.session_id if session else None
if not agent_id:
@@ -140,7 +149,7 @@ class CustomizeAgentTool(BaseTool):
agent_details = await store_db.get_store_agent_details(
username=creator_username, agent_name=agent_slug
)
except AgentNotFoundError:
except NotFoundError:
return ErrorResponse(
message=(
f"Could not find marketplace agent '{agent_id}'. "
@@ -310,7 +319,7 @@ class CustomizeAgentTool(BaseTool):
# Save to user's library
try:
created_graph, library_agent = await save_agent_to_library(
customized_agent, user_id, is_update=False
customized_agent, user_id, is_update=False, folder_id=folder_id
)
return AgentSavedResponse(

View File

@@ -0,0 +1,573 @@
"""Folder management tools for the copilot."""
from typing import Any
from backend.api.features.library import model as library_model
from backend.api.features.library.db import collect_tree_ids
from backend.copilot.model import ChatSession
from backend.data.db_accessors import library_db
from .base import BaseTool
from .models import (
AgentsMovedToFolderResponse,
ErrorResponse,
FolderAgentSummary,
FolderCreatedResponse,
FolderDeletedResponse,
FolderInfo,
FolderListResponse,
FolderMovedResponse,
FolderTreeInfo,
FolderUpdatedResponse,
ToolResponseBase,
)
def _folder_to_info(
folder: library_model.LibraryFolder,
agents: list[FolderAgentSummary] | None = None,
) -> FolderInfo:
"""Convert a LibraryFolder DB model to a FolderInfo response model."""
return FolderInfo(
id=folder.id,
name=folder.name,
parent_id=folder.parent_id,
icon=folder.icon,
color=folder.color,
agent_count=folder.agent_count,
subfolder_count=folder.subfolder_count,
agents=agents,
)
def _tree_to_info(
tree: library_model.LibraryFolderTree,
agents_map: dict[str, list[FolderAgentSummary]] | None = None,
) -> FolderTreeInfo:
"""Recursively convert a LibraryFolderTree to a FolderTreeInfo response."""
return FolderTreeInfo(
id=tree.id,
name=tree.name,
parent_id=tree.parent_id,
icon=tree.icon,
color=tree.color,
agent_count=tree.agent_count,
subfolder_count=tree.subfolder_count,
children=[_tree_to_info(child, agents_map) for child in tree.children],
agents=agents_map.get(tree.id) if agents_map else None,
)
def _to_agent_summaries(
raw: list[dict[str, str | None]],
) -> list[FolderAgentSummary]:
"""Convert raw agent dicts to typed FolderAgentSummary models."""
return [
FolderAgentSummary(
id=a["id"] or "",
name=a["name"] or "",
description=a["description"] or "",
)
for a in raw
]
def _to_agent_summaries_map(
raw: dict[str, list[dict[str, str | None]]],
) -> dict[str, list[FolderAgentSummary]]:
"""Convert a folder-id-keyed dict of raw agents to typed summaries."""
return {fid: _to_agent_summaries(agents) for fid, agents in raw.items()}
class CreateFolderTool(BaseTool):
"""Tool for creating a library folder."""
@property
def name(self) -> str:
return "create_folder"
@property
def description(self) -> str:
return (
"Create a new folder in the user's library to organize agents. "
"Optionally nest it inside an existing folder using parent_id."
)
@property
def requires_auth(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Name for the new folder (max 100 chars).",
},
"parent_id": {
"type": "string",
"description": (
"ID of the parent folder to nest inside. "
"Omit to create at root level."
),
},
"icon": {
"type": "string",
"description": "Optional icon identifier for the folder.",
},
"color": {
"type": "string",
"description": "Optional hex color code (#RRGGBB).",
},
},
"required": ["name"],
}
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Create a folder with the given name and optional parent/icon/color."""
assert user_id is not None # guaranteed by requires_auth
name = (kwargs.get("name") or "").strip()
parent_id = kwargs.get("parent_id")
icon = kwargs.get("icon")
color = kwargs.get("color")
session_id = session.session_id if session else None
if not name:
return ErrorResponse(
message="Please provide a folder name.",
error="missing_name",
session_id=session_id,
)
try:
folder = await library_db().create_folder(
user_id=user_id,
name=name,
parent_id=parent_id,
icon=icon,
color=color,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to create folder: {e}",
error="create_folder_failed",
session_id=session_id,
)
return FolderCreatedResponse(
message=f"Folder '{folder.name}' created successfully!",
folder=_folder_to_info(folder),
session_id=session_id,
)
class ListFoldersTool(BaseTool):
"""Tool for listing library folders."""
@property
def name(self) -> str:
return "list_folders"
@property
def description(self) -> str:
return (
"List the user's library folders. "
"Omit parent_id to get the full folder tree. "
"Provide parent_id to list only direct children of that folder. "
"Set include_agents=true to also return the agents inside each folder "
"and root-level agents not in any folder. Always set include_agents=true "
"when the user asks about agents, wants to see what's in their folders, "
"or mentions agents alongside folders."
)
@property
def requires_auth(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"parent_id": {
"type": "string",
"description": (
"List children of this folder. "
"Omit to get the full folder tree."
),
},
"include_agents": {
"type": "boolean",
"description": (
"Whether to include the list of agents inside each folder. "
"Defaults to false."
),
},
},
"required": [],
}
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""List folders as a flat list (by parent) or full tree."""
assert user_id is not None # guaranteed by requires_auth
parent_id = kwargs.get("parent_id")
include_agents = kwargs.get("include_agents", False)
session_id = session.session_id if session else None
try:
if parent_id:
folders = await library_db().list_folders(
user_id=user_id, parent_id=parent_id
)
raw_map = (
await library_db().get_folder_agents_map(
user_id, [f.id for f in folders]
)
if include_agents
else None
)
agents_map = _to_agent_summaries_map(raw_map) if raw_map else None
return FolderListResponse(
message=f"Found {len(folders)} folder(s).",
folders=[
_folder_to_info(f, agents_map.get(f.id) if agents_map else None)
for f in folders
],
count=len(folders),
session_id=session_id,
)
else:
tree = await library_db().get_folder_tree(user_id=user_id)
all_ids = collect_tree_ids(tree)
agents_map = None
root_agents = None
if include_agents:
raw_map = await library_db().get_folder_agents_map(user_id, all_ids)
agents_map = _to_agent_summaries_map(raw_map)
root_agents = _to_agent_summaries(
await library_db().get_root_agent_summaries(user_id)
)
return FolderListResponse(
message=f"Found {len(all_ids)} folder(s) in your library.",
tree=[_tree_to_info(t, agents_map) for t in tree],
root_agents=root_agents,
count=len(all_ids),
session_id=session_id,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to list folders: {e}",
error="list_folders_failed",
session_id=session_id,
)
class UpdateFolderTool(BaseTool):
"""Tool for updating a folder's properties."""
@property
def name(self) -> str:
return "update_folder"
@property
def description(self) -> str:
return "Update a folder's name, icon, or color."
@property
def requires_auth(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"folder_id": {
"type": "string",
"description": "ID of the folder to update.",
},
"name": {
"type": "string",
"description": "New name for the folder.",
},
"icon": {
"type": "string",
"description": "New icon identifier.",
},
"color": {
"type": "string",
"description": "New hex color code (#RRGGBB).",
},
},
"required": ["folder_id"],
}
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Update a folder's name, icon, or color."""
assert user_id is not None # guaranteed by requires_auth
folder_id = (kwargs.get("folder_id") or "").strip()
name = kwargs.get("name")
icon = kwargs.get("icon")
color = kwargs.get("color")
session_id = session.session_id if session else None
if not folder_id:
return ErrorResponse(
message="Please provide a folder_id.",
error="missing_folder_id",
session_id=session_id,
)
try:
folder = await library_db().update_folder(
folder_id=folder_id,
user_id=user_id,
name=name,
icon=icon,
color=color,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to update folder: {e}",
error="update_folder_failed",
session_id=session_id,
)
return FolderUpdatedResponse(
message=f"Folder updated to '{folder.name}'.",
folder=_folder_to_info(folder),
session_id=session_id,
)
class MoveFolderTool(BaseTool):
"""Tool for moving a folder to a new parent."""
@property
def name(self) -> str:
return "move_folder"
@property
def description(self) -> str:
return (
"Move a folder to a different parent folder. "
"Set target_parent_id to null to move to root level."
)
@property
def requires_auth(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"folder_id": {
"type": "string",
"description": "ID of the folder to move.",
},
"target_parent_id": {
"type": ["string", "null"],
"description": (
"ID of the new parent folder. "
"Use null to move to root level."
),
},
},
"required": ["folder_id"],
}
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Move a folder to a new parent or to root level."""
assert user_id is not None # guaranteed by requires_auth
folder_id = (kwargs.get("folder_id") or "").strip()
target_parent_id = kwargs.get("target_parent_id")
session_id = session.session_id if session else None
if not folder_id:
return ErrorResponse(
message="Please provide a folder_id.",
error="missing_folder_id",
session_id=session_id,
)
try:
folder = await library_db().move_folder(
folder_id=folder_id,
user_id=user_id,
target_parent_id=target_parent_id,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to move folder: {e}",
error="move_folder_failed",
session_id=session_id,
)
dest = "a subfolder" if target_parent_id else "root level"
return FolderMovedResponse(
message=f"Folder '{folder.name}' moved to {dest}.",
folder=_folder_to_info(folder),
target_parent_id=target_parent_id,
session_id=session_id,
)
class DeleteFolderTool(BaseTool):
"""Tool for deleting a folder."""
@property
def name(self) -> str:
return "delete_folder"
@property
def description(self) -> str:
return (
"Delete a folder from the user's library. "
"Agents inside the folder are moved to root level (not deleted)."
)
@property
def requires_auth(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"folder_id": {
"type": "string",
"description": "ID of the folder to delete.",
},
},
"required": ["folder_id"],
}
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Soft-delete a folder; agents inside are moved to root level."""
assert user_id is not None # guaranteed by requires_auth
folder_id = (kwargs.get("folder_id") or "").strip()
session_id = session.session_id if session else None
if not folder_id:
return ErrorResponse(
message="Please provide a folder_id.",
error="missing_folder_id",
session_id=session_id,
)
try:
await library_db().delete_folder(
folder_id=folder_id,
user_id=user_id,
soft_delete=True,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to delete folder: {e}",
error="delete_folder_failed",
session_id=session_id,
)
return FolderDeletedResponse(
message="Folder deleted. Any agents inside were moved to root level.",
folder_id=folder_id,
session_id=session_id,
)
class MoveAgentsToFolderTool(BaseTool):
"""Tool for moving agents into a folder."""
@property
def name(self) -> str:
return "move_agents_to_folder"
@property
def description(self) -> str:
return (
"Move one or more agents to a folder. "
"Set folder_id to null to move agents to root level."
)
@property
def requires_auth(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": "List of library agent IDs to move.",
},
"folder_id": {
"type": ["string", "null"],
"description": (
"Target folder ID. Use null to move to root level."
),
},
},
"required": ["agent_ids"],
}
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Move one or more agents to a folder or to root level."""
assert user_id is not None # guaranteed by requires_auth
agent_ids = kwargs.get("agent_ids", [])
folder_id = kwargs.get("folder_id")
session_id = session.session_id if session else None
if not agent_ids:
return ErrorResponse(
message="Please provide at least one agent ID.",
error="missing_agent_ids",
session_id=session_id,
)
try:
moved = await library_db().bulk_move_agents_to_folder(
agent_ids=agent_ids,
folder_id=folder_id,
user_id=user_id,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to move agents: {e}",
error="move_agents_failed",
session_id=session_id,
)
moved_ids = [a.id for a in moved]
agent_names = [a.name for a in moved]
dest = "the folder" if folder_id else "root level"
names_str = (
", ".join(agent_names) if agent_names else f"{len(agent_ids)} agent(s)"
)
return AgentsMovedToFolderResponse(
message=f"Moved {names_str} to {dest}.",
agent_ids=moved_ids,
agent_names=agent_names,
folder_id=folder_id,
count=len(moved),
session_id=session_id,
)

View File

@@ -0,0 +1,455 @@
"""Tests for folder management copilot tools."""
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.api.features.library import model as library_model
from backend.copilot.tools.manage_folders import (
CreateFolderTool,
DeleteFolderTool,
ListFoldersTool,
MoveAgentsToFolderTool,
MoveFolderTool,
UpdateFolderTool,
)
from backend.copilot.tools.models import (
AgentsMovedToFolderResponse,
ErrorResponse,
FolderCreatedResponse,
FolderDeletedResponse,
FolderListResponse,
FolderMovedResponse,
FolderUpdatedResponse,
)
from ._test_data import make_session
_TEST_USER_ID = "test-user-folders"
_NOW = datetime.now(UTC)
def _make_folder(
id: str = "folder-1",
name: str = "My Folder",
parent_id: str | None = None,
icon: str | None = None,
color: str | None = None,
agent_count: int = 0,
subfolder_count: int = 0,
) -> library_model.LibraryFolder:
return library_model.LibraryFolder(
id=id,
user_id=_TEST_USER_ID,
name=name,
icon=icon,
color=color,
parent_id=parent_id,
created_at=_NOW,
updated_at=_NOW,
agent_count=agent_count,
subfolder_count=subfolder_count,
)
def _make_tree(
id: str = "folder-1",
name: str = "Root",
children: list[library_model.LibraryFolderTree] | None = None,
) -> library_model.LibraryFolderTree:
return library_model.LibraryFolderTree(
id=id,
user_id=_TEST_USER_ID,
name=name,
created_at=_NOW,
updated_at=_NOW,
children=children or [],
)
def _make_library_agent(id: str = "agent-1", name: str = "Test Agent"):
agent = MagicMock()
agent.id = id
agent.name = name
return agent
@pytest.fixture
def session():
return make_session(_TEST_USER_ID)
# ── CreateFolderTool ──
@pytest.fixture
def create_tool():
return CreateFolderTool()
@pytest.mark.asyncio
async def test_create_folder_missing_name(create_tool, session):
result = await create_tool._execute(user_id=_TEST_USER_ID, session=session, name="")
assert isinstance(result, ErrorResponse)
assert result.error == "missing_name"
@pytest.mark.asyncio
async def test_create_folder_none_name(create_tool, session):
result = await create_tool._execute(
user_id=_TEST_USER_ID, session=session, name=None
)
assert isinstance(result, ErrorResponse)
assert result.error == "missing_name"
@pytest.mark.asyncio
async def test_create_folder_success(create_tool, session):
folder = _make_folder(name="New Folder")
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.create_folder = AsyncMock(return_value=folder)
result = await create_tool._execute(
user_id=_TEST_USER_ID, session=session, name="New Folder"
)
assert isinstance(result, FolderCreatedResponse)
assert result.folder.name == "New Folder"
assert "New Folder" in result.message
@pytest.mark.asyncio
async def test_create_folder_db_error(create_tool, session):
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.create_folder = AsyncMock(
side_effect=Exception("db down")
)
result = await create_tool._execute(
user_id=_TEST_USER_ID, session=session, name="Folder"
)
assert isinstance(result, ErrorResponse)
assert result.error == "create_folder_failed"
# ── ListFoldersTool ──
@pytest.fixture
def list_tool():
return ListFoldersTool()
@pytest.mark.asyncio
async def test_list_folders_by_parent(list_tool, session):
folders = [_make_folder(id="f1", name="A"), _make_folder(id="f2", name="B")]
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.list_folders = AsyncMock(return_value=folders)
result = await list_tool._execute(
user_id=_TEST_USER_ID, session=session, parent_id="parent-1"
)
assert isinstance(result, FolderListResponse)
assert result.count == 2
assert len(result.folders) == 2
@pytest.mark.asyncio
async def test_list_folders_tree(list_tool, session):
tree = [
_make_tree(id="r1", name="Root", children=[_make_tree(id="c1", name="Child")])
]
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.get_folder_tree = AsyncMock(return_value=tree)
result = await list_tool._execute(user_id=_TEST_USER_ID, session=session)
assert isinstance(result, FolderListResponse)
assert result.count == 2 # root + child
assert result.tree is not None
assert len(result.tree) == 1
@pytest.mark.asyncio
async def test_list_folders_tree_with_agents_includes_root(list_tool, session):
tree = [_make_tree(id="r1", name="Root")]
raw_map = {"r1": [{"id": "a1", "name": "Foldered", "description": "In folder"}]}
root_raw = [{"id": "a2", "name": "Loose Agent", "description": "At root"}]
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.get_folder_tree = AsyncMock(return_value=tree)
mock_lib.return_value.get_folder_agents_map = AsyncMock(return_value=raw_map)
mock_lib.return_value.get_root_agent_summaries = AsyncMock(
return_value=root_raw
)
result = await list_tool._execute(
user_id=_TEST_USER_ID, session=session, include_agents=True
)
assert isinstance(result, FolderListResponse)
assert result.root_agents is not None
assert len(result.root_agents) == 1
assert result.root_agents[0].name == "Loose Agent"
assert result.tree is not None
assert result.tree[0].agents is not None
assert result.tree[0].agents[0].name == "Foldered"
mock_lib.return_value.get_root_agent_summaries.assert_awaited_once_with(
_TEST_USER_ID
)
@pytest.mark.asyncio
async def test_list_folders_tree_without_agents_no_root(list_tool, session):
tree = [_make_tree(id="r1", name="Root")]
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.get_folder_tree = AsyncMock(return_value=tree)
result = await list_tool._execute(
user_id=_TEST_USER_ID, session=session, include_agents=False
)
assert isinstance(result, FolderListResponse)
assert result.root_agents is None
@pytest.mark.asyncio
async def test_list_folders_db_error(list_tool, session):
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.get_folder_tree = AsyncMock(
side_effect=Exception("timeout")
)
result = await list_tool._execute(user_id=_TEST_USER_ID, session=session)
assert isinstance(result, ErrorResponse)
assert result.error == "list_folders_failed"
# ── UpdateFolderTool ──
@pytest.fixture
def update_tool():
return UpdateFolderTool()
@pytest.mark.asyncio
async def test_update_folder_missing_id(update_tool, session):
result = await update_tool._execute(
user_id=_TEST_USER_ID, session=session, folder_id=""
)
assert isinstance(result, ErrorResponse)
assert result.error == "missing_folder_id"
@pytest.mark.asyncio
async def test_update_folder_none_id(update_tool, session):
result = await update_tool._execute(
user_id=_TEST_USER_ID, session=session, folder_id=None
)
assert isinstance(result, ErrorResponse)
assert result.error == "missing_folder_id"
@pytest.mark.asyncio
async def test_update_folder_success(update_tool, session):
folder = _make_folder(name="Renamed")
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.update_folder = AsyncMock(return_value=folder)
result = await update_tool._execute(
user_id=_TEST_USER_ID, session=session, folder_id="folder-1", name="Renamed"
)
assert isinstance(result, FolderUpdatedResponse)
assert result.folder.name == "Renamed"
@pytest.mark.asyncio
async def test_update_folder_db_error(update_tool, session):
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.update_folder = AsyncMock(
side_effect=Exception("not found")
)
result = await update_tool._execute(
user_id=_TEST_USER_ID, session=session, folder_id="folder-1", name="X"
)
assert isinstance(result, ErrorResponse)
assert result.error == "update_folder_failed"
# ── MoveFolderTool ──
@pytest.fixture
def move_tool():
return MoveFolderTool()
@pytest.mark.asyncio
async def test_move_folder_missing_id(move_tool, session):
result = await move_tool._execute(
user_id=_TEST_USER_ID, session=session, folder_id=""
)
assert isinstance(result, ErrorResponse)
assert result.error == "missing_folder_id"
@pytest.mark.asyncio
async def test_move_folder_to_parent(move_tool, session):
folder = _make_folder(name="Moved")
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.move_folder = AsyncMock(return_value=folder)
result = await move_tool._execute(
user_id=_TEST_USER_ID,
session=session,
folder_id="folder-1",
target_parent_id="parent-1",
)
assert isinstance(result, FolderMovedResponse)
assert "subfolder" in result.message
@pytest.mark.asyncio
async def test_move_folder_to_root(move_tool, session):
folder = _make_folder(name="Moved")
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.move_folder = AsyncMock(return_value=folder)
result = await move_tool._execute(
user_id=_TEST_USER_ID,
session=session,
folder_id="folder-1",
target_parent_id=None,
)
assert isinstance(result, FolderMovedResponse)
assert "root level" in result.message
@pytest.mark.asyncio
async def test_move_folder_db_error(move_tool, session):
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.move_folder = AsyncMock(side_effect=Exception("circular"))
result = await move_tool._execute(
user_id=_TEST_USER_ID, session=session, folder_id="folder-1"
)
assert isinstance(result, ErrorResponse)
assert result.error == "move_folder_failed"
# ── DeleteFolderTool ──
@pytest.fixture
def delete_tool():
return DeleteFolderTool()
@pytest.mark.asyncio
async def test_delete_folder_missing_id(delete_tool, session):
result = await delete_tool._execute(
user_id=_TEST_USER_ID, session=session, folder_id=""
)
assert isinstance(result, ErrorResponse)
assert result.error == "missing_folder_id"
@pytest.mark.asyncio
async def test_delete_folder_success(delete_tool, session):
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.delete_folder = AsyncMock(return_value=None)
result = await delete_tool._execute(
user_id=_TEST_USER_ID, session=session, folder_id="folder-1"
)
assert isinstance(result, FolderDeletedResponse)
assert result.folder_id == "folder-1"
assert "root level" in result.message
@pytest.mark.asyncio
async def test_delete_folder_db_error(delete_tool, session):
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.delete_folder = AsyncMock(
side_effect=Exception("permission denied")
)
result = await delete_tool._execute(
user_id=_TEST_USER_ID, session=session, folder_id="folder-1"
)
assert isinstance(result, ErrorResponse)
assert result.error == "delete_folder_failed"
# ── MoveAgentsToFolderTool ──
@pytest.fixture
def move_agents_tool():
return MoveAgentsToFolderTool()
@pytest.mark.asyncio
async def test_move_agents_missing_ids(move_agents_tool, session):
result = await move_agents_tool._execute(
user_id=_TEST_USER_ID, session=session, agent_ids=[]
)
assert isinstance(result, ErrorResponse)
assert result.error == "missing_agent_ids"
@pytest.mark.asyncio
async def test_move_agents_success(move_agents_tool, session):
agents = [
_make_library_agent(id="a1", name="Agent Alpha"),
_make_library_agent(id="a2", name="Agent Beta"),
]
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.bulk_move_agents_to_folder = AsyncMock(
return_value=agents
)
result = await move_agents_tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_ids=["a1", "a2"],
folder_id="folder-1",
)
assert isinstance(result, AgentsMovedToFolderResponse)
assert result.count == 2
assert result.agent_names == ["Agent Alpha", "Agent Beta"]
assert "Agent Alpha" in result.message
assert "Agent Beta" in result.message
@pytest.mark.asyncio
async def test_move_agents_to_root(move_agents_tool, session):
agents = [_make_library_agent(id="a1", name="Agent One")]
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.bulk_move_agents_to_folder = AsyncMock(
return_value=agents
)
result = await move_agents_tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_ids=["a1"],
folder_id=None,
)
assert isinstance(result, AgentsMovedToFolderResponse)
assert "root level" in result.message
@pytest.mark.asyncio
async def test_move_agents_db_error(move_agents_tool, session):
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
mock_lib.return_value.bulk_move_agents_to_folder = AsyncMock(
side_effect=Exception("folder not found")
)
result = await move_agents_tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_ids=["a1"],
folder_id="bad-folder",
)
assert isinstance(result, ErrorResponse)
assert result.error == "move_agents_failed"

View File

@@ -55,6 +55,13 @@ class ResponseType(str, Enum):
# MCP tool types
MCP_TOOLS_DISCOVERED = "mcp_tools_discovered"
MCP_TOOL_OUTPUT = "mcp_tool_output"
# Folder management types
FOLDER_CREATED = "folder_created"
FOLDER_LIST = "folder_list"
FOLDER_UPDATED = "folder_updated"
FOLDER_MOVED = "folder_moved"
FOLDER_DELETED = "folder_deleted"
AGENTS_MOVED_TO_FOLDER = "agents_moved_to_folder"
# Base response model
@@ -539,3 +546,82 @@ class BrowserScreenshotResponse(ToolResponseBase):
type: ResponseType = ResponseType.BROWSER_SCREENSHOT
file_id: str # Workspace file ID — use read_workspace_file to retrieve
filename: str
# Folder management models
class FolderAgentSummary(BaseModel):
"""Lightweight agent info for folder listings."""
id: str
name: str
description: str = ""
class FolderInfo(BaseModel):
"""Information about a folder."""
id: str
name: str
parent_id: str | None = None
icon: str | None = None
color: str | None = None
agent_count: int = 0
subfolder_count: int = 0
agents: list[FolderAgentSummary] | None = None
class FolderTreeInfo(FolderInfo):
"""Folder with nested children for tree display."""
children: list["FolderTreeInfo"] = []
class FolderCreatedResponse(ToolResponseBase):
"""Response when a folder is created."""
type: ResponseType = ResponseType.FOLDER_CREATED
folder: FolderInfo
class FolderListResponse(ToolResponseBase):
"""Response for listing folders."""
type: ResponseType = ResponseType.FOLDER_LIST
folders: list[FolderInfo] = Field(default_factory=list)
tree: list[FolderTreeInfo] | None = None
root_agents: list[FolderAgentSummary] | None = None
count: int = 0
class FolderUpdatedResponse(ToolResponseBase):
"""Response when a folder is updated."""
type: ResponseType = ResponseType.FOLDER_UPDATED
folder: FolderInfo
class FolderMovedResponse(ToolResponseBase):
"""Response when a folder is moved."""
type: ResponseType = ResponseType.FOLDER_MOVED
folder: FolderInfo
target_parent_id: str | None = None
class FolderDeletedResponse(ToolResponseBase):
"""Response when a folder is deleted."""
type: ResponseType = ResponseType.FOLDER_DELETED
folder_id: str
class AgentsMovedToFolderResponse(ToolResponseBase):
"""Response when agents are moved to a folder."""
type: ResponseType = ResponseType.AGENTS_MOVED_TO_FOLDER
agent_ids: list[str]
agent_names: list[str] = []
folder_id: str | None = None
count: int = 0

View File

@@ -53,11 +53,15 @@ class RunMCPToolTool(BaseTool):
def description(self) -> str:
return (
"Connect to an MCP (Model Context Protocol) server to discover and execute its tools. "
"Call with just `server_url` to see available tools. "
"Then call again with `server_url`, `tool_name`, and `tool_arguments` to execute. "
"If the server requires authentication, the user will be prompted to connect it. "
"Find MCP servers at https://registry.modelcontextprotocol.io/ — hundreds of integrations "
"including GitHub, Postgres, Slack, filesystem, and more."
"Two-step workflow: (1) Call with just `server_url` to discover available tools. "
"(2) Call again with `server_url`, `tool_name`, and `tool_arguments` to execute. "
"Known hosted servers (use directly): Notion (https://mcp.notion.com/mcp), "
"Linear (https://mcp.linear.app/mcp), Stripe (https://mcp.stripe.com), "
"Intercom (https://mcp.intercom.com/mcp), Cloudflare (https://mcp.cloudflare.com/mcp), "
"Atlassian/Jira (https://mcp.atlassian.com/mcp). "
"For other services, search the MCP registry at https://registry.modelcontextprotocol.io/. "
"Authentication: If the server requires credentials, user will be prompted to complete the MCP credential setup flow."
"Once connected and user confirms, retry the same call immediately."
)
@property

View File

@@ -81,6 +81,7 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.CLAUDE_4_OPUS: 21,
LlmModel.CLAUDE_4_SONNET: 5,
LlmModel.CLAUDE_4_6_OPUS: 14,
LlmModel.CLAUDE_4_6_SONNET: 9,
LlmModel.CLAUDE_4_5_HAIKU: 4,
LlmModel.CLAUDE_4_5_OPUS: 14,
LlmModel.CLAUDE_4_5_SONNET: 9,

View File

@@ -4,11 +4,20 @@ from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cas
from backend.api.features.library.db import (
add_store_agent_to_library,
bulk_move_agents_to_folder,
create_folder,
create_graph_in_library,
create_library_agent,
delete_folder,
get_folder_agents_map,
get_folder_tree,
get_library_agent,
get_library_agent_by_graph_id,
get_root_agent_summaries,
list_folders,
list_library_agents,
move_folder,
update_folder,
update_graph_in_library,
)
from backend.api.features.store.db import (
@@ -260,6 +269,16 @@ class DatabaseManager(AppService):
update_graph_in_library = _(update_graph_in_library)
validate_graph_execution_permissions = _(validate_graph_execution_permissions)
create_folder = _(create_folder)
list_folders = _(list_folders)
get_folder_tree = _(get_folder_tree)
update_folder = _(update_folder)
move_folder = _(move_folder)
delete_folder = _(delete_folder)
bulk_move_agents_to_folder = _(bulk_move_agents_to_folder)
get_folder_agents_map = _(get_folder_agents_map)
get_root_agent_summaries = _(get_root_agent_summaries)
# ============ Onboarding ============ #
increment_onboarding_runs = _(increment_onboarding_runs)
@@ -305,6 +324,7 @@ class DatabaseManager(AppService):
delete_chat_session = _(chat_db.delete_chat_session)
get_next_sequence = _(chat_db.get_next_sequence)
update_tool_message_content = _(chat_db.update_tool_message_content)
update_chat_session_title = _(chat_db.update_chat_session_title)
class DatabaseManagerClient(AppServiceClient):
@@ -433,6 +453,17 @@ class DatabaseManagerAsyncClient(AppServiceClient):
update_graph_in_library = d.update_graph_in_library
validate_graph_execution_permissions = d.validate_graph_execution_permissions
# ============ Library Folders ============ #
create_folder = d.create_folder
list_folders = d.list_folders
get_folder_tree = d.get_folder_tree
update_folder = d.update_folder
move_folder = d.move_folder
delete_folder = d.delete_folder
bulk_move_agents_to_folder = d.bulk_move_agents_to_folder
get_folder_agents_map = d.get_folder_agents_map
get_root_agent_summaries = d.get_root_agent_summaries
# ============ Onboarding ============ #
increment_onboarding_runs = d.increment_onboarding_runs
@@ -475,3 +506,4 @@ class DatabaseManagerAsyncClient(AppServiceClient):
delete_chat_session = d.delete_chat_session
get_next_sequence = d.get_next_sequence
update_tool_message_content = d.update_tool_message_content
update_chat_session_title = d.update_chat_session_title

View File

@@ -344,7 +344,7 @@ class GraphExecution(GraphExecutionMeta):
),
**{
# input from webhook-triggered block
"payload": exec.input_data["payload"]
"payload": exec.input_data.get("payload")
for exec in complete_node_executions
if (
(block := get_block(exec.block_id))

View File

@@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, patch
from uuid import UUID
import fastapi.exceptions
import prisma
import pytest
from pytest_snapshot.plugin import Snapshot
@@ -250,8 +251,8 @@ async def test_clean_graph(server: SpinTestServer):
"_test_id": "node_with_secrets",
"input": "normal_value",
"control_test_input": "should be preserved",
"api_key": "secret_api_key_123", # Should be filtered
"password": "secret_password_456", # Should be filtered
"api_key": "secret_api_key_123", # Should be filtered # pragma: allowlist secret # noqa
"password": "secret_password_456", # Should be filtered # pragma: allowlist secret # noqa
"token": "secret_token_789", # Should be filtered
"credentials": { # Should be filtered
"id": "fake-github-credentials-id",
@@ -354,9 +355,24 @@ async def test_access_store_listing_graph(server: SpinTestServer):
create_graph, DEFAULT_USER_ID
)
# Ensure the default user has a Profile (required for store submissions)
existing_profile = await prisma.models.Profile.prisma().find_first(
where={"userId": DEFAULT_USER_ID}
)
if not existing_profile:
await prisma.models.Profile.prisma().create(
data=prisma.types.ProfileCreateInput(
userId=DEFAULT_USER_ID,
name="Default User",
username=f"default-user-{DEFAULT_USER_ID[:8]}",
description="Default test user profile",
links=[],
)
)
store_submission_request = store.StoreSubmissionRequest(
agent_id=created_graph.id,
agent_version=created_graph.version,
graph_id=created_graph.id,
graph_version=created_graph.version,
slug=created_graph.id,
name="Test name",
sub_heading="Test sub heading",
@@ -385,8 +401,8 @@ async def test_access_store_listing_graph(server: SpinTestServer):
assert False, "Failed to create store listing"
slv_id = (
store_listing.store_listing_version_id
if store_listing.store_listing_version_id is not None
store_listing.listing_version_id
if store_listing.listing_version_id is not None
else None
)

View File

@@ -184,17 +184,17 @@ async def find_webhook_by_credentials_and_props(
credentials_id: str,
webhook_type: str,
resource: str,
events: Optional[list[str]],
events: list[str] | None = None,
) -> Webhook | None:
webhook = await IntegrationWebhook.prisma().find_first(
where={
"userId": user_id,
"credentialsId": credentials_id,
"webhookType": webhook_type,
"resource": resource,
**({"events": {"has_every": events}} if events else {}),
},
)
where: IntegrationWebhookWhereInput = {
"userId": user_id,
"credentialsId": credentials_id,
"webhookType": webhook_type,
"resource": resource,
}
if events is not None:
where["events"] = {"has_every": events}
webhook = await IntegrationWebhook.prisma().find_first(where=where)
return Webhook.from_db(webhook) if webhook else None

View File

@@ -0,0 +1,601 @@
import asyncio
import csv
import io
import logging
import re
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Literal, Optional
from uuid import uuid4
import prisma.enums
import prisma.models
import prisma.types
from pydantic import BaseModel, EmailStr, TypeAdapter, ValidationError
from backend.data.db import transaction
from backend.data.model import User
from backend.data.tally import get_business_understanding_input_from_tally
from backend.data.understanding import (
BusinessUnderstandingInput,
merge_business_understanding_data,
)
from backend.util.exceptions import (
NotAuthorizedError,
NotFoundError,
PreconditionFailed,
)
from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
_tally_seed_tasks: set[asyncio.Task] = set()
_email_adapter = TypeAdapter(EmailStr)
MAX_BULK_INVITE_FILE_BYTES = 1024 * 1024
MAX_BULK_INVITE_ROWS = 500
class InvitedUserRecord(BaseModel):
id: str
email: str
status: prisma.enums.InvitedUserStatus
auth_user_id: Optional[str] = None
name: Optional[str] = None
tally_understanding: Optional[dict[str, Any]] = None
tally_status: prisma.enums.TallyComputationStatus
tally_computed_at: Optional[datetime] = None
tally_error: Optional[str] = None
created_at: datetime
updated_at: datetime
@classmethod
def from_db(cls, invited_user: "prisma.models.InvitedUser") -> "InvitedUserRecord":
payload = (
invited_user.tallyUnderstanding
if isinstance(invited_user.tallyUnderstanding, dict)
else None
)
return cls(
id=invited_user.id,
email=invited_user.email,
status=invited_user.status,
auth_user_id=invited_user.authUserId,
name=invited_user.name,
tally_understanding=payload,
tally_status=invited_user.tallyStatus,
tally_computed_at=invited_user.tallyComputedAt,
tally_error=invited_user.tallyError,
created_at=invited_user.createdAt,
updated_at=invited_user.updatedAt,
)
class BulkInvitedUserRowResult(BaseModel):
row_number: int
email: Optional[str] = None
name: Optional[str] = None
status: Literal["CREATED", "SKIPPED", "ERROR"]
message: str
invited_user: Optional[InvitedUserRecord] = None
class BulkInvitedUsersResult(BaseModel):
created_count: int
skipped_count: int
error_count: int
results: list[BulkInvitedUserRowResult]
@dataclass(frozen=True)
class _ParsedInviteRow:
row_number: int
email: str
name: Optional[str]
def normalize_email(email: str) -> str:
return email.strip().lower()
def _normalize_name(name: Optional[str]) -> Optional[str]:
if name is None:
return None
normalized = name.strip()
return normalized or None
def _default_profile_name(email: str, preferred_name: Optional[str]) -> str:
if preferred_name:
return preferred_name
local_part = email.split("@", 1)[0].strip()
return local_part or "user"
def _sanitize_username_base(email: str) -> str:
local_part = email.split("@", 1)[0].lower()
sanitized = re.sub(r"[^a-z0-9-]", "", local_part)
sanitized = sanitized.strip("-")
return sanitized[:40] or "user"
async def _generate_unique_profile_username(email: str, tx) -> str:
base = _sanitize_username_base(email)
for attempt in range(10):
candidate = base if attempt == 0 else f"{base}-{uuid4().hex[:6]}"
existing = await prisma.models.Profile.prisma(tx).find_unique(
where={"username": candidate}
)
if existing is None:
return candidate
raise RuntimeError(f"Unable to generate unique username for {email}")
async def _ensure_default_profile(
user_id: str,
email: str,
preferred_name: Optional[str],
tx,
) -> None:
existing_profile = await prisma.models.Profile.prisma(tx).find_unique(
where={"userId": user_id}
)
if existing_profile is not None:
return
username = await _generate_unique_profile_username(email, tx)
await prisma.models.Profile.prisma(tx).create(
data=prisma.types.ProfileCreateInput(
userId=user_id,
name=_default_profile_name(email, preferred_name),
username=username,
description="I'm new here",
links=[],
avatarUrl="",
)
)
async def _ensure_default_onboarding(user_id: str, tx) -> None:
await prisma.models.UserOnboarding.prisma(tx).upsert(
where={"userId": user_id},
data={
"create": prisma.types.UserOnboardingCreateInput(userId=user_id),
"update": {},
},
)
async def _apply_tally_understanding(
user_id: str,
invited_user: "prisma.models.InvitedUser",
tx,
) -> None:
if not isinstance(invited_user.tallyUnderstanding, dict):
return
input_data = BusinessUnderstandingInput.model_validate(
invited_user.tallyUnderstanding
)
payload = merge_business_understanding_data({}, input_data)
await prisma.models.CoPilotUnderstanding.prisma(tx).upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "data": SafeJson(payload)},
"update": {"data": SafeJson(payload)},
},
)
async def list_invited_users() -> list[InvitedUserRecord]:
invited_users = await prisma.models.InvitedUser.prisma().find_many(
order={"createdAt": "desc"}
)
return [InvitedUserRecord.from_db(invited_user) for invited_user in invited_users]
async def create_invited_user(
email: str, name: Optional[str] = None
) -> InvitedUserRecord:
normalized_email = normalize_email(email)
normalized_name = _normalize_name(name)
existing_user = await prisma.models.User.prisma().find_unique(
where={"email": normalized_email}
)
if existing_user is not None:
raise PreconditionFailed("An active user with this email already exists")
existing_invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"email": normalized_email}
)
if existing_invited_user is not None:
raise PreconditionFailed("An invited user with this email already exists")
invited_user = await prisma.models.InvitedUser.prisma().create(
data={
"email": normalized_email,
"name": normalized_name,
"status": prisma.enums.InvitedUserStatus.INVITED,
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
}
)
schedule_invited_user_tally_precompute(invited_user.id)
return InvitedUserRecord.from_db(invited_user)
async def revoke_invited_user(invited_user_id: str) -> InvitedUserRecord:
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"id": invited_user_id}
)
if invited_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
if invited_user.status == prisma.enums.InvitedUserStatus.CLAIMED:
raise PreconditionFailed("Claimed invited users cannot be revoked")
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
return InvitedUserRecord.from_db(invited_user)
revoked_user = await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={"status": prisma.enums.InvitedUserStatus.REVOKED},
)
if revoked_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
return InvitedUserRecord.from_db(revoked_user)
async def retry_invited_user_tally(invited_user_id: str) -> InvitedUserRecord:
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"id": invited_user_id}
)
if invited_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
raise PreconditionFailed("Revoked invited users cannot retry Tally seeding")
refreshed_user = await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyUnderstanding": None,
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
"tallyComputedAt": None,
"tallyError": None,
},
)
if refreshed_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
schedule_invited_user_tally_precompute(invited_user_id)
return InvitedUserRecord.from_db(refreshed_user)
def _decode_bulk_invite_file(content: bytes) -> str:
if len(content) > MAX_BULK_INVITE_FILE_BYTES:
raise ValueError("Invite file exceeds the maximum size of 1 MB")
try:
return content.decode("utf-8-sig")
except UnicodeDecodeError as exc:
raise ValueError("Invite file must be UTF-8 encoded") from exc
def _parse_bulk_invite_csv(text: str) -> list[_ParsedInviteRow]:
indexed_rows: list[tuple[int, list[str]]] = []
for row_number, row in enumerate(csv.reader(io.StringIO(text)), start=1):
normalized_row = [cell.strip() for cell in row]
if any(normalized_row):
indexed_rows.append((row_number, normalized_row))
if not indexed_rows:
return []
header = [cell.lower() for cell in indexed_rows[0][1]]
has_header = "email" in header
email_index = header.index("email") if has_header else 0
name_index = header.index("name") if has_header and "name" in header else 1
data_rows = indexed_rows[1:] if has_header else indexed_rows
parsed_rows: list[_ParsedInviteRow] = []
for row_number, row in data_rows:
email = row[email_index].strip() if len(row) > email_index else ""
name = row[name_index].strip() if len(row) > name_index else ""
parsed_rows.append(
_ParsedInviteRow(
row_number=row_number,
email=email,
name=name or None,
)
)
return parsed_rows
def _parse_bulk_invite_text(text: str) -> list[_ParsedInviteRow]:
parsed_rows: list[_ParsedInviteRow] = []
for row_number, raw_line in enumerate(text.splitlines(), start=1):
line = raw_line.strip()
if not line or line.startswith("#"):
continue
parsed_rows.append(
_ParsedInviteRow(
row_number=row_number,
email=line,
name=None,
)
)
return parsed_rows
def _parse_bulk_invite_file(
filename: Optional[str],
content: bytes,
) -> list[_ParsedInviteRow]:
text = _decode_bulk_invite_file(content)
file_name = filename.lower() if filename else ""
parsed_rows = (
_parse_bulk_invite_csv(text)
if file_name.endswith(".csv")
else _parse_bulk_invite_text(text)
)
if not parsed_rows:
raise ValueError("Invite file did not contain any emails")
if len(parsed_rows) > MAX_BULK_INVITE_ROWS:
raise ValueError(
f"Invite file contains too many rows. Maximum supported rows: {MAX_BULK_INVITE_ROWS}"
)
return parsed_rows
async def bulk_create_invited_users_from_file(
filename: Optional[str],
content: bytes,
) -> BulkInvitedUsersResult:
parsed_rows = _parse_bulk_invite_file(filename, content)
created_count = 0
skipped_count = 0
error_count = 0
results: list[BulkInvitedUserRowResult] = []
seen_emails: set[str] = set()
for row in parsed_rows:
row_name = _normalize_name(row.name)
try:
validated_email = _email_adapter.validate_python(row.email)
except ValidationError:
error_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=row.email or None,
name=row_name,
status="ERROR",
message="Invalid email address",
)
)
continue
normalized_email = normalize_email(str(validated_email))
if normalized_email in seen_emails:
skipped_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="SKIPPED",
message="Duplicate email in upload file",
)
)
continue
seen_emails.add(normalized_email)
try:
invited_user = await create_invited_user(normalized_email, row_name)
except PreconditionFailed as exc:
skipped_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="SKIPPED",
message=str(exc),
)
)
except Exception:
logger.exception(
"Failed to create bulk invite for %s from row %s",
normalized_email,
row.row_number,
)
error_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="ERROR",
message="Unexpected error creating invite",
)
)
else:
created_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="CREATED",
message="Invite created",
invited_user=invited_user,
)
)
return BulkInvitedUsersResult(
created_count=created_count,
skipped_count=skipped_count,
error_count=error_count,
results=results,
)
async def _compute_invited_user_tally_seed(invited_user_id: str) -> None:
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"id": invited_user_id}
)
if invited_user is None:
return
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
return
await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyStatus": prisma.enums.TallyComputationStatus.RUNNING,
"tallyError": None,
},
)
try:
input_data = await get_business_understanding_input_from_tally(
invited_user.email,
require_api_key=True,
)
payload = (
SafeJson(input_data.model_dump(exclude_none=True))
if input_data is not None
else None
)
await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyUnderstanding": payload,
"tallyStatus": prisma.enums.TallyComputationStatus.READY,
"tallyComputedAt": datetime.now(timezone.utc),
"tallyError": None,
},
)
except Exception as exc:
logger.exception(
"Failed to compute Tally understanding for invited user %s",
invited_user_id,
)
await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyStatus": prisma.enums.TallyComputationStatus.FAILED,
"tallyError": str(exc),
},
)
def schedule_invited_user_tally_precompute(invited_user_id: str) -> None:
task = asyncio.create_task(_compute_invited_user_tally_seed(invited_user_id))
_tally_seed_tasks.add(task)
task.add_done_callback(_tally_seed_tasks.discard)
async def get_or_activate_user(user_data: dict) -> User:
auth_user_id = user_data.get("sub")
if not auth_user_id:
raise NotAuthorizedError("User ID not found in token")
auth_email = user_data.get("email")
if not auth_email:
raise NotAuthorizedError("Email not found in token")
normalized_email = normalize_email(auth_email)
user_metadata = user_data.get("user_metadata")
metadata_name = (
user_metadata.get("name") if isinstance(user_metadata, dict) else None
)
existing_user = await prisma.models.User.prisma().find_unique(
where={"id": auth_user_id}
)
if existing_user is not None:
return User.from_db(existing_user)
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"email": normalized_email}
)
if invited_user is None:
raise NotAuthorizedError("Your email is not allowed to access the platform")
if invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
raise NotAuthorizedError("Your invitation is no longer active")
async with transaction() as tx:
current_user = await prisma.models.User.prisma(tx).find_unique(
where={"id": auth_user_id}
)
if current_user is not None:
return User.from_db(current_user)
current_invited_user = await prisma.models.InvitedUser.prisma(tx).find_unique(
where={"email": normalized_email}
)
if current_invited_user is None:
raise NotAuthorizedError("Your email is not allowed to access the platform")
if current_invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
raise NotAuthorizedError("Your invitation is no longer active")
if current_invited_user.authUserId not in (None, auth_user_id):
raise NotAuthorizedError("Your invitation has already been claimed")
preferred_name = current_invited_user.name or _normalize_name(metadata_name)
await prisma.models.User.prisma(tx).create(
data=prisma.types.UserCreateInput(
id=auth_user_id,
email=normalized_email,
name=preferred_name,
)
)
await prisma.models.InvitedUser.prisma(tx).update(
where={"id": current_invited_user.id},
data={
"status": prisma.enums.InvitedUserStatus.CLAIMED,
"authUserId": auth_user_id,
},
)
await _ensure_default_profile(
auth_user_id,
normalized_email,
preferred_name,
tx,
)
await _ensure_default_onboarding(auth_user_id, tx)
await _apply_tally_understanding(auth_user_id, current_invited_user, tx)
from backend.data.user import get_user_by_email, get_user_by_id
get_user_by_id.cache_delete(auth_user_id)
get_user_by_email.cache_delete(normalized_email)
activated_user = await prisma.models.User.prisma().find_unique(
where={"id": auth_user_id}
)
if activated_user is None:
raise RuntimeError(
f"Activated user {auth_user_id} was not found after creation"
)
return User.from_db(activated_user)

View File

@@ -0,0 +1,297 @@
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock
import prisma.enums
import pytest
import pytest_mock
from backend.util.exceptions import NotAuthorizedError, PreconditionFailed
from .invited_user import (
bulk_create_invited_users_from_file,
create_invited_user,
get_or_activate_user,
retry_invited_user_tally,
)
def _invited_user_db_record(
*,
status: prisma.enums.InvitedUserStatus = prisma.enums.InvitedUserStatus.INVITED,
tally_understanding: dict | None = None,
):
now = datetime.now(timezone.utc)
return SimpleNamespace(
id="invite-1",
email="invited@example.com",
status=status,
authUserId=None,
name="Invited User",
tallyUnderstanding=tally_understanding,
tallyStatus=prisma.enums.TallyComputationStatus.PENDING,
tallyComputedAt=None,
tallyError=None,
createdAt=now,
updatedAt=now,
)
def _user_db_record():
now = datetime.now(timezone.utc)
return SimpleNamespace(
id="auth-user-1",
email="invited@example.com",
emailVerified=True,
name="Invited User",
createdAt=now,
updatedAt=now,
metadata={},
integrations="",
stripeCustomerId=None,
topUpConfig=None,
maxEmailsPerDay=3,
notifyOnAgentRun=True,
notifyOnZeroBalance=True,
notifyOnLowBalance=True,
notifyOnBlockExecutionFailed=True,
notifyOnContinuousAgentError=True,
notifyOnDailySummary=True,
notifyOnWeeklySummary=True,
notifyOnMonthlySummary=True,
notifyOnAgentApproved=True,
notifyOnAgentRejected=True,
timezone="not-set",
)
@pytest.mark.asyncio
async def test_create_invited_user_rejects_existing_active_user(
mocker: pytest_mock.MockerFixture,
) -> None:
user_repo = Mock()
user_repo.find_unique = AsyncMock(return_value=_user_db_record())
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock()
mocker.patch(
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
with pytest.raises(PreconditionFailed):
await create_invited_user("Invited@example.com")
@pytest.mark.asyncio
async def test_create_invited_user_schedules_tally_seed(
mocker: pytest_mock.MockerFixture,
) -> None:
user_repo = Mock()
user_repo.find_unique = AsyncMock(return_value=None)
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock(return_value=None)
invited_user_repo.create = AsyncMock(return_value=_invited_user_db_record())
schedule = mocker.patch(
"backend.data.invited_user.schedule_invited_user_tally_precompute"
)
mocker.patch(
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
invited_user = await create_invited_user("Invited@example.com", "Invited User")
assert invited_user.email == "invited@example.com"
invited_user_repo.create.assert_awaited_once()
schedule.assert_called_once_with("invite-1")
@pytest.mark.asyncio
async def test_retry_invited_user_tally_resets_state_and_schedules(
mocker: pytest_mock.MockerFixture,
) -> None:
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock(return_value=_invited_user_db_record())
invited_user_repo.update = AsyncMock(return_value=_invited_user_db_record())
schedule = mocker.patch(
"backend.data.invited_user.schedule_invited_user_tally_precompute"
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
invited_user = await retry_invited_user_tally("invite-1")
assert invited_user.id == "invite-1"
invited_user_repo.update.assert_awaited_once()
schedule.assert_called_once_with("invite-1")
@pytest.mark.asyncio
async def test_get_or_activate_user_requires_invite(
mocker: pytest_mock.MockerFixture,
) -> None:
user_repo = Mock()
user_repo.find_unique = AsyncMock(return_value=None)
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock(return_value=None)
mocker.patch(
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
with pytest.raises(NotAuthorizedError):
await get_or_activate_user(
{"sub": "auth-user-1", "email": "invited@example.com"}
)
@pytest.mark.asyncio
async def test_get_or_activate_user_creates_user_from_invite(
mocker: pytest_mock.MockerFixture,
) -> None:
tx = object()
invited_user = _invited_user_db_record(
tally_understanding={"user_name": "Invited User", "industry": "Automation"}
)
created_user = _user_db_record()
outside_user_repo = Mock()
outside_user_repo.find_unique = AsyncMock(side_effect=[None, created_user])
inside_user_repo = Mock()
inside_user_repo.find_unique = AsyncMock(return_value=None)
inside_user_repo.create = AsyncMock(return_value=created_user)
outside_invited_repo = Mock()
outside_invited_repo.find_unique = AsyncMock(return_value=invited_user)
inside_invited_repo = Mock()
inside_invited_repo.find_unique = AsyncMock(return_value=invited_user)
inside_invited_repo.update = AsyncMock(return_value=invited_user)
def user_prisma(client=None):
return inside_user_repo if client is tx else outside_user_repo
def invited_user_prisma(client=None):
return inside_invited_repo if client is tx else outside_invited_repo
@asynccontextmanager
async def fake_transaction():
yield tx
ensure_profile = mocker.patch(
"backend.data.invited_user._ensure_default_profile",
AsyncMock(),
)
ensure_onboarding = mocker.patch(
"backend.data.invited_user._ensure_default_onboarding",
AsyncMock(),
)
apply_tally = mocker.patch(
"backend.data.invited_user._apply_tally_understanding",
AsyncMock(),
)
mocker.patch("backend.data.invited_user.transaction", fake_transaction)
mocker.patch(
"backend.data.invited_user.prisma.models.User.prisma", side_effect=user_prisma
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
side_effect=invited_user_prisma,
)
mocker.patch("backend.data.user.get_user_by_id.cache_delete", Mock())
mocker.patch("backend.data.user.get_user_by_email.cache_delete", Mock())
user = await get_or_activate_user(
{
"sub": "auth-user-1",
"email": "Invited@example.com",
"user_metadata": {"name": "Invited User"},
}
)
assert user.id == "auth-user-1"
inside_user_repo.create.assert_awaited_once()
inside_invited_repo.update.assert_awaited_once()
ensure_profile.assert_awaited_once()
ensure_onboarding.assert_awaited_once_with("auth-user-1", tx)
apply_tally.assert_awaited_once_with("auth-user-1", invited_user, tx)
@pytest.mark.asyncio
async def test_bulk_create_invited_users_from_text_file(
mocker: pytest_mock.MockerFixture,
) -> None:
create_invited = mocker.patch(
"backend.data.invited_user.create_invited_user",
AsyncMock(
side_effect=[
_invited_user_db_record(),
_invited_user_db_record(),
]
),
)
result = await bulk_create_invited_users_from_file(
"invites.txt",
b"Invited@example.com\nsecond@example.com\n",
)
assert result.created_count == 2
assert result.skipped_count == 0
assert result.error_count == 0
assert [row.status for row in result.results] == ["CREATED", "CREATED"]
assert create_invited.await_count == 2
@pytest.mark.asyncio
async def test_bulk_create_invited_users_handles_csv_duplicates_and_invalid_rows(
mocker: pytest_mock.MockerFixture,
) -> None:
create_invited = mocker.patch(
"backend.data.invited_user.create_invited_user",
AsyncMock(
side_effect=[
_invited_user_db_record(),
PreconditionFailed("An invited user with this email already exists"),
]
),
)
result = await bulk_create_invited_users_from_file(
"invites.csv",
(
"email,name\n"
"valid@example.com,Valid User\n"
"not-an-email,Bad Row\n"
"valid@example.com,Duplicate In File\n"
"existing@example.com,Existing User\n"
).encode("utf-8"),
)
assert result.created_count == 1
assert result.skipped_count == 2
assert result.error_count == 1
assert [row.status for row in result.results] == [
"CREATED",
"ERROR",
"SKIPPED",
"SKIPPED",
]
assert create_invited.await_count == 2

View File

@@ -13,7 +13,14 @@ from prisma.types import (
)
# from backend.notifications.models import NotificationEvent
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
from pydantic import (
BaseModel,
ConfigDict,
EmailStr,
Field,
field_validator,
model_validator,
)
from backend.util.exceptions import DatabaseError
from backend.util.json import SafeJson
@@ -175,10 +182,26 @@ class RefundRequestData(BaseNotificationData):
balance: int
class AgentApprovalData(BaseNotificationData):
class _LegacyAgentFieldsMixin:
"""Temporary patch to handle existing queued payloads"""
# FIXME: remove in next release
@model_validator(mode="before")
@classmethod
def _map_legacy_agent_fields(cls, values: Any):
if isinstance(values, dict):
if "graph_id" not in values and "agent_id" in values:
values["graph_id"] = values.pop("agent_id")
if "graph_version" not in values and "agent_version" in values:
values["graph_version"] = values.pop("agent_version")
return values
class AgentApprovalData(_LegacyAgentFieldsMixin, BaseNotificationData):
agent_name: str
agent_id: str
agent_version: int
graph_id: str
graph_version: int
reviewer_name: str
reviewer_email: str
comments: str
@@ -193,10 +216,10 @@ class AgentApprovalData(BaseNotificationData):
return value
class AgentRejectionData(BaseNotificationData):
class AgentRejectionData(_LegacyAgentFieldsMixin, BaseNotificationData):
agent_name: str
agent_id: str
agent_version: int
graph_id: str
graph_version: int
reviewer_name: str
reviewer_email: str
comments: str

View File

@@ -15,8 +15,8 @@ class TestAgentApprovalData:
"""Test creating valid AgentApprovalData."""
data = AgentApprovalData(
agent_name="Test Agent",
agent_id="test-agent-123",
agent_version=1,
graph_id="test-agent-123",
graph_version=1,
reviewer_name="John Doe",
reviewer_email="john@example.com",
comments="Great agent, approved!",
@@ -25,8 +25,8 @@ class TestAgentApprovalData:
)
assert data.agent_name == "Test Agent"
assert data.agent_id == "test-agent-123"
assert data.agent_version == 1
assert data.graph_id == "test-agent-123"
assert data.graph_version == 1
assert data.reviewer_name == "John Doe"
assert data.reviewer_email == "john@example.com"
assert data.comments == "Great agent, approved!"
@@ -40,8 +40,8 @@ class TestAgentApprovalData:
):
AgentApprovalData(
agent_name="Test Agent",
agent_id="test-agent-123",
agent_version=1,
graph_id="test-agent-123",
graph_version=1,
reviewer_name="John Doe",
reviewer_email="john@example.com",
comments="Great agent, approved!",
@@ -53,8 +53,8 @@ class TestAgentApprovalData:
"""Test AgentApprovalData with empty comments."""
data = AgentApprovalData(
agent_name="Test Agent",
agent_id="test-agent-123",
agent_version=1,
graph_id="test-agent-123",
graph_version=1,
reviewer_name="John Doe",
reviewer_email="john@example.com",
comments="", # Empty comments
@@ -72,8 +72,8 @@ class TestAgentRejectionData:
"""Test creating valid AgentRejectionData."""
data = AgentRejectionData(
agent_name="Test Agent",
agent_id="test-agent-123",
agent_version=1,
graph_id="test-agent-123",
graph_version=1,
reviewer_name="Jane Doe",
reviewer_email="jane@example.com",
comments="Please fix the security issues before resubmitting.",
@@ -82,8 +82,8 @@ class TestAgentRejectionData:
)
assert data.agent_name == "Test Agent"
assert data.agent_id == "test-agent-123"
assert data.agent_version == 1
assert data.graph_id == "test-agent-123"
assert data.graph_version == 1
assert data.reviewer_name == "Jane Doe"
assert data.reviewer_email == "jane@example.com"
assert data.comments == "Please fix the security issues before resubmitting."
@@ -97,8 +97,8 @@ class TestAgentRejectionData:
):
AgentRejectionData(
agent_name="Test Agent",
agent_id="test-agent-123",
agent_version=1,
graph_id="test-agent-123",
graph_version=1,
reviewer_name="Jane Doe",
reviewer_email="jane@example.com",
comments="Please fix the security issues.",
@@ -111,8 +111,8 @@ class TestAgentRejectionData:
long_comment = "A" * 1000 # Very long comment
data = AgentRejectionData(
agent_name="Test Agent",
agent_id="test-agent-123",
agent_version=1,
graph_id="test-agent-123",
graph_version=1,
reviewer_name="Jane Doe",
reviewer_email="jane@example.com",
comments=long_comment,
@@ -126,8 +126,8 @@ class TestAgentRejectionData:
"""Test that models can be serialized and deserialized."""
original_data = AgentRejectionData(
agent_name="Test Agent",
agent_id="test-agent-123",
agent_version=1,
graph_id="test-agent-123",
graph_version=1,
reviewer_name="Jane Doe",
reviewer_email="jane@example.com",
comments="Please fix the issues.",
@@ -142,8 +142,8 @@ class TestAgentRejectionData:
restored_data = AgentRejectionData.model_validate(data_dict)
assert restored_data.agent_name == original_data.agent_name
assert restored_data.agent_id == original_data.agent_id
assert restored_data.agent_version == original_data.agent_version
assert restored_data.graph_id == original_data.graph_id
assert restored_data.graph_version == original_data.graph_version
assert restored_data.reviewer_name == original_data.reviewer_name
assert restored_data.reviewer_email == original_data.reviewer_email
assert restored_data.comments == original_data.comments

View File

@@ -244,7 +244,10 @@ def _clean_and_split(text: str) -> list[str]:
def _calculate_points(
agent, categories: list[str], custom: list[str], integrations: list[str]
agent: prisma.models.StoreAgent,
categories: list[str],
custom: list[str],
integrations: list[str],
) -> int:
"""
Calculates the total points for an agent based on the specified criteria.
@@ -397,7 +400,7 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
storeAgents = await prisma.models.StoreAgent.prisma().find_many(
where={
"is_available": True,
"useForOnboarding": True,
"use_for_onboarding": True,
},
order=[
{"featured": "desc"},
@@ -407,7 +410,7 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
take=100,
)
# If not enough agents found, relax the useForOnboarding filter
# If not enough agents found, relax the use_for_onboarding filter
if len(storeAgents) < 2:
storeAgents = await prisma.models.StoreAgent.prisma().find_many(
where=prisma.types.StoreAgentWhereInput(**where_clause),
@@ -420,7 +423,7 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
)
# Calculate points for the first X agents and choose the top 2
agent_points = []
agent_points: list[tuple[prisma.models.StoreAgent, int]] = []
for agent in storeAgents[:POINTS_AGENT_COUNT]:
points = _calculate_points(
agent, categories, custom, user_onboarding.integrations
@@ -430,28 +433,7 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
agent_points.sort(key=lambda x: x[1], reverse=True)
recommended_agents = [agent for agent, _ in agent_points[:2]]
return [
StoreAgentDetails(
store_listing_version_id=agent.storeListingVersionId,
slug=agent.slug,
agent_name=agent.agent_name,
agent_video=agent.agent_video or "",
agent_output_demo=agent.agent_output_demo or "",
agent_image=agent.agent_image,
creator=agent.creator_username,
creator_avatar=agent.creator_avatar,
sub_heading=agent.sub_heading,
description=agent.description,
categories=agent.categories,
runs=agent.runs,
rating=agent.rating,
versions=agent.versions,
agentGraphVersions=agent.agentGraphVersions,
agentGraphId=agent.agentGraphId,
last_updated=agent.updated_at,
)
for agent in recommended_agents
]
return [StoreAgentDetails.from_db(agent) for agent in recommended_agents]
@cached(maxsize=1, ttl_seconds=300) # Cache for 5 minutes since this rarely changes

View File

@@ -380,6 +380,35 @@ async def extract_business_understanding(
return BusinessUnderstandingInput(**cleaned)
async def get_business_understanding_input_from_tally(
email: str,
*,
require_api_key: bool = False,
) -> Optional[BusinessUnderstandingInput]:
settings = Settings()
if not settings.secrets.tally_api_key:
if require_api_key:
raise RuntimeError("Tally API key is not configured")
logger.debug("Tally: no API key configured, skipping")
return None
masked = _mask_email(email)
result = await find_submission_by_email(TALLY_FORM_ID, email)
if result is None:
logger.debug(f"Tally: no submission found for {masked}")
return None
submission, questions = result
logger.info(f"Tally: found submission for {masked}, extracting understanding")
formatted = format_submission_for_llm(submission, questions)
if not formatted.strip():
logger.warning("Tally: formatted submission was empty, skipping")
return None
return await extract_business_understanding(formatted)
async def populate_understanding_from_tally(user_id: str, email: str) -> None:
"""Main orchestrator: check Tally for a matching submission and populate understanding.
@@ -394,30 +423,10 @@ async def populate_understanding_from_tally(user_id: str, email: str) -> None:
)
return
# Check API key is configured
settings = Settings()
if not settings.secrets.tally_api_key:
logger.debug("Tally: no API key configured, skipping")
understanding_input = await get_business_understanding_input_from_tally(email)
if understanding_input is None:
return
# Look up submission by email
masked = _mask_email(email)
result = await find_submission_by_email(TALLY_FORM_ID, email)
if result is None:
logger.debug(f"Tally: no submission found for {masked}")
return
submission, questions = result
logger.info(f"Tally: found submission for {masked}, extracting understanding")
# Format and extract
formatted = format_submission_for_llm(submission, questions)
if not formatted.strip():
logger.warning("Tally: formatted submission was empty, skipping")
return
understanding_input = await extract_business_understanding(formatted)
# Upsert into database
await upsert_business_understanding(user_id, understanding_input)
logger.info(f"Tally: successfully populated understanding for user {user_id}")

View File

@@ -166,6 +166,56 @@ def _merge_lists(existing: list | None, new: list | None) -> list | None:
return merged
def merge_business_understanding_data(
existing_data: dict[str, Any],
input_data: BusinessUnderstandingInput,
) -> dict[str, Any]:
merged_data = dict(existing_data)
merged_business: dict[str, Any] = {}
if isinstance(merged_data.get("business"), dict):
merged_business = dict(merged_data["business"])
business_string_fields = [
"job_title",
"business_name",
"industry",
"business_size",
"user_role",
"additional_notes",
]
business_list_fields = [
"key_workflows",
"daily_activities",
"pain_points",
"bottlenecks",
"manual_tasks",
"automation_goals",
"current_software",
"existing_automation",
]
if input_data.user_name is not None:
merged_data["name"] = input_data.user_name
for field in business_string_fields:
value = getattr(input_data, field)
if value is not None:
merged_business[field] = value
for field in business_list_fields:
value = getattr(input_data, field)
if value is not None:
existing_list = _json_to_list(merged_business.get(field))
merged_list = _merge_lists(existing_list, value)
merged_business[field] = merged_list
merged_business["version"] = 1
merged_data["business"] = merged_business
return merged_data
async def _get_from_cache(user_id: str) -> Optional[BusinessUnderstanding]:
"""Get business understanding from Redis cache."""
try:
@@ -245,63 +295,18 @@ async def upsert_business_understanding(
where={"userId": user_id}
)
# Get existing data structure or start fresh
existing_data: dict[str, Any] = {}
if existing and isinstance(existing.data, dict):
existing_data = dict(existing.data)
existing_business: dict[str, Any] = {}
if isinstance(existing_data.get("business"), dict):
existing_business = dict(existing_data["business"])
# Business fields (stored inside business object)
business_string_fields = [
"job_title",
"business_name",
"industry",
"business_size",
"user_role",
"additional_notes",
]
business_list_fields = [
"key_workflows",
"daily_activities",
"pain_points",
"bottlenecks",
"manual_tasks",
"automation_goals",
"current_software",
"existing_automation",
]
# Handle top-level name field
if input_data.user_name is not None:
existing_data["name"] = input_data.user_name
# Business string fields - overwrite if provided
for field in business_string_fields:
value = getattr(input_data, field)
if value is not None:
existing_business[field] = value
# Business list fields - merge with existing
for field in business_list_fields:
value = getattr(input_data, field)
if value is not None:
existing_list = _json_to_list(existing_business.get(field))
merged = _merge_lists(existing_list, value)
existing_business[field] = merged
# Set version and nest business data
existing_business["version"] = 1
existing_data["business"] = existing_business
merged_data = merge_business_understanding_data(existing_data, input_data)
# Upsert with the merged data
record = await CoPilotUnderstanding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "data": SafeJson(existing_data)},
"update": {"data": SafeJson(existing_data)},
"create": {"userId": user_id, "data": SafeJson(merged_data)},
"update": {"data": SafeJson(merged_data)},
},
)

View File

@@ -2,6 +2,7 @@ import logging
from unittest.mock import AsyncMock, patch
import fastapi.responses
import prisma
import pytest
import backend.api.features.library.model
@@ -497,9 +498,24 @@ async def test_store_listing_graph(server: SpinTestServer):
test_user = await create_test_user()
test_graph = await create_graph(server, create_test_graph(), test_user)
# Ensure the test user has a Profile (required for store submissions)
existing_profile = await prisma.models.Profile.prisma().find_first(
where={"userId": test_user.id}
)
if not existing_profile:
await prisma.models.Profile.prisma().create(
data=prisma.types.ProfileCreateInput(
userId=test_user.id,
name=test_user.name or "Test User",
username=f"test-user-{test_user.id[:8]}",
description="Test user profile",
links=[],
)
)
store_submission_request = backend.api.features.store.model.StoreSubmissionRequest(
agent_id=test_graph.id,
agent_version=test_graph.version,
graph_id=test_graph.id,
graph_version=test_graph.version,
slug=test_graph.id,
name="Test name",
sub_heading="Test sub heading",
@@ -517,8 +533,8 @@ async def test_store_listing_graph(server: SpinTestServer):
assert False, "Failed to create store listing"
slv_id = (
store_listing.store_listing_version_id
if store_listing.store_listing_version_id is not None
store_listing.listing_version_id
if store_listing.listing_version_id is not None
else None
)

View File

@@ -481,6 +481,22 @@ async def _construct_starting_node_execution_input(
if nodes_input_masks and (node_input_mask := nodes_input_masks.get(node.id)):
input_data.update(node_input_mask)
# Webhook-triggered agents cannot be executed directly without payload data.
# Legitimate webhook triggers provide payload via nodes_input_masks above.
if (
block.block_type
in (
BlockType.WEBHOOK,
BlockType.WEBHOOK_MANUAL,
)
and "payload" not in input_data
):
raise ValueError(
"This agent is triggered by an external event (webhook) "
"and cannot be executed directly. "
"Please use the appropriate trigger to run this agent."
)
input_data, error = validate_exec(node, input_data)
if input_data is None:
raise ValueError(error)

View File

@@ -76,7 +76,6 @@ class TelegramWebhooksManager(BaseWebhooksManager):
credentials_id=credentials.id,
webhook_type=webhook_type,
resource=resource,
events=None, # Ignore events for this lookup
):
# Re-register with Telegram using the same URL but new allowed_updates
ingress_url = webhook_ingress_url(self.PROVIDER_NAME, existing.id)
@@ -143,10 +142,6 @@ class TelegramWebhooksManager(BaseWebhooksManager):
elif "video" in message:
event_type = "message.video"
else:
logger.warning(
"Unknown Telegram webhook payload type; "
f"message.keys() = {message.keys()}"
)
event_type = "message.other"
elif "edited_message" in payload:
event_type = "message.edited_message"

View File

@@ -2,8 +2,8 @@
{#
Template variables:
data.agent_name: the name of the approved agent
data.agent_id: the ID of the agent
data.agent_version: the version of the agent
data.graph_id: the ID of the agent
data.graph_version: the version of the agent
data.reviewer_name: the name of the reviewer who approved it
data.reviewer_email: the email of the reviewer
data.comments: comments from the reviewer
@@ -70,4 +70,4 @@
Thank you for contributing to the AutoGPT ecosystem! 🚀
</p>
{% endblock %}
{% endblock %}

View File

@@ -2,8 +2,8 @@
{#
Template variables:
data.agent_name: the name of the rejected agent
data.agent_id: the ID of the agent
data.agent_version: the version of the agent
data.graph_id: the ID of the agent
data.graph_version: the version of the agent
data.reviewer_name: the name of the reviewer who rejected it
data.reviewer_email: the email of the reviewer
data.comments: comments from the reviewer explaining the rejection
@@ -74,4 +74,4 @@
We're excited to see your improved agent submission! 🚀
</p>
{% endblock %}
{% endblock %}

View File

@@ -64,6 +64,10 @@ class GraphNotInLibraryError(GraphNotAccessibleError):
"""Raised when attempting to execute a graph that is not / no longer in the user's library."""
class PreconditionFailed(Exception):
"""The user must do something else first before trying the current operation"""
class InsufficientBalanceError(ValueError):
user_id: str
message: str

View File

@@ -0,0 +1,32 @@
BEGIN;
-- Drop illogical column StoreListing.agentGraphVersion;
ALTER TABLE "StoreListing" DROP CONSTRAINT "StoreListing_agentGraphId_agentGraphVersion_fkey";
DROP INDEX "StoreListing_agentGraphId_agentGraphVersion_idx";
ALTER TABLE "StoreListing" DROP COLUMN "agentGraphVersion";
-- Add uniqueness constraint to Profile.userId and remove invalid data
--
-- Delete any profiles with null userId (which is invalid and doesn't occur in theory)
DELETE FROM "Profile" WHERE "userId" IS NULL;
--
-- Delete duplicate profiles per userId, keeping the most recently updated one
DELETE FROM "Profile"
WHERE "id" IN (
SELECT "id" FROM (
SELECT "id", ROW_NUMBER() OVER (
PARTITION BY "userId" ORDER BY "updatedAt" DESC, "id" DESC
) AS rn
FROM "Profile"
) ranked
WHERE rn > 1
);
--
-- Add userId uniqueness constraint
ALTER TABLE "Profile" ALTER COLUMN "userId" SET NOT NULL;
CREATE UNIQUE INDEX "Profile_userId_key" ON "Profile"("userId");
-- Add formal relation StoreListing.owningUserId -> Profile.userId
ALTER TABLE "StoreListing" ADD CONSTRAINT "StoreListing_owner_Profile_fkey" FOREIGN KEY ("owningUserId") REFERENCES "Profile"("userId") ON DELETE CASCADE ON UPDATE CASCADE;
COMMIT;

View File

@@ -0,0 +1,219 @@
-- Update the StoreSubmission and StoreAgent views with additional fields, clearer field names, and faster joins.
-- Steps:
-- 1. Update `mv_agent_run_counts` to exclude runs by the agent's creator
-- a. Drop dependent views `StoreAgent` and `Creator`
-- b. Update `mv_agent_run_counts` and its index
-- c. Recreate `StoreAgent` view (with updates)
-- d. Restore `Creator` view
-- 2. Update `StoreSubmission` view
-- 3. Update `StoreListingReview` indices to make `StoreSubmission` query more efficient
BEGIN;
-- Drop views that are dependent on mv_agent_run_counts
DROP VIEW IF EXISTS "StoreAgent";
DROP VIEW IF EXISTS "Creator";
-- Update materialized view for agent run counts to exclude runs by the agent's creator
DROP INDEX IF EXISTS "idx_mv_agent_run_counts";
DROP MATERIALIZED VIEW IF EXISTS "mv_agent_run_counts";
CREATE MATERIALIZED VIEW "mv_agent_run_counts" AS
SELECT
run."agentGraphId" AS graph_id,
COUNT(*) AS run_count
FROM "AgentGraphExecution" run
JOIN "AgentGraph" graph ON graph.id = run."agentGraphId"
-- Exclude runs by the agent's creator to avoid inflating run counts
WHERE graph."userId" != run."userId"
GROUP BY run."agentGraphId";
-- Recreate index
CREATE UNIQUE INDEX IF NOT EXISTS "idx_mv_agent_run_counts" ON "mv_agent_run_counts"("graph_id");
-- Re-populate the materialized view
REFRESH MATERIALIZED VIEW "mv_agent_run_counts";
-- Recreate the StoreAgent view with the following changes
-- (compared to 20260115210000_remove_storelistingversion_search):
-- - Narrow to *explicitly active* version (sl.activeVersionId) instead of MAX(version)
-- - Add `recommended_schedule_cron` column
-- - Rename `"storeListingVersionId"` -> `listing_version_id`
-- - Rename `"agentGraphVersions"` -> `graph_versions`
-- - Rename `"agentGraphId"` -> `graph_id`
-- - Rename `"useForOnboarding"` -> `use_for_onboarding`
CREATE OR REPLACE VIEW "StoreAgent" AS
WITH store_agent_versions AS (
SELECT
"storeListingId",
array_agg(DISTINCT version::text ORDER BY version::text) AS versions
FROM "StoreListingVersion"
WHERE "submissionStatus" = 'APPROVED'
GROUP BY "storeListingId"
),
agent_graph_versions AS (
SELECT
"storeListingId",
array_agg(DISTINCT "agentGraphVersion"::text ORDER BY "agentGraphVersion"::text) AS graph_versions
FROM "StoreListingVersion"
WHERE "submissionStatus" = 'APPROVED'
GROUP BY "storeListingId"
)
SELECT
sl.id AS listing_id,
slv.id AS listing_version_id,
slv."createdAt" AS updated_at,
sl.slug,
COALESCE(slv.name, '') AS agent_name,
slv."videoUrl" AS agent_video,
slv."agentOutputDemoUrl" AS agent_output_demo,
COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image,
slv."isFeatured" AS featured,
cp.username AS creator_username,
cp."avatarUrl" AS creator_avatar,
slv."subHeading" AS sub_heading,
slv.description,
slv.categories,
COALESCE(arc.run_count, 0::bigint) AS runs,
COALESCE(reviews.avg_rating, 0.0)::double precision AS rating,
COALESCE(sav.versions, ARRAY[slv.version::text]) AS versions,
slv."agentGraphId" AS graph_id,
COALESCE(
agv.graph_versions,
ARRAY[slv."agentGraphVersion"::text]
) AS graph_versions,
slv."isAvailable" AS is_available,
COALESCE(sl."useForOnboarding", false) AS use_for_onboarding,
slv."recommendedScheduleCron" AS recommended_schedule_cron
FROM "StoreListing" AS sl
JOIN "StoreListingVersion" AS slv
ON slv."storeListingId" = sl.id
AND slv.id = sl."activeVersionId"
AND slv."submissionStatus" = 'APPROVED'
JOIN "AgentGraph" AS ag
ON slv."agentGraphId" = ag.id
AND slv."agentGraphVersion" = ag.version
LEFT JOIN "Profile" AS cp
ON sl."owningUserId" = cp."userId"
LEFT JOIN "mv_review_stats" AS reviews
ON sl.id = reviews."storeListingId"
LEFT JOIN "mv_agent_run_counts" AS arc
ON ag.id = arc.graph_id
LEFT JOIN store_agent_versions AS sav
ON sl.id = sav."storeListingId"
LEFT JOIN agent_graph_versions AS agv
ON sl.id = agv."storeListingId"
WHERE sl."isDeleted" = false
AND sl."hasApprovedVersion" = true;
-- Restore Creator view as last updated in 20250604130249_optimise_store_agent_and_creator_views,
-- with minor changes:
-- - Ensure top_categories always TEXT[]
-- - Filter out empty ('') categories
CREATE OR REPLACE VIEW "Creator" AS
WITH creator_listings AS (
SELECT
sl."owningUserId",
sl.id AS listing_id,
slv."agentGraphId",
slv.categories,
sr.score,
ar.run_count
FROM "StoreListing" sl
JOIN "StoreListingVersion" slv
ON slv."storeListingId" = sl.id
AND slv."submissionStatus" = 'APPROVED'
LEFT JOIN "StoreListingReview" sr
ON sr."storeListingVersionId" = slv.id
LEFT JOIN "mv_agent_run_counts" ar
ON ar.graph_id = slv."agentGraphId"
WHERE sl."isDeleted" = false
AND sl."hasApprovedVersion" = true
),
creator_stats AS (
SELECT
cl."owningUserId",
COUNT(DISTINCT cl.listing_id) AS num_agents,
AVG(COALESCE(cl.score, 0)::numeric) AS agent_rating,
SUM(COALESCE(cl.run_count, 0)) AS agent_runs,
array_agg(DISTINCT cat ORDER BY cat)
FILTER (WHERE cat IS NOT NULL AND cat != '') AS all_categories
FROM creator_listings cl
LEFT JOIN LATERAL unnest(COALESCE(cl.categories, ARRAY[]::text[])) AS cat ON true
GROUP BY cl."owningUserId"
)
SELECT
p.username,
p.name,
p."avatarUrl" AS avatar_url,
p.description,
COALESCE(cs.all_categories, ARRAY[]::text[]) AS top_categories,
p.links,
p."isFeatured" AS is_featured,
COALESCE(cs.num_agents, 0::bigint) AS num_agents,
COALESCE(cs.agent_rating, 0.0) AS agent_rating,
COALESCE(cs.agent_runs, 0::numeric) AS agent_runs
FROM "Profile" p
LEFT JOIN creator_stats cs ON cs."owningUserId" = p."userId";
-- Recreate the StoreSubmission view with updated fields & query strategy:
-- - Uses mv_agent_run_counts instead of full AgentGraphExecution table scan + aggregation
-- - Renamed agent_id, agent_version -> graph_id, graph_version
-- - Renamed store_listing_version_id -> listing_version_id
-- - Renamed date_submitted -> submitted_at
-- - Renamed runs, rating -> run_count, review_avg_rating
-- - Added fields: instructions, agent_output_demo_url, review_count, is_deleted
DROP VIEW IF EXISTS "StoreSubmission";
CREATE OR REPLACE VIEW "StoreSubmission" AS
WITH review_stats AS (
SELECT
"storeListingVersionId" AS version_id, -- more specific than mv_review_stats
avg(score) AS avg_rating,
count(*) AS review_count
FROM "StoreListingReview"
GROUP BY "storeListingVersionId"
)
SELECT
sl.id AS listing_id,
sl."owningUserId" AS user_id,
sl.slug AS slug,
slv.id AS listing_version_id,
slv.version AS listing_version,
slv."agentGraphId" AS graph_id,
slv."agentGraphVersion" AS graph_version,
slv.name AS name,
slv."subHeading" AS sub_heading,
slv.description AS description,
slv.instructions AS instructions,
slv.categories AS categories,
slv."imageUrls" AS image_urls,
slv."videoUrl" AS video_url,
slv."agentOutputDemoUrl" AS agent_output_demo_url,
slv."submittedAt" AS submitted_at,
slv."changesSummary" AS changes_summary,
slv."submissionStatus" AS status,
slv."reviewedAt" AS reviewed_at,
slv."reviewerId" AS reviewer_id,
slv."reviewComments" AS review_comments,
slv."internalComments" AS internal_comments,
slv."isDeleted" AS is_deleted,
COALESCE(run_stats.run_count, 0::bigint) AS run_count,
COALESCE(review_stats.review_count, 0::bigint) AS review_count,
COALESCE(review_stats.avg_rating, 0.0)::double precision AS review_avg_rating
FROM "StoreListing" sl
JOIN "StoreListingVersion" slv ON slv."storeListingId" = sl.id
LEFT JOIN review_stats ON review_stats.version_id = slv.id
LEFT JOIN mv_agent_run_counts run_stats ON run_stats.graph_id = slv."agentGraphId"
WHERE sl."isDeleted" = false;
-- Drop unused index on StoreListingReview.reviewByUserId
DROP INDEX IF EXISTS "StoreListingReview_reviewByUserId_idx";
-- Add index on storeListingVersionId to make StoreSubmission query faster
CREATE INDEX "StoreListingReview_storeListingVersionId_idx" ON "StoreListingReview"("storeListingVersionId");
COMMIT;

View File

@@ -0,0 +1,114 @@
-- CreateEnum
CREATE TYPE "InvitedUserStatus" AS ENUM ('INVITED', 'CLAIMED', 'REVOKED');
-- CreateEnum
CREATE TYPE "TallyComputationStatus" AS ENUM ('PENDING', 'RUNNING', 'READY', 'FAILED');
-- CreateTable
CREATE TABLE "InvitedUser" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"email" TEXT NOT NULL,
"status" "InvitedUserStatus" NOT NULL DEFAULT 'INVITED',
"authUserId" TEXT,
"name" TEXT,
"tallyUnderstanding" JSONB,
"tallyStatus" "TallyComputationStatus" NOT NULL DEFAULT 'PENDING',
"tallyComputedAt" TIMESTAMP(3),
"tallyError" TEXT,
CONSTRAINT "InvitedUser_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "InvitedUser_email_key" ON "InvitedUser"("email");
-- CreateIndex
CREATE UNIQUE INDEX "InvitedUser_authUserId_key" ON "InvitedUser"("authUserId");
-- CreateIndex
CREATE INDEX "InvitedUser_status_idx" ON "InvitedUser"("status");
-- CreateIndex
CREATE INDEX "InvitedUser_tallyStatus_idx" ON "InvitedUser"("tallyStatus");
-- AddForeignKey
ALTER TABLE "InvitedUser" ADD CONSTRAINT "InvitedUser_authUserId_fkey"
FOREIGN KEY ("authUserId") REFERENCES "User"("id") ON DELETE SET NULL ON UPDATE CASCADE;
DO $$
DECLARE
allowed_users_schema TEXT;
BEGIN
SELECT table_schema
INTO allowed_users_schema
FROM information_schema.columns
WHERE table_name = 'allowed_users'
AND column_name = 'email'
ORDER BY CASE
WHEN table_schema = 'platform' THEN 0
WHEN table_schema = 'public' THEN 1
ELSE 2
END
LIMIT 1;
IF allowed_users_schema IS NOT NULL THEN
EXECUTE format(
'INSERT INTO platform."InvitedUser" ("id", "email", "status", "tallyStatus", "createdAt", "updatedAt")
SELECT gen_random_uuid()::text,
lower(email),
''INVITED''::platform."InvitedUserStatus",
''PENDING''::platform."TallyComputationStatus",
now(),
now()
FROM %I.allowed_users
WHERE email IS NOT NULL
ON CONFLICT ("email") DO NOTHING',
allowed_users_schema
);
END IF;
END $$;
CREATE OR REPLACE FUNCTION platform.ensure_invited_user_can_register()
RETURNS TRIGGER AS $$
BEGIN
IF NEW.email IS NULL THEN
RAISE EXCEPTION 'The email address "%" is not allowed to register. Please contact support for assistance.', NEW.email
USING ERRCODE = 'P0001';
END IF;
IF lower(split_part(NEW.email, '@', 2)) = 'agpt.co' THEN
RETURN NEW;
END IF;
IF EXISTS (
SELECT 1
FROM platform."InvitedUser" invited_user
WHERE lower(invited_user.email) = lower(NEW.email)
AND invited_user.status = 'INVITED'::platform."InvitedUserStatus"
) THEN
RETURN NEW;
END IF;
RAISE EXCEPTION 'The email address "%" is not allowed to register. Please contact support for assistance.', NEW.email
USING ERRCODE = 'P0001';
END;
$$ LANGUAGE plpgsql SECURITY DEFINER;
DO $$
BEGIN
IF EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema = 'auth'
AND table_name = 'users'
) THEN
DROP TRIGGER IF EXISTS user_added_to_platform ON auth.users;
DROP TRIGGER IF EXISTS invited_user_signup_gate ON auth.users;
CREATE TRIGGER invited_user_signup_gate
BEFORE INSERT ON auth.users
FOR EACH ROW EXECUTE FUNCTION platform.ensure_invited_user_can_register();
END IF;
END $$;

View File

@@ -65,6 +65,7 @@ model User {
NotificationBatches UserNotificationBatch[]
PendingHumanReviews PendingHumanReview[]
Workspace UserWorkspace?
ClaimedInvite InvitedUser? @relation("InvitedUserAuthUser")
// OAuth Provider relations
OAuthApplications OAuthApplication[]
@@ -73,6 +74,39 @@ model User {
OAuthRefreshTokens OAuthRefreshToken[]
}
enum InvitedUserStatus {
INVITED
CLAIMED
REVOKED
}
enum TallyComputationStatus {
PENDING
RUNNING
READY
FAILED
}
model InvitedUser {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
email String @unique
status InvitedUserStatus @default(INVITED)
authUserId String? @unique
AuthUser User? @relation("InvitedUserAuthUser", fields: [authUserId], references: [id], onDelete: SetNull)
name String?
tallyUnderstanding Json?
tallyStatus TallyComputationStatus @default(PENDING)
tallyComputedAt DateTime?
tallyError String?
@@index([status])
@@index([tallyStatus])
}
enum OnboardingStep {
// Introductory onboarding (Library)
WELCOME
@@ -281,7 +315,6 @@ model AgentGraph {
Presets AgentPreset[]
LibraryAgents LibraryAgent[]
StoreListings StoreListing[]
StoreListingVersions StoreListingVersion[]
@@id(name: "graphVersionId", [id, version])
@@ -814,10 +847,8 @@ model Profile {
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
// Only 1 of user or group can be set.
// The user this profile belongs to, if any.
userId String?
User User? @relation(fields: [userId], references: [id], onDelete: Cascade)
userId String @unique
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
name String
username String @unique
@@ -830,6 +861,7 @@ model Profile {
isFeatured Boolean @default(false)
LibraryAgents LibraryAgent[]
StoreListings StoreListing[]
@@index([userId])
}
@@ -860,9 +892,9 @@ view Creator {
}
view StoreAgent {
listing_id String @id
storeListingVersionId String
updated_at DateTime
listing_id String @id
listing_version_id String
updated_at DateTime
slug String
agent_name String
@@ -879,10 +911,12 @@ view StoreAgent {
runs Int
rating Float
versions String[]
agentGraphVersions String[]
agentGraphId String
graph_id String
graph_versions String[]
is_available Boolean @default(true)
useForOnboarding Boolean @default(false)
use_for_onboarding Boolean @default(false)
recommended_schedule_cron String?
// Materialized views used (refreshed every 15 minutes via pg_cron):
// - mv_agent_run_counts - Pre-aggregated agent execution counts by agentGraphId
@@ -896,41 +930,52 @@ view StoreAgent {
}
view StoreSubmission {
listing_id String @id
user_id String
slug String
name String
sub_heading String
description String
image_urls String[]
date_submitted DateTime
status SubmissionStatus
runs Int
rating Float
agent_id String
agent_version Int
store_listing_version_id String
reviewer_id String?
review_comments String?
internal_comments String?
reviewed_at DateTime?
changes_summary String?
video_url String?
categories String[]
// From StoreListing:
listing_id String
user_id String
slug String
// Index or unique are not applied to views
// From StoreListingVersion:
listing_version_id String @id
listing_version Int
graph_id String
graph_version Int
name String
sub_heading String
description String
instructions String?
categories String[]
image_urls String[]
video_url String?
agent_output_demo_url String?
submitted_at DateTime?
changes_summary String?
status SubmissionStatus
reviewed_at DateTime?
reviewer_id String?
review_comments String?
internal_comments String?
is_deleted Boolean
// Aggregated from AgentGraphExecutions and StoreListingReviews:
run_count Int
review_count Int
review_avg_rating Float
}
// Note: This is actually a MATERIALIZED VIEW in the database
// Refreshed automatically every 15 minutes via pg_cron (with fallback to manual refresh)
view mv_agent_run_counts {
agentGraphId String @unique
run_count Int
graph_id String @unique
run_count Int // excluding runs by the graph's creator
// Pre-aggregated count of AgentGraphExecution records by agentGraphId
// Used by StoreAgent and Creator views for performance optimization
// Unique index created automatically on agentGraphId for fast lookups
// Refresh uses CONCURRENTLY to avoid blocking reads
// Pre-aggregated count of AgentGraphExecution records by agentGraphId.
// Used by StoreAgent, Creator, and StoreSubmission views for performance optimization.
// - Should have a unique index on graph_id for fast lookups
// - Refresh should use CONCURRENTLY to avoid blocking reads
}
// Note: This is actually a MATERIALIZED VIEW in the database
@@ -979,22 +1024,18 @@ model StoreListing {
ActiveVersion StoreListingVersion? @relation("ActiveVersion", fields: [activeVersionId], references: [id])
// The agent link here is only so we can do lookup on agentId
agentGraphId String
agentGraphVersion Int
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version], onDelete: Cascade)
agentGraphId String @unique
owningUserId String
OwningUser User @relation(fields: [owningUserId], references: [id])
owningUserId String
OwningUser User @relation(fields: [owningUserId], references: [id])
CreatorProfile Profile @relation(fields: [owningUserId], references: [userId], map: "StoreListing_owner_Profile_fkey", onDelete: Cascade)
// Relations
Versions StoreListingVersion[] @relation("ListingVersions")
// Unique index on agentId to ensure only one listing per agent, regardless of number of versions the agent has.
@@unique([agentGraphId])
@@unique([owningUserId, slug])
// Used in the view query
@@index([isDeleted, hasApprovedVersion])
@@index([agentGraphId, agentGraphVersion])
}
model StoreListingVersion {
@@ -1089,16 +1130,16 @@ model UnifiedContentEmbedding {
// Search data
embedding Unsupported("vector(1536)") // pgvector embedding (extension in platform schema)
searchableText String // Combined text for search and fallback
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector")) // Full-text search (auto-populated by trigger)
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector")) // Full-text search (auto-populated by trigger)
metadata Json @default("{}") // Content-specific metadata
// NO @@index for search - GIN index "UnifiedContentEmbedding_search_idx" created via SQL migration
// Prisma may generate DROP INDEX on migrate dev - that's okay, migration recreates it
@@unique([contentType, contentId, userId], map: "UnifiedContentEmbedding_contentType_contentId_userId_key")
@@index([contentType])
@@index([userId])
@@index([contentType, userId])
@@index([embedding], map: "UnifiedContentEmbedding_embedding_idx")
// NO @@index for search - GIN index "UnifiedContentEmbedding_search_idx" created via SQL migration
// Prisma may generate DROP INDEX on migrate dev - that's okay, migration recreates it
}
model StoreListingReview {
@@ -1115,8 +1156,9 @@ model StoreListingReview {
score Int
comments String?
// Enforce one review per user per listing version
@@unique([storeListingVersionId, reviewByUserId])
@@index([reviewByUserId])
@@index([storeListingVersionId])
}
enum SubmissionStatus {

View File

@@ -23,14 +23,14 @@
"1.0.0",
"1.1.0"
],
"agentGraphVersions": [
"graph_id": "test-graph-id",
"graph_versions": [
"1",
"2"
],
"agentGraphId": "test-graph-id",
"last_updated": "2023-01-01T00:00:00",
"recommended_schedule_cron": null,
"active_version_id": null,
"has_approved_version": false,
"active_version_id": "test-version-id",
"has_approved_version": true,
"changelog": null
}

View File

@@ -1,14 +1,16 @@
{
"name": "Test User",
"username": "creator1",
"name": "Test User",
"description": "Test creator description",
"avatar_url": "avatar.jpg",
"links": [
"link1.com",
"link2.com"
],
"avatar_url": "avatar.jpg",
"agent_rating": 4.8,
"is_featured": true,
"num_agents": 5,
"agent_runs": 1000,
"agent_rating": 4.8,
"top_categories": [
"category1",
"category2"

View File

@@ -1,54 +1,94 @@
{
"creators": [
{
"name": "Creator 0",
"username": "creator0",
"name": "Creator 0",
"description": "Creator 0 description",
"avatar_url": "avatar0.jpg",
"links": [
"user0.link.com"
],
"is_featured": false,
"num_agents": 1,
"agent_rating": 4.5,
"agent_runs": 100,
"is_featured": false
"agent_rating": 4.5,
"top_categories": [
"cat1",
"cat2",
"cat3"
]
},
{
"name": "Creator 1",
"username": "creator1",
"name": "Creator 1",
"description": "Creator 1 description",
"avatar_url": "avatar1.jpg",
"links": [
"user1.link.com"
],
"is_featured": false,
"num_agents": 1,
"agent_rating": 4.5,
"agent_runs": 100,
"is_featured": false
"agent_rating": 4.5,
"top_categories": [
"cat1",
"cat2",
"cat3"
]
},
{
"name": "Creator 2",
"username": "creator2",
"name": "Creator 2",
"description": "Creator 2 description",
"avatar_url": "avatar2.jpg",
"links": [
"user2.link.com"
],
"is_featured": false,
"num_agents": 1,
"agent_rating": 4.5,
"agent_runs": 100,
"is_featured": false
"agent_rating": 4.5,
"top_categories": [
"cat1",
"cat2",
"cat3"
]
},
{
"name": "Creator 3",
"username": "creator3",
"name": "Creator 3",
"description": "Creator 3 description",
"avatar_url": "avatar3.jpg",
"links": [
"user3.link.com"
],
"is_featured": false,
"num_agents": 1,
"agent_rating": 4.5,
"agent_runs": 100,
"is_featured": false
"agent_rating": 4.5,
"top_categories": [
"cat1",
"cat2",
"cat3"
]
},
{
"name": "Creator 4",
"username": "creator4",
"name": "Creator 4",
"description": "Creator 4 description",
"avatar_url": "avatar4.jpg",
"links": [
"user4.link.com"
],
"is_featured": false,
"num_agents": 1,
"agent_rating": 4.5,
"agent_runs": 100,
"is_featured": false
"agent_rating": 4.5,
"top_categories": [
"cat1",
"cat2",
"cat3"
]
}
],
"pagination": {

View File

@@ -2,32 +2,33 @@
"submissions": [
{
"listing_id": "test-listing-id",
"agent_id": "test-agent-id",
"agent_version": 1,
"user_id": "test-user-id",
"slug": "test-agent",
"listing_version_id": "test-version-id",
"listing_version": 1,
"graph_id": "test-agent-id",
"graph_version": 1,
"name": "Test Agent",
"sub_heading": "Test agent subheading",
"slug": "test-agent",
"description": "Test agent description",
"instructions": null,
"instructions": "Click the button!",
"categories": [
"test-category"
],
"image_urls": [
"test.jpg"
],
"date_submitted": "2023-01-01T00:00:00",
"video_url": "test.mp4",
"agent_output_demo_url": "demo_video.mp4",
"submitted_at": "2023-01-01T00:00:00",
"changes_summary": "Initial Submission",
"status": "APPROVED",
"runs": 50,
"rating": 4.2,
"store_listing_version_id": null,
"version": null,
"reviewed_at": null,
"reviewer_id": null,
"review_comments": null,
"internal_comments": null,
"reviewed_at": null,
"changes_summary": null,
"video_url": "test.mp4",
"agent_output_demo_url": null,
"categories": [
"test-category"
]
"run_count": 50,
"review_count": 5,
"review_avg_rating": 4.2
}
],
"pagination": {

View File

@@ -128,7 +128,7 @@ class TestDataCreator:
email = "test123@gmail.com"
else:
email = faker.unique.email()
password = "testpassword123" # Standard test password
password = "testpassword123" # Standard test password # pragma: allowlist secret # noqa
user_id = f"test-user-{i}-{faker.uuid4()}"
# Create user in Supabase Auth (if needed)
@@ -571,8 +571,8 @@ class TestDataCreator:
if test_user and self.agent_graphs:
test_submission_data = {
"user_id": test_user["id"],
"agent_id": self.agent_graphs[0]["id"],
"agent_version": 1,
"graph_id": self.agent_graphs[0]["id"],
"graph_version": 1,
"slug": "test-agent-submission",
"name": "Test Agent Submission",
"sub_heading": "A test agent for frontend testing",
@@ -593,9 +593,9 @@ class TestDataCreator:
print("✅ Created special test store submission for test123@gmail.com")
# ALWAYS approve and feature the test submission
if test_submission.store_listing_version_id:
if test_submission.listing_version_id:
approved_submission = await review_store_submission(
store_listing_version_id=test_submission.store_listing_version_id,
store_listing_version_id=test_submission.listing_version_id,
is_approved=True,
external_comments="Test submission approved",
internal_comments="Auto-approved test submission",
@@ -605,7 +605,7 @@ class TestDataCreator:
print("✅ Approved test store submission")
await prisma.storelistingversion.update(
where={"id": test_submission.store_listing_version_id},
where={"id": test_submission.listing_version_id},
data={"isFeatured": True},
)
featured_count += 1
@@ -640,8 +640,8 @@ class TestDataCreator:
submission = await create_store_submission(
user_id=user["id"],
agent_id=graph["id"],
agent_version=graph.get("version", 1),
graph_id=graph["id"],
graph_version=graph.get("version", 1),
slug=faker.slug(),
name=graph.get("name", faker.sentence(nb_words=3)),
sub_heading=faker.sentence(),
@@ -654,7 +654,7 @@ class TestDataCreator:
submissions.append(submission.model_dump())
print(f"✅ Created store submission: {submission.name}")
if submission.store_listing_version_id:
if submission.listing_version_id:
# DETERMINISTIC: First N submissions are always approved
# First GUARANTEED_FEATURED_AGENTS of those are always featured
should_approve = (
@@ -667,7 +667,7 @@ class TestDataCreator:
try:
reviewer_id = random.choice(self.users)["id"]
approved_submission = await review_store_submission(
store_listing_version_id=submission.store_listing_version_id,
store_listing_version_id=submission.listing_version_id,
is_approved=True,
external_comments="Auto-approved for E2E testing",
internal_comments="Automatically approved by E2E test data script",
@@ -683,9 +683,7 @@ class TestDataCreator:
if should_feature:
try:
await prisma.storelistingversion.update(
where={
"id": submission.store_listing_version_id
},
where={"id": submission.listing_version_id},
data={"isFeatured": True},
)
featured_count += 1
@@ -699,9 +697,7 @@ class TestDataCreator:
elif random.random() < 0.2:
try:
await prisma.storelistingversion.update(
where={
"id": submission.store_listing_version_id
},
where={"id": submission.listing_version_id},
data={"isFeatured": True},
)
featured_count += 1
@@ -721,7 +717,7 @@ class TestDataCreator:
try:
reviewer_id = random.choice(self.users)["id"]
await review_store_submission(
store_listing_version_id=submission.store_listing_version_id,
store_listing_version_id=submission.listing_version_id,
is_approved=False,
external_comments="Submission rejected - needs improvements",
internal_comments="Automatically rejected by E2E test data script",

View File

@@ -394,7 +394,6 @@ async def main():
listing = await db.storelisting.create(
data={
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"owningUserId": user.id,
"hasApprovedVersion": random.choice([True, False]),
"slug": slug,

View File

@@ -0,0 +1,393 @@
# Beta Invite Pre-Provisioning Design
## Problem
The current signup path is split across three places:
- Supabase creates the auth user and password.
- The backend lazily creates `platform.User` on first authenticated request in [`backend/backend/data/user.py`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/backend/backend/data/user.py).
- A Postgres trigger creates `platform.Profile` after `auth.users` insert in [`backend/migrations/20250205100104_add_profile_trigger/migration.sql`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/backend/migrations/20250205100104_add_profile_trigger/migration.sql).
That works for open signup plus a DB-level allowlist, but it does not give the platform a durable object representing:
- a beta invite before the person signs up
- pre-computed onboarding data tied to an email before auth exists
- the distinction between "invited", "claimed", "active", and "revoked"
It also makes first-login setup racey because `User`, `Profile`, `UserOnboarding`, and `CoPilotUnderstanding` are not created from one source of truth.
## Goals
- Allow staff to invite a user before they have a Supabase account.
- Pre-create the platform-side user record and related data before first login.
- Keep Supabase responsible for password entry, email verification, sessions, and OAuth providers.
- Support both password signup and magic-link / OAuth signup for invited users.
- Preserve the existing ability to populate Tally-derived understanding and onboarding defaults.
- Make first login idempotent and safe if the user retries or signs in with a different method.
## Non-goals
- Replacing Supabase Auth.
- Building a full enterprise identity management system.
- Solving general-purpose team/org invites in this change.
## Proposed model
Introduce invite-backed pre-provisioning with email as the pre-auth identity key, then bind the invite to the Supabase `auth.users.id` when the user claims it.
### New tables
#### `BetaInvite`
Represents an invitation sent to a person before they create credentials.
Suggested fields:
- `id`
- `email` unique
- `status` enum: `PENDING`, `CLAIMED`, `EXPIRED`, `REVOKED`
- `inviteTokenHash` nullable
- `invitedByUserId` nullable
- `expiresAt` nullable
- `claimedAt` nullable
- `claimedAuthUserId` nullable unique
- `metadata` jsonb
- `createdAt`
- `updatedAt`
`metadata` should hold operational fields only, for example:
- source: manual import, admin UI, CSV
- cohort: beta wave name
- notes
- original tally email if normalization differs
#### `PreProvisionedUser`
Represents the platform-side user state that exists before Supabase auth is bound.
Suggested fields:
- `id`
- `inviteId` unique
- `email` unique
- `authUserId` nullable unique
- `status` enum: `PENDING_CLAIM`, `ACTIVE`, `MERGE_REQUIRED`, `DISABLED`
- `name` nullable
- `timezone` default `not-set`
- `emailVerified` default `false`
- `metadata` jsonb
- `createdAt`
- `updatedAt`
This is the durable record staff can enrich before login.
#### `PreProvisionedUserSeed`
Stores structured seed data that should become first-class user records on claim.
Suggested fields:
- `id`
- `preProvisionedUserId` unique
- `profile` jsonb nullable
- `onboarding` jsonb nullable
- `businessUnderstanding` jsonb nullable
- `promptContext` jsonb nullable
- `createdAt`
- `updatedAt`
This avoids polluting the main `User.metadata` blob and gives explicit ownership over what is seed data versus ongoing user-authored data.
## Why not pre-create `platform.User` directly?
`platform.User.id` is currently designed to equal the Supabase user id in [`backend/schema.prisma`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/backend/schema.prisma). Pre-creating `User` with a fake UUID would ripple through every FK and complicate the eventual bind. Reusing email as the join key inside `User` is also not safe enough because the rest of the system assumes `User.id` is the canonical identity.
A separate pre-provisioning layer is lower risk:
- it avoids breaking every relation off `User`
- it keeps the existing auth token model intact
- it gives a clean migration path to merge into `User` once a Supabase identity exists
## Claim flow
### 1. Staff creates invite
Admin API or script:
1. Create `BetaInvite`.
2. Create `PreProvisionedUser`.
3. Create `PreProvisionedUserSeed`.
4. Optionally fetch and store Tally-derived understanding immediately by email.
5. Optionally send invite email containing a claim link.
### 2. User opens signup page
Two supported modes:
- direct signup with invited email
- invite-link signup with a short-lived token
Preferred UX:
- `/signup?invite=<token>`
- frontend resolves token to masked email and invite status
- email field is prefilled and locked by default
### 3. User creates Supabase account
Frontend still calls Supabase `signUp`, but include invite context in user metadata:
- `invite_id`
- `invite_email`
This is useful for debugging but should not be the source of truth.
### 4. First authenticated backend call activates invite
Replace the current pure `get_or_create_user` behavior with:
1. Read JWT `sub` and `email`.
2. Look for existing `platform.User` by `id`.
3. If found, return it.
4. Else look for `PreProvisionedUser` by normalized email and `status = PENDING_CLAIM`.
5. In one transaction:
- create `platform.User` with `id = auth sub`
- bind `PreProvisionedUser.authUserId = auth sub`
- mark `PreProvisionedUser.status = ACTIVE`
- mark `BetaInvite.status = CLAIMED`
- create or upsert `Profile`
- create or upsert `UserOnboarding`
- create or upsert `CoPilotUnderstanding`
- create `UserWorkspace` if required for the product experience
6. If no pre-provisioned record exists, either:
- reject access for closed beta, or
- fall back to normal creation if a feature flag allows open signup
This activation logic should live in the backend service, not a Postgres trigger, because it needs to coordinate across multiple platform tables and apply merge rules.
## Data merge rules
When activating a pre-provisioned invite:
- `User`
- source of truth for `id` is Supabase JWT `sub`
- source of truth for `email` is Supabase auth email
- `name` prefers pre-provisioned value, falls back to auth metadata name
- `Profile`
- if seed profile exists, use it
- else create the current generated default
- never overwrite an existing profile if activation is retried
- `UserOnboarding`
- seed values are initial defaults only
- use `upsert` with create-on-missing semantics
- `CoPilotUnderstanding`
- seed from stored Tally extraction if present
- skip Tally backfill on first login if this already exists
- prompt/tally-specific context
- do not jam this into `User.metadata` unless no better typed model exists
- use `PreProvisionedUserSeed.promptContext` now, migrate to typed tables later if product solidifies
## Invite enforcement
The current closed-beta gate appears to rely on a Supabase-side DB error surfaced as "not allowed" in [`frontend/src/app/api/auth/utils.ts`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/frontend/src/app/api/auth/utils.ts).
That should evolve to:
### Short term
Keep the current Supabase signup restriction, but have it check `BetaInvite`/`PreProvisionedUser` instead of a raw allowlist.
### Medium term
Stop using a generic DB exception as the main product signal and expose a backend endpoint:
- `POST /internal/beta-invites/validate`
- `GET /internal/beta-invites/:token`
Then the frontend can fail earlier with a specific state:
- invited and ready
- invite expired
- already claimed
- not invited
Supabase should still remain the hard gate for credential creation, but the UI should stop depending on opaque trigger failures for normal control flow.
## Trigger changes
The `auth.users` trigger that creates `platform.User` and `platform.Profile` should be removed once activation logic is live.
Reason:
- it cannot see or safely apply pre-provisioned seed data
- it only handles create, not merge/bind
- it duplicates responsibility already present in `get_or_create_user`
Interim state during rollout:
- keep trigger disabled for production once backend activation ships
- keep a backfill script for any auth users created during the transition
## API surface
### Admin APIs
- `POST /admin/beta-invites`
- create invite + pre-provisioned user + seed data
- `POST /admin/beta-invites/bulk`
- CSV import for beta cohorts
- `POST /admin/beta-invites/:id/resend`
- `POST /admin/beta-invites/:id/revoke`
- `GET /admin/beta-invites`
### Public/auth-adjacent APIs
- `GET /beta-invites/lookup?token=...`
- return safe invite info for signup page
- `POST /beta-invites/claim-preview`
- optional: check whether an email is invited before attempting Supabase signup
### Internal service function changes
- Replace `get_or_create_user` with `get_or_activate_user`
- keep a compatibility wrapper if needed to limit churn
## Suggested backend implementation
### New module
- `backend/backend/data/beta_invite.py`
Responsibilities:
- create invite
- resolve invite token
- normalize email
- activate pre-provisioned user
- revoke / expire invite
### Existing module changes
- [`backend/backend/data/user.py`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/backend/backend/data/user.py)
- move lazy creation logic into activation-aware flow
- [`backend/backend/api/features/v1.py`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/backend/backend/api/features/v1.py)
- `/auth/user` should call activation-aware function
- only run background Tally population if no seeded understanding exists
- [`backend/backend/data/tally.py`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/backend/backend/data/tally.py)
- add a reusable "seed from email" function for invite creation time
## Suggested frontend implementation
### Signup
- read invite token from URL
- call invite lookup endpoint
- prefill locked email when token is valid
- if invite is invalid, show specific error, not generic waitlist modal
Files likely affected:
- [`frontend/src/app/(platform)/signup/page.tsx`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/frontend/src/app/(platform)/signup/page.tsx)
- [`frontend/src/app/(platform)/signup/actions.ts`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/frontend/src/app/(platform)/signup/actions.ts)
- [`frontend/src/components/auth/WaitlistErrorContent.tsx`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/frontend/src/components/auth/WaitlistErrorContent.tsx)
### Admin UI
Add a simple internal page for beta invites instead of manually editing allowlists.
Possible location:
- `frontend/src/app/(platform)/admin/users/invites/page.tsx`
## Rollout plan
### Phase 1: schema and service layer
- add `BetaInvite`, `PreProvisionedUser`, `PreProvisionedUserSeed`
- implement activation transaction
- keep existing trigger/allowlist in place
### Phase 2: admin creation path
- add admin API or CLI script
- support single invite and CSV bulk upload
- seed Tally/business understanding during invite creation
### Phase 3: signup UX
- invite token lookup
- better invite state messaging
- preserve existing closed beta modal for non-invited traffic
### Phase 4: remove legacy coupling
- disable `auth.users` profile trigger
- simplify `get_or_create_user`
- migrate allowlist logic to invite tables
## Edge cases
### Existing auth user, no platform user
This already happens today. Activation flow should treat it as:
- if auth email matches a pending pre-provisioned invite, bind and activate
- else create a plain `User` only if open-signup feature flag is enabled
### Existing platform user, invited again
Do not create another `PreProvisionedUser`. Create a second invite only if product explicitly wants re-invites. Otherwise reject as duplicate.
### Email changed after invite
Support admin-side reissue:
- revoke old invite
- create new invite with new email
- move seed data forward
Do not automatically bind across unrelated email addresses.
### OAuth signup
OAuth provider signups still work as long as the resulting Supabase email matches the invited email. If the provider returns a different email, activation should fail with a clear UI message.
### Tally data arrives after invite creation
Allow re-seeding before claim if `CoPilotUnderstanding` has not yet been created on activation.
## Migration notes
### Existing beta users
No immediate migration required for already-active users. This system is mainly for future invited users.
### Existing allowlist entries
Backfill them into `BetaInvite` plus `PreProvisionedUser`, then swap the Supabase gating logic to consult invite tables instead of the legacy allowlist table/trigger.
## Recommendation
Implement the pre-provisioning layer first and keep `platform.User` bound to Supabase `auth.users.id`. That is the lowest-risk design because it respects the existing identity model while giving the business exactly what it needs:
- invite someone before signup
- compute and store Tally/onboarding/prompt defaults before first login
- activate those defaults atomically when the user actually creates credentials
## First implementation slice
The smallest useful slice is:
1. Add `BetaInvite`, `PreProvisionedUser`, and `PreProvisionedUserSeed`.
2. Add a backend admin endpoint to create an invite from email plus optional seed payload.
3. Change `/auth/user` activation logic to bind and materialize seeded `Profile`, `UserOnboarding`, and `CoPilotUnderstanding`.
4. Keep the existing signup UI, but validate invite membership before or during signup.
That delivers the core behavior without needing the full admin UI on day one.

View File

@@ -10,6 +10,7 @@
"cssVariables": false,
"prefix": ""
},
"iconLibrary": "radix",
"aliases": {
"components": "@/components",
"utils": "@/lib/utils"

View File

@@ -1,5 +1,11 @@
import { Sidebar } from "@/components/__legacy__/Sidebar";
import { Users, DollarSign, UserSearch, FileText } from "lucide-react";
import {
Users,
DollarSign,
UserSearch,
FileText,
UserPlus,
} from "lucide-react";
import { IconSliders } from "@/components/__legacy__/ui/icons";
@@ -16,6 +22,11 @@ const sidebarLinkGroups = [
href: "/admin/spending",
icon: <DollarSign className="h-6 w-6" />,
},
{
text: "Beta Invites",
href: "/admin/users",
icon: <UserPlus className="h-6 w-6" />,
},
{
text: "User Impersonation",
href: "/admin/impersonation",

View File

@@ -1,33 +1,39 @@
"use server";
import { revalidatePath } from "next/cache";
import BackendApi from "@/lib/autogpt-server-api";
import {
StoreListingsWithVersionsResponse,
SubmissionStatus,
} from "@/lib/autogpt-server-api/types";
getV2GetAdminListingsHistory,
postV2ReviewStoreSubmission,
getV2AdminDownloadAgentFile,
} from "@/app/api/__generated__/endpoints/admin/admin";
import { okData } from "@/app/api/helpers";
import { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
export async function approveAgent(formData: FormData) {
const data = {
store_listing_version_id: formData.get("id") as string,
const storeListingVersionId = formData.get("id") as string;
const comments = formData.get("comments") as string;
await postV2ReviewStoreSubmission(storeListingVersionId, {
store_listing_version_id: storeListingVersionId,
is_approved: true,
comments: formData.get("comments") as string,
};
const api = new BackendApi();
await api.reviewSubmissionAdmin(data.store_listing_version_id, data);
comments,
});
revalidatePath("/admin/marketplace");
}
export async function rejectAgent(formData: FormData) {
const data = {
store_listing_version_id: formData.get("id") as string,
const storeListingVersionId = formData.get("id") as string;
const comments = formData.get("comments") as string;
const internal_comments =
(formData.get("internal_comments") as string) || undefined;
await postV2ReviewStoreSubmission(storeListingVersionId, {
store_listing_version_id: storeListingVersionId,
is_approved: false,
comments: formData.get("comments") as string,
internal_comments: formData.get("internal_comments") as string,
};
const api = new BackendApi();
await api.reviewSubmissionAdmin(data.store_listing_version_id, data);
comments,
internal_comments,
});
revalidatePath("/admin/marketplace");
}
@@ -37,26 +43,18 @@ export async function getAdminListingsWithVersions(
search?: string,
page: number = 1,
pageSize: number = 20,
): Promise<StoreListingsWithVersionsResponse> {
const data: Record<string, any> = {
) {
const response = await getV2GetAdminListingsHistory({
status,
search,
page,
page_size: pageSize,
};
});
if (status) {
data.status = status;
}
if (search) {
data.search = search;
}
const api = new BackendApi();
const response = await api.getAdminListingsWithVersions(data);
return response;
return okData(response);
}
export async function downloadAsAdmin(storeListingVersion: string) {
const api = new BackendApi();
const file = await api.downloadStoreAgentAdmin(storeListingVersion);
return file;
const response = await getV2AdminDownloadAgentFile(storeListingVersion);
return okData(response);
}

View File

@@ -6,10 +6,8 @@ import {
TableHeader,
TableRow,
} from "@/components/__legacy__/ui/table";
import {
StoreSubmission,
SubmissionStatus,
} from "@/lib/autogpt-server-api/types";
import type { StoreSubmissionAdminView } from "@/app/api/__generated__/models/storeSubmissionAdminView";
import type { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
import { PaginationControls } from "../../../../../components/__legacy__/ui/pagination-controls";
import { getAdminListingsWithVersions } from "@/app/(platform)/admin/marketplace/actions";
import { ExpandableRow } from "./ExpandleRow";
@@ -17,12 +15,14 @@ import { SearchAndFilterAdminMarketplace } from "./SearchFilterForm";
// Helper function to get the latest version by version number
const getLatestVersionByNumber = (
versions: StoreSubmission[],
): StoreSubmission | null => {
versions: StoreSubmissionAdminView[] | undefined,
): StoreSubmissionAdminView | null => {
if (!versions || versions.length === 0) return null;
return versions.reduce(
(latest, current) =>
(current.version ?? 0) > (latest.version ?? 1) ? current : latest,
(current.listing_version ?? 0) > (latest.listing_version ?? 1)
? current
: latest,
versions[0],
);
};
@@ -37,12 +37,14 @@ export async function AdminAgentsDataTable({
initialSearch?: string;
}) {
// Server-side data fetching
const { listings, pagination } = await getAdminListingsWithVersions(
const data = await getAdminListingsWithVersions(
initialStatus,
initialSearch,
initialPage,
10,
);
const listings = data?.listings ?? [];
const pagination = data?.pagination;
return (
<div className="space-y-4">
@@ -92,7 +94,7 @@ export async function AdminAgentsDataTable({
<PaginationControls
currentPage={initialPage}
totalPages={pagination.total_pages}
totalPages={pagination?.total_pages ?? 1}
/>
</div>
);

View File

@@ -13,7 +13,7 @@ import {
} from "@/components/__legacy__/ui/dialog";
import { Label } from "@/components/__legacy__/ui/label";
import { Textarea } from "@/components/__legacy__/ui/textarea";
import type { StoreSubmission } from "@/lib/autogpt-server-api/types";
import type { StoreSubmissionAdminView } from "@/app/api/__generated__/models/storeSubmissionAdminView";
import { useRouter } from "next/navigation";
import {
approveAgent,
@@ -23,7 +23,7 @@ import {
export function ApproveRejectButtons({
version,
}: {
version: StoreSubmission;
version: StoreSubmissionAdminView;
}) {
const router = useRouter();
const [isApproveDialogOpen, setIsApproveDialogOpen] = useState(false);
@@ -95,7 +95,7 @@ export function ApproveRejectButtons({
<input
type="hidden"
name="id"
value={version.store_listing_version_id || ""}
value={version.listing_version_id || ""}
/>
<div className="grid gap-4 py-4">
@@ -142,7 +142,7 @@ export function ApproveRejectButtons({
<input
type="hidden"
name="id"
value={version.store_listing_version_id || ""}
value={version.listing_version_id || ""}
/>
<div className="grid gap-4 py-4">

View File

@@ -12,11 +12,9 @@ import {
import { Badge } from "@/components/__legacy__/ui/badge";
import { ChevronDown, ChevronRight } from "lucide-react";
import { formatDistanceToNow } from "date-fns";
import {
type StoreListingWithVersions,
type StoreSubmission,
SubmissionStatus,
} from "@/lib/autogpt-server-api/types";
import type { StoreListingWithVersionsAdminView } from "@/app/api/__generated__/models/storeListingWithVersionsAdminView";
import type { StoreSubmissionAdminView } from "@/app/api/__generated__/models/storeSubmissionAdminView";
import { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
import { ApproveRejectButtons } from "./ApproveRejectButton";
import { DownloadAgentAdminButton } from "./DownloadAgentButton";
@@ -38,8 +36,8 @@ export function ExpandableRow({
listing,
latestVersion,
}: {
listing: StoreListingWithVersions;
latestVersion: StoreSubmission | null;
listing: StoreListingWithVersionsAdminView;
latestVersion: StoreSubmissionAdminView | null;
}) {
const [expanded, setExpanded] = useState(false);
@@ -69,17 +67,17 @@ export function ExpandableRow({
{latestVersion?.status && getStatusBadge(latestVersion.status)}
</TableCell>
<TableCell onClick={() => setExpanded(!expanded)}>
{latestVersion?.date_submitted
? formatDistanceToNow(new Date(latestVersion.date_submitted), {
{latestVersion?.submitted_at
? formatDistanceToNow(new Date(latestVersion.submitted_at), {
addSuffix: true,
})
: "Unknown"}
</TableCell>
<TableCell className="text-right">
<div className="flex justify-end gap-2">
{latestVersion?.store_listing_version_id && (
{latestVersion?.listing_version_id && (
<DownloadAgentAdminButton
storeListingVersionId={latestVersion.store_listing_version_id}
storeListingVersionId={latestVersion.listing_version_id}
/>
)}
@@ -115,14 +113,17 @@ export function ExpandableRow({
</TableRow>
</TableHeader>
<TableBody>
{listing.versions
.sort((a, b) => (b.version ?? 1) - (a.version ?? 0))
{(listing.versions ?? [])
.sort(
(a, b) =>
(b.listing_version ?? 1) - (a.listing_version ?? 0),
)
.map((version) => (
<TableRow key={version.store_listing_version_id}>
<TableRow key={version.listing_version_id}>
<TableCell>
v{version.version || "?"}
{version.store_listing_version_id ===
listing.active_version_id && (
v{version.listing_version || "?"}
{version.listing_version_id ===
listing.active_listing_version_id && (
<Badge className="ml-2 bg-blue-500">Active</Badge>
)}
</TableCell>
@@ -131,9 +132,9 @@ export function ExpandableRow({
{version.changes_summary || "No summary"}
</TableCell>
<TableCell>
{version.date_submitted
{version.submitted_at
? formatDistanceToNow(
new Date(version.date_submitted),
new Date(version.submitted_at),
{ addSuffix: true },
)
: "Unknown"}
@@ -182,10 +183,10 @@ export function ExpandableRow({
{/* <TableCell>{version.categories.join(", ")}</TableCell> */}
<TableCell className="text-right">
<div className="flex justify-end gap-2">
{version.store_listing_version_id && (
{version.listing_version_id && (
<DownloadAgentAdminButton
storeListingVersionId={
version.store_listing_version_id
version.listing_version_id
}
/>
)}

View File

@@ -12,7 +12,7 @@ import {
SelectTrigger,
SelectValue,
} from "@/components/__legacy__/ui/select";
import { SubmissionStatus } from "@/lib/autogpt-server-api/types";
import { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
export function SearchAndFilterAdminMarketplace({
initialSearch,

View File

@@ -1,11 +1,11 @@
import { withRoleAccess } from "@/lib/withRoleAccess";
import { Suspense } from "react";
import type { SubmissionStatus } from "@/lib/autogpt-server-api/types";
import type { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
import { AdminAgentsDataTable } from "./components/AdminAgentsDataTable";
type MarketplaceAdminPageSearchParams = {
page?: string;
status?: string;
status?: SubmissionStatus;
search?: string;
};
@@ -15,7 +15,7 @@ async function AdminMarketplaceDashboard({
searchParams: MarketplaceAdminPageSearchParams;
}) {
const page = searchParams.page ? Number.parseInt(searchParams.page) : 1;
const status = searchParams.status as SubmissionStatus | undefined;
const status = searchParams.status;
const search = searchParams.search;
return (

View File

@@ -0,0 +1,80 @@
"use client";
import { Card } from "@/components/atoms/Card/Card";
import { BulkInviteForm } from "../BulkInviteForm/BulkInviteForm";
import { InviteUserForm } from "../InviteUserForm/InviteUserForm";
import { InvitedUsersTable } from "../InvitedUsersTable/InvitedUsersTable";
import { useAdminUsersPage } from "../../useAdminUsersPage";
export function AdminUsersPage() {
const {
email,
name,
bulkInviteFile,
bulkInviteInputKey,
lastBulkInviteResult,
invitedUsers,
isLoadingInvitedUsers,
isRefreshingInvitedUsers,
isCreatingInvite,
isBulkInviting,
pendingInviteAction,
setEmail,
setName,
handleBulkInviteFileChange,
handleBulkInviteSubmit,
handleCreateInvite,
handleRetryTally,
handleRevoke,
} = useAdminUsersPage();
return (
<div className="mx-auto flex max-w-7xl flex-col gap-6 p-6">
<div className="flex flex-col gap-2">
<h1 className="text-3xl font-bold text-zinc-900">Beta Invites</h1>
<p className="max-w-3xl text-sm text-zinc-600">
Pre-provision beta users before they sign up. Invites store the
platform-side record, run Tally understanding extraction, and activate
the real account on the user&apos;s first authenticated request.
</p>
</div>
<div className="grid gap-6 xl:grid-cols-[24rem,1fr]">
<div className="flex flex-col gap-6">
<Card className="border border-zinc-200 shadow-sm">
<InviteUserForm
email={email}
name={name}
isSubmitting={isCreatingInvite}
onEmailChange={setEmail}
onNameChange={setName}
onSubmit={handleCreateInvite}
/>
</Card>
<Card className="border border-zinc-200 shadow-sm">
<BulkInviteForm
selectedFile={bulkInviteFile}
inputKey={bulkInviteInputKey}
isSubmitting={isBulkInviting}
lastResult={lastBulkInviteResult}
onFileChange={handleBulkInviteFileChange}
onSubmit={handleBulkInviteSubmit}
/>
</Card>
</div>
<Card className="border border-zinc-200 shadow-sm">
<InvitedUsersTable
invitedUsers={invitedUsers}
isLoading={isLoadingInvitedUsers}
isRefreshing={isRefreshingInvitedUsers}
pendingInviteAction={pendingInviteAction}
onRetryTally={handleRetryTally}
onRevoke={handleRevoke}
/>
</Card>
</div>
</div>
);
}

View File

@@ -0,0 +1,131 @@
"use client";
import type { BulkInvitedUsersResponse } from "@/app/api/__generated__/models/bulkInvitedUsersResponse";
import { Badge } from "@/components/atoms/Badge/Badge";
import { Button } from "@/components/atoms/Button/Button";
import type { FormEvent } from "react";
interface Props {
selectedFile: File | null;
inputKey: number;
isSubmitting: boolean;
lastResult: BulkInvitedUsersResponse | null;
onFileChange: (file: File | null) => void;
onSubmit: (event: FormEvent<HTMLFormElement>) => void;
}
function getStatusVariant(status: "CREATED" | "SKIPPED" | "ERROR") {
if (status === "CREATED") {
return "success";
}
if (status === "ERROR") {
return "error";
}
return "info";
}
export function BulkInviteForm({
selectedFile,
inputKey,
isSubmitting,
lastResult,
onFileChange,
onSubmit,
}: Props) {
return (
<form className="flex flex-col gap-4" onSubmit={onSubmit}>
<div className="flex flex-col gap-1">
<h2 className="text-xl font-semibold text-zinc-900">Bulk invite</h2>
<p className="text-sm text-zinc-600">
Upload a <span className="font-medium text-zinc-800">.txt</span> file
with one email per line, or a{" "}
<span className="font-medium text-zinc-800">.csv</span> with
<span className="font-medium text-zinc-800"> email</span> and optional
<span className="font-medium text-zinc-800"> name</span> columns.
</p>
</div>
<label className="flex cursor-pointer flex-col gap-2 rounded-2xl border border-dashed border-zinc-300 bg-zinc-50 px-4 py-5 text-sm text-zinc-600 transition-colors hover:border-zinc-400 hover:bg-zinc-100">
<span className="font-medium text-zinc-900">
{selectedFile ? selectedFile.name : "Choose invite file"}
</span>
<span>Maximum 500 rows, UTF-8 encoded.</span>
<input
key={inputKey}
type="file"
accept=".txt,.csv,text/plain,text/csv"
disabled={isSubmitting}
className="hidden"
onChange={(event) =>
onFileChange(event.target.files?.item(0) ?? null)
}
/>
</label>
<Button
type="submit"
variant="primary"
loading={isSubmitting}
disabled={!selectedFile}
className="w-full"
>
{isSubmitting ? "Uploading invites..." : "Upload invite file"}
</Button>
{lastResult ? (
<div className="flex flex-col gap-3 rounded-2xl border border-zinc-200 bg-zinc-50 p-4">
<div className="grid grid-cols-3 gap-2 text-center">
<div className="rounded-xl bg-white px-3 py-2">
<div className="text-lg font-semibold text-zinc-900">
{lastResult.created_count}
</div>
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
Created
</div>
</div>
<div className="rounded-xl bg-white px-3 py-2">
<div className="text-lg font-semibold text-zinc-900">
{lastResult.skipped_count}
</div>
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
Skipped
</div>
</div>
<div className="rounded-xl bg-white px-3 py-2">
<div className="text-lg font-semibold text-zinc-900">
{lastResult.error_count}
</div>
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
Errors
</div>
</div>
</div>
<div className="max-h-64 overflow-y-auto rounded-xl border border-zinc-200 bg-white">
<div className="flex flex-col divide-y divide-zinc-100">
{lastResult.results.map((row) => (
<div
key={`${row.row_number}-${row.email ?? row.message}`}
className="flex items-start gap-3 px-3 py-3"
>
<Badge variant={getStatusVariant(row.status)} size="small">
{row.status}
</Badge>
<div className="flex min-w-0 flex-1 flex-col gap-1">
<span className="text-sm font-medium text-zinc-900">
Row {row.row_number}
{row.email ? ` · ${row.email}` : ""}
</span>
<span className="text-xs text-zinc-500">{row.message}</span>
</div>
</div>
))}
</div>
</div>
</div>
) : null}
</form>
);
}

View File

@@ -0,0 +1,66 @@
"use client";
import { Button } from "@/components/atoms/Button/Button";
import { Input } from "@/components/atoms/Input/Input";
import type { FormEvent } from "react";
interface Props {
email: string;
name: string;
isSubmitting: boolean;
onEmailChange: (value: string) => void;
onNameChange: (value: string) => void;
onSubmit: (event: FormEvent<HTMLFormElement>) => void;
}
export function InviteUserForm({
email,
name,
isSubmitting,
onEmailChange,
onNameChange,
onSubmit,
}: Props) {
return (
<form className="flex flex-col gap-4" onSubmit={onSubmit}>
<div className="flex flex-col gap-1">
<h2 className="text-xl font-semibold text-zinc-900">Create invite</h2>
<p className="text-sm text-zinc-600">
The invite is stored immediately, then Tally pre-seeding starts in the
background.
</p>
</div>
<Input
id="invite-email"
label="Email"
type="email"
value={email}
placeholder="jane@example.com"
autoComplete="email"
disabled={isSubmitting}
onChange={(event) => onEmailChange(event.target.value)}
/>
<Input
id="invite-name"
label="Name"
type="text"
value={name}
placeholder="Jane Doe"
disabled={isSubmitting}
onChange={(event) => onNameChange(event.target.value)}
/>
<Button
type="submit"
variant="primary"
loading={isSubmitting}
disabled={!email.trim()}
className="w-full"
>
{isSubmitting ? "Creating invite..." : "Create invite"}
</Button>
</form>
);
}

View File

@@ -0,0 +1,209 @@
"use client";
import type { InvitedUserResponse } from "@/app/api/__generated__/models/invitedUserResponse";
import { Badge } from "@/components/atoms/Badge/Badge";
import { Button } from "@/components/atoms/Button/Button";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/__legacy__/ui/table";
interface Props {
invitedUsers: InvitedUserResponse[];
isLoading: boolean;
isRefreshing: boolean;
pendingInviteAction: string | null;
onRetryTally: (invitedUserId: string) => void;
onRevoke: (invitedUserId: string) => void;
}
function getInviteBadgeVariant(status: InvitedUserResponse["status"]) {
if (status === "CLAIMED") {
return "success";
}
if (status === "REVOKED") {
return "error";
}
return "info";
}
function getTallyBadgeVariant(status: InvitedUserResponse["tally_status"]) {
if (status === "READY") {
return "success";
}
if (status === "FAILED") {
return "error";
}
return "info";
}
function formatDate(value: Date | undefined) {
if (!value) {
return "-";
}
return value.toLocaleString();
}
function getTallySummary(invitedUser: InvitedUserResponse) {
if (invitedUser.tally_status === "FAILED" && invitedUser.tally_error) {
return invitedUser.tally_error;
}
if (invitedUser.tally_status === "READY" && invitedUser.tally_understanding) {
return "Stored and ready for activation";
}
if (invitedUser.tally_status === "READY") {
return "No matching Tally submission found";
}
if (invitedUser.tally_status === "RUNNING") {
return "Extraction in progress";
}
return "Waiting to run";
}
function isActionPending(
pendingInviteAction: string | null,
action: "retry" | "revoke",
invitedUserId: string,
) {
return pendingInviteAction === `${action}:${invitedUserId}`;
}
export function InvitedUsersTable({
invitedUsers,
isLoading,
isRefreshing,
pendingInviteAction,
onRetryTally,
onRevoke,
}: Props) {
return (
<div className="flex flex-col gap-4">
<div className="flex items-center justify-between gap-4">
<div className="flex flex-col gap-1">
<h2 className="text-xl font-semibold text-zinc-900">Invited users</h2>
<p className="text-sm text-zinc-600">
Live invite state, claim status, and Tally pre-seeding progress.
</p>
</div>
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
{isRefreshing ? "Refreshing" : `${invitedUsers.length} total`}
</span>
</div>
<div className="overflow-hidden rounded-2xl border border-zinc-200">
<Table>
<TableHeader className="bg-zinc-50">
<TableRow>
<TableHead>Email</TableHead>
<TableHead>Name</TableHead>
<TableHead>Invite</TableHead>
<TableHead>Tally</TableHead>
<TableHead>Claimed User</TableHead>
<TableHead>Created</TableHead>
<TableHead className="text-right">Actions</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{isLoading ? (
<TableRow>
<TableCell
colSpan={7}
className="py-10 text-center text-zinc-500"
>
Loading invited users...
</TableCell>
</TableRow>
) : invitedUsers.length === 0 ? (
<TableRow>
<TableCell
colSpan={7}
className="py-10 text-center text-zinc-500"
>
No invited users yet
</TableCell>
</TableRow>
) : (
invitedUsers.map((invitedUser) => (
<TableRow key={invitedUser.id} className="align-top">
<TableCell className="font-medium text-zinc-900">
{invitedUser.email}
</TableCell>
<TableCell>{invitedUser.name || "-"}</TableCell>
<TableCell>
<Badge variant={getInviteBadgeVariant(invitedUser.status)}>
{invitedUser.status}
</Badge>
</TableCell>
<TableCell>
<div className="flex max-w-xs flex-col gap-2">
<Badge
variant={getTallyBadgeVariant(invitedUser.tally_status)}
>
{invitedUser.tally_status}
</Badge>
<span className="text-xs text-zinc-500">
{getTallySummary(invitedUser)}
</span>
<span className="text-xs text-zinc-400">
{formatDate(invitedUser.tally_computed_at ?? undefined)}
</span>
</div>
</TableCell>
<TableCell className="font-mono text-xs text-zinc-500">
{invitedUser.auth_user_id || "-"}
</TableCell>
<TableCell className="text-sm text-zinc-500">
{formatDate(invitedUser.created_at)}
</TableCell>
<TableCell>
<div className="flex justify-end gap-2">
<Button
variant="outline"
size="small"
disabled={invitedUser.status === "REVOKED"}
loading={isActionPending(
pendingInviteAction,
"retry",
invitedUser.id,
)}
onClick={() => onRetryTally(invitedUser.id)}
>
Retry Tally
</Button>
<Button
variant="secondary"
size="small"
disabled={invitedUser.status !== "INVITED"}
loading={isActionPending(
pendingInviteAction,
"revoke",
invitedUser.id,
)}
onClick={() => onRevoke(invitedUser.id)}
>
Revoke
</Button>
</div>
</TableCell>
</TableRow>
))
)}
</TableBody>
</Table>
</div>
</div>
);
}

View File

@@ -1,16 +1,11 @@
import { withRoleAccess } from "@/lib/withRoleAccess";
import React from "react";
import { AdminUsersPage } from "./components/AdminUsersPage/AdminUsersPage";
function AdminUsers() {
return (
<div>
<h1>Users Dashboard</h1>
{/* Add your admin-only content here */}
</div>
);
return <AdminUsersPage />;
}
export default async function AdminUsersPage() {
export default async function AdminUsersRoute() {
"use server";
const withAdminAccess = await withRoleAccess(["admin"]);
const ProtectedAdminUsers = await withAdminAccess(AdminUsers);

View File

@@ -0,0 +1,201 @@
"use client";
import type { BulkInvitedUsersResponse } from "@/app/api/__generated__/models/bulkInvitedUsersResponse";
import { okData } from "@/app/api/helpers";
import {
getGetV2ListInvitedUsersQueryKey,
useGetV2ListInvitedUsers,
usePostV2BulkCreateInvitedUsers,
usePostV2CreateInvitedUser,
usePostV2RetryInvitedUserTally,
usePostV2RevokeInvitedUser,
} from "@/app/api/__generated__/endpoints/admin/admin";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { ApiError } from "@/lib/autogpt-server-api/helpers";
import { useQueryClient } from "@tanstack/react-query";
import { type FormEvent, useState } from "react";
function getErrorMessage(error: unknown) {
if (error instanceof ApiError) {
return error.message;
}
if (error instanceof Error) {
return error.message;
}
return "Something went wrong";
}
export function useAdminUsersPage() {
const queryClient = useQueryClient();
const { toast } = useToast();
const [email, setEmail] = useState("");
const [name, setName] = useState("");
const [bulkInviteFile, setBulkInviteFile] = useState<File | null>(null);
const [bulkInviteInputKey, setBulkInviteInputKey] = useState(0);
const [lastBulkInviteResult, setLastBulkInviteResult] =
useState<BulkInvitedUsersResponse | null>(null);
const [pendingInviteAction, setPendingInviteAction] = useState<string | null>(
null,
);
const invitedUsersQuery = useGetV2ListInvitedUsers({
query: {
select: okData,
refetchInterval: 5000,
},
});
const createInvitedUserMutation = usePostV2CreateInvitedUser({
mutation: {
onSuccess: async () => {
setEmail("");
setName("");
await queryClient.invalidateQueries({
queryKey: getGetV2ListInvitedUsersQueryKey(),
});
toast({
title: "Invited user created",
variant: "default",
});
},
onError: (error) => {
toast({
title: getErrorMessage(error),
variant: "destructive",
});
},
},
});
const bulkCreateInvitedUsersMutation = usePostV2BulkCreateInvitedUsers({
mutation: {
onSuccess: async (response) => {
const result = okData(response) ?? null;
setBulkInviteFile(null);
setBulkInviteInputKey((currentValue) => currentValue + 1);
setLastBulkInviteResult(result);
await queryClient.invalidateQueries({
queryKey: getGetV2ListInvitedUsersQueryKey(),
});
toast({
title: result
? `${result.created_count} invites created`
: "Bulk invite upload complete",
variant: "default",
});
},
onError: (error) => {
toast({
title: getErrorMessage(error),
variant: "destructive",
});
},
},
});
const retryInvitedUserTallyMutation = usePostV2RetryInvitedUserTally({
mutation: {
onSuccess: async () => {
setPendingInviteAction(null);
await queryClient.invalidateQueries({
queryKey: getGetV2ListInvitedUsersQueryKey(),
});
toast({
title: "Tally pre-seeding restarted",
variant: "default",
});
},
onError: (error) => {
setPendingInviteAction(null);
toast({
title: getErrorMessage(error),
variant: "destructive",
});
},
},
});
const revokeInvitedUserMutation = usePostV2RevokeInvitedUser({
mutation: {
onSuccess: async () => {
setPendingInviteAction(null);
await queryClient.invalidateQueries({
queryKey: getGetV2ListInvitedUsersQueryKey(),
});
toast({
title: "Invite revoked",
variant: "default",
});
},
onError: (error) => {
setPendingInviteAction(null);
toast({
title: getErrorMessage(error),
variant: "destructive",
});
},
},
});
function handleCreateInvite(event: FormEvent<HTMLFormElement>) {
event.preventDefault();
createInvitedUserMutation.mutate({
data: {
email,
name: name.trim() || null,
},
});
}
function handleRetryTally(invitedUserId: string) {
setPendingInviteAction(`retry:${invitedUserId}`);
retryInvitedUserTallyMutation.mutate({ invitedUserId });
}
function handleBulkInviteFileChange(file: File | null) {
setBulkInviteFile(file);
}
function handleBulkInviteSubmit(event: FormEvent<HTMLFormElement>) {
event.preventDefault();
if (!bulkInviteFile) {
return;
}
bulkCreateInvitedUsersMutation.mutate({
data: {
file: bulkInviteFile,
},
});
}
function handleRevoke(invitedUserId: string) {
setPendingInviteAction(`revoke:${invitedUserId}`);
revokeInvitedUserMutation.mutate({ invitedUserId });
}
return {
email,
name,
bulkInviteFile,
bulkInviteInputKey,
lastBulkInviteResult,
invitedUsers: invitedUsersQuery.data?.invited_users ?? [],
isLoadingInvitedUsers: invitedUsersQuery.isLoading,
isRefreshingInvitedUsers: invitedUsersQuery.isFetching,
isCreatingInvite: createInvitedUserMutation.isPending,
isBulkInviting: bulkCreateInvitedUsersMutation.isPending,
pendingInviteAction,
setEmail,
setName,
handleBulkInviteFileChange,
handleBulkInviteSubmit,
handleCreateInvite,
handleRetryTally,
handleRevoke,
};
}

View File

@@ -151,6 +151,9 @@ export function ChatInput({
onFilesSelected={handleFilesSelected}
disabled={isBusy}
/>
</PromptInputTools>
<div className="flex items-center gap-4">
{showMicButton && (
<RecordingButton
isRecording={isRecording}
@@ -160,13 +163,12 @@ export function ChatInput({
onClick={toggleRecording}
/>
)}
</PromptInputTools>
{isStreaming ? (
<PromptInputSubmit status="streaming" onStop={onStop} />
) : (
<PromptInputSubmit disabled={!canSend} />
)}
{isStreaming ? (
<PromptInputSubmit status="streaming" onStop={onStop} />
) : (
<PromptInputSubmit disabled={!canSend} />
)}
</div>
</PromptInputFooter>
</InputGroup>
</form>

View File

@@ -28,10 +28,9 @@ export function RecordingButton({
disabled={disabled}
onClick={onClick}
className={cn(
"border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
"border-0 bg-white text-zinc-500 hover:bg-zinc-50 hover:text-zinc-700",
disabled && "opacity-40",
isRecording &&
"animate-pulse border-red-500 bg-red-500 text-white hover:border-red-600 hover:bg-red-600",
isRecording && "animate-pulse bg-red-500 text-white hover:bg-red-600",
isTranscribing && "bg-zinc-100 text-zinc-400",
isStreaming && "opacity-40",
)}

View File

@@ -5,15 +5,19 @@ import {
} from "@/components/ai-elements/conversation";
import { Message, MessageContent } from "@/components/ai-elements/message";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { FileUIPart, UIDataTypes, UIMessage, UITools } from "ai";
import { FileUIPart, ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
import { TOOL_PART_PREFIX } from "../JobStatsBar/constants";
import { TurnStatsBar } from "../JobStatsBar/TurnStatsBar";
import { parseSpecialMarkers } from "./helpers";
import { AssistantMessageActions } from "./components/AssistantMessageActions";
import { CollapsedToolGroup } from "./components/CollapsedToolGroup";
import { MessageAttachments } from "./components/MessageAttachments";
import { MessagePartRenderer } from "./components/MessagePartRenderer";
import { ReasoningCollapse } from "./components/ReasoningCollapse";
import { ThinkingIndicator } from "./components/ThinkingIndicator";
type MessagePart = UIMessage<unknown, UIDataTypes, UITools>["parts"][number];
interface Props {
messages: UIMessage<unknown, UIDataTypes, UITools>[];
status: string;
@@ -23,6 +27,132 @@ interface Props {
sessionID?: string | null;
}
function isCompletedToolPart(part: MessagePart): part is ToolUIPart {
return (
part.type.startsWith("tool-") &&
"state" in part &&
(part.state === "output-available" || part.state === "output-error")
);
}
type RenderSegment =
| { kind: "part"; part: MessagePart; index: number }
| { kind: "collapsed-group"; parts: ToolUIPart[] };
// Tool types that have custom renderers and should NOT be collapsed
const CUSTOM_TOOL_TYPES = new Set([
"tool-find_block",
"tool-find_agent",
"tool-find_library_agent",
"tool-search_docs",
"tool-get_doc_page",
"tool-run_block",
"tool-run_mcp_tool",
"tool-run_agent",
"tool-schedule_agent",
"tool-create_agent",
"tool-edit_agent",
"tool-view_agent_output",
"tool-search_feature_requests",
"tool-create_feature_request",
]);
/**
* Groups consecutive completed generic tool parts into collapsed segments.
* Non-generic tools (those with custom renderers) and active/streaming tools
* are left as individual parts.
*/
function buildRenderSegments(
parts: MessagePart[],
baseIndex = 0,
): RenderSegment[] {
const segments: RenderSegment[] = [];
let pendingGroup: Array<{ part: ToolUIPart; index: number }> | null = null;
function flushGroup() {
if (!pendingGroup) return;
if (pendingGroup.length >= 2) {
segments.push({
kind: "collapsed-group",
parts: pendingGroup.map((p) => p.part),
});
} else {
for (const p of pendingGroup) {
segments.push({ kind: "part", part: p.part, index: p.index });
}
}
pendingGroup = null;
}
parts.forEach((part, i) => {
const absoluteIndex = baseIndex + i;
const isGenericCompletedTool =
isCompletedToolPart(part) && !CUSTOM_TOOL_TYPES.has(part.type);
if (isGenericCompletedTool) {
if (!pendingGroup) pendingGroup = [];
pendingGroup.push({ part: part as ToolUIPart, index: absoluteIndex });
} else {
flushGroup();
segments.push({ kind: "part", part, index: absoluteIndex });
}
});
flushGroup();
return segments;
}
/**
* For finalized assistant messages, split parts into "reasoning" (intermediate
* text + tools before the final response) and "response" (final text after the
* last tool). If there are no tools, everything is response.
*/
function splitReasoningAndResponse(parts: MessagePart[]): {
reasoning: MessagePart[];
response: MessagePart[];
} {
const lastToolIndex = parts.findLastIndex((p) => p.type.startsWith("tool-"));
// No tools → everything is response
if (lastToolIndex === -1) {
return { reasoning: [], response: parts };
}
// Check if there's any text after the last tool
const hasResponseAfterTools = parts
.slice(lastToolIndex + 1)
.some((p) => p.type === "text");
if (!hasResponseAfterTools) {
// No final text response → don't collapse anything
return { reasoning: [], response: parts };
}
return {
reasoning: parts.slice(0, lastToolIndex + 1),
response: parts.slice(lastToolIndex + 1),
};
}
function renderSegments(
segments: RenderSegment[],
messageID: string,
): React.ReactNode[] {
return segments.map((seg, segIdx) => {
if (seg.kind === "collapsed-group") {
return <CollapsedToolGroup key={`group-${segIdx}`} parts={seg.parts} />;
}
return (
<MessagePartRenderer
key={`${messageID}-${seg.index}`}
part={seg.part}
messageID={messageID}
partIndex={seg.index}
/>
);
});
}
/** Collect all messages belonging to a turn: the user message + every
* assistant message up to (but not including) the next user message. */
function getTurnMessages(
@@ -119,6 +249,24 @@ export function ChatMessagesContainer({
(p): p is FileUIPart => p.type === "file",
);
// For finalized assistant messages, split into reasoning + response.
// During streaming, show everything normally with tool collapsing.
const isFinalized =
message.role === "assistant" && !isCurrentlyStreaming;
const { reasoning, response } = isFinalized
? splitReasoningAndResponse(message.parts)
: { reasoning: [] as MessagePart[], response: message.parts };
const hasReasoning = reasoning.length > 0;
const responseStartIndex = message.parts.length - response.length;
const responseSegments =
message.role === "assistant"
? buildRenderSegments(response, responseStartIndex)
: null;
const reasoningSegments = hasReasoning
? buildRenderSegments(reasoning, 0)
: null;
return (
<Message from={message.role} key={message.id}>
<MessageContent
@@ -128,14 +276,21 @@ export function ChatMessagesContainer({
"group-[.is-assistant]:bg-transparent group-[.is-assistant]:text-slate-900"
}
>
{message.parts.map((part, i) => (
<MessagePartRenderer
key={`${message.id}-${i}`}
part={part}
messageID={message.id}
partIndex={i}
/>
))}
{hasReasoning && reasoningSegments && (
<ReasoningCollapse>
{renderSegments(reasoningSegments, message.id)}
</ReasoningCollapse>
)}
{responseSegments
? renderSegments(responseSegments, message.id)
: message.parts.map((part, i) => (
<MessagePartRenderer
key={`${message.id}-${i}`}
part={part}
messageID={message.id}
partIndex={i}
/>
))}
{isLastInTurn && !isCurrentlyStreaming && (
<TurnStatsBar
turnMessages={getTurnMessages(messages, messageIndex)}

View File

@@ -0,0 +1,152 @@
"use client";
import { useId, useState } from "react";
import {
ArrowsClockwiseIcon,
CaretRightIcon,
CheckCircleIcon,
FileIcon,
FilesIcon,
GearIcon,
GlobeIcon,
ListChecksIcon,
MagnifyingGlassIcon,
MonitorIcon,
PencilSimpleIcon,
TerminalIcon,
TrashIcon,
WarningDiamondIcon,
} from "@phosphor-icons/react";
import type { ToolUIPart } from "ai";
import {
type ToolCategory,
extractToolName,
getAnimationText,
getToolCategory,
} from "../../../tools/GenericTool/helpers";
interface Props {
parts: ToolUIPart[];
}
/** Category icon matching GenericTool's ToolIcon for completed states. */
function EntryIcon({
category,
isError,
}: {
category: ToolCategory;
isError: boolean;
}) {
if (isError) {
return (
<WarningDiamondIcon size={14} weight="regular" className="text-red-500" />
);
}
const iconClass = "text-green-500";
switch (category) {
case "bash":
return <TerminalIcon size={14} weight="regular" className={iconClass} />;
case "web":
return <GlobeIcon size={14} weight="regular" className={iconClass} />;
case "browser":
return <MonitorIcon size={14} weight="regular" className={iconClass} />;
case "file-read":
case "file-write":
return <FileIcon size={14} weight="regular" className={iconClass} />;
case "file-delete":
return <TrashIcon size={14} weight="regular" className={iconClass} />;
case "file-list":
return <FilesIcon size={14} weight="regular" className={iconClass} />;
case "search":
return (
<MagnifyingGlassIcon size={14} weight="regular" className={iconClass} />
);
case "edit":
return (
<PencilSimpleIcon size={14} weight="regular" className={iconClass} />
);
case "todo":
return (
<ListChecksIcon size={14} weight="regular" className={iconClass} />
);
case "compaction":
return (
<ArrowsClockwiseIcon size={14} weight="regular" className={iconClass} />
);
default:
return <GearIcon size={14} weight="regular" className={iconClass} />;
}
}
export function CollapsedToolGroup({ parts }: Props) {
const [expanded, setExpanded] = useState(false);
const panelId = useId();
const errorCount = parts.filter((p) => p.state === "output-error").length;
const label =
errorCount > 0
? `${parts.length} tool calls (${errorCount} failed)`
: `${parts.length} tool calls completed`;
return (
<div className="py-1">
<button
type="button"
onClick={() => setExpanded(!expanded)}
aria-expanded={expanded}
aria-controls={panelId}
className="flex items-center gap-1.5 text-sm text-muted-foreground transition-colors hover:text-foreground"
>
<CaretRightIcon
size={12}
weight="bold"
className={
"transition-transform duration-150 " + (expanded ? "rotate-90" : "")
}
/>
{errorCount > 0 ? (
<WarningDiamondIcon
size={14}
weight="regular"
className="text-red-500"
/>
) : (
<CheckCircleIcon
size={14}
weight="regular"
className="text-green-500"
/>
)}
<span>{label}</span>
</button>
{expanded && (
<div
id={panelId}
className="ml-5 mt-1 space-y-0.5 border-l border-neutral-200 pl-3"
>
{parts.map((part) => {
const toolName = extractToolName(part);
const category = getToolCategory(toolName);
const text = getAnimationText(part, category);
const isError = part.state === "output-error";
return (
<div
key={part.toolCallId}
className={
"flex items-center gap-1.5 text-xs " +
(isError ? "text-red-500" : "text-muted-foreground")
}
>
<EntryIcon category={category} isError={isError} />
<span>{text}</span>
</div>
);
})}
</div>
)}
</div>
);
}

View File

@@ -15,6 +15,7 @@ export function FeedbackModal({ isOpen, onSubmit, onCancel }: Props) {
const [comment, setComment] = useState("");
function handleSubmit() {
if (!comment.trim()) return;
onSubmit(comment);
setComment("");
}
@@ -36,7 +37,7 @@ export function FeedbackModal({ isOpen, onSubmit, onCancel }: Props) {
>
<Dialog.Content>
<div className="mx-auto w-[95%] space-y-4">
<p className="text-sm text-slate-600">
<p className="text-sm text-muted-foreground">
Your feedback helps us improve. Share details below.
</p>
<Textarea
@@ -48,12 +49,18 @@ export function FeedbackModal({ isOpen, onSubmit, onCancel }: Props) {
className="resize-none"
/>
<div className="flex items-center justify-between">
<p className="text-xs text-slate-400">{comment.length}/2000</p>
<p className="text-xs text-muted-foreground">
{comment.length}/2000
</p>
<div className="flex gap-2">
<Button variant="outline" size="sm" onClick={handleClose}>
Cancel
</Button>
<Button size="sm" onClick={handleSubmit}>
<Button
size="sm"
onClick={handleSubmit}
disabled={!comment.trim()}
>
Submit feedback
</Button>
</div>

View File

@@ -10,6 +10,7 @@ import {
SearchFeatureRequestsTool,
} from "../../../tools/FeatureRequests/FeatureRequests";
import { FindAgentsTool } from "../../../tools/FindAgents/FindAgents";
import { FolderTool } from "../../../tools/FolderTool/FolderTool";
import { FindBlocksTool } from "../../../tools/FindBlocks/FindBlocks";
import { GenericTool } from "../../../tools/GenericTool/GenericTool";
import { RunAgentTool } from "../../../tools/RunAgent/RunAgent";
@@ -145,6 +146,13 @@ export function MessagePartRenderer({ part, messageID, partIndex }: Props) {
return <SearchFeatureRequestsTool key={key} part={part as ToolUIPart} />;
case "tool-create_feature_request":
return <CreateFeatureRequestTool key={key} part={part as ToolUIPart} />;
case "tool-create_folder":
case "tool-list_folders":
case "tool-update_folder":
case "tool-move_folder":
case "tool-delete_folder":
case "tool-move_agents_to_folder":
return <FolderTool key={key} part={part as ToolUIPart} />;
default:
// Render a generic tool indicator for SDK built-in
// tools (Read, Glob, Grep, etc.) or any unrecognized tool

View File

@@ -0,0 +1,27 @@
"use client";
import { LightbulbIcon } from "@phosphor-icons/react";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
interface Props {
children: React.ReactNode;
}
export function ReasoningCollapse({ children }: Props) {
return (
<Dialog title="Reasoning">
<Dialog.Trigger>
<button
type="button"
className="flex items-center gap-1 text-xs text-zinc-500 transition-colors hover:text-zinc-700"
>
<LightbulbIcon size={12} weight="bold" />
<span>Show reasoning</span>
</button>
</Dialog.Trigger>
<Dialog.Content>
<div className="space-y-1">{children}</div>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -3,6 +3,7 @@ import {
getGetV2ListSessionsQueryKey,
useDeleteV2DeleteSession,
useGetV2ListSessions,
usePatchV2UpdateSessionTitle,
} from "@/app/api/__generated__/endpoints/chat/chat";
import { Button } from "@/components/atoms/Button/Button";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
@@ -17,7 +18,6 @@ import { toast } from "@/components/molecules/Toast/use-toast";
import {
Sidebar,
SidebarContent,
SidebarFooter,
SidebarHeader,
SidebarTrigger,
useSidebar,
@@ -25,8 +25,9 @@ import {
import { cn } from "@/lib/utils";
import { DotsThree, PlusCircleIcon, PlusIcon } from "@phosphor-icons/react";
import { useQueryClient } from "@tanstack/react-query";
import { motion } from "framer-motion";
import { AnimatePresence, motion } from "framer-motion";
import { parseAsString, useQueryState } from "nuqs";
import { useEffect, useRef, useState } from "react";
import { useCopilotUIStore } from "../../store";
import { DeleteChatDialog } from "../DeleteChatDialog/DeleteChatDialog";
@@ -65,6 +66,39 @@ export function ChatSidebar() {
},
});
const [editingSessionId, setEditingSessionId] = useState<string | null>(null);
const [editingTitle, setEditingTitle] = useState("");
const renameInputRef = useRef<HTMLInputElement>(null);
const renameCancelledRef = useRef(false);
const { mutate: renameSession } = usePatchV2UpdateSessionTitle({
mutation: {
onSuccess: () => {
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
setEditingSessionId(null);
},
onError: (error) => {
toast({
title: "Failed to rename chat",
description:
error instanceof Error ? error.message : "An error occurred",
variant: "destructive",
});
setEditingSessionId(null);
},
},
});
// Auto-focus the rename input when editing starts
useEffect(() => {
if (editingSessionId && renameInputRef.current) {
renameInputRef.current.focus();
renameInputRef.current.select();
}
}, [editingSessionId]);
const sessions =
sessionsResponse?.status === 200 ? sessionsResponse.data.sessions : [];
@@ -76,6 +110,26 @@ export function ChatSidebar() {
setSessionId(id);
}
function handleRenameClick(
e: React.MouseEvent,
id: string,
title: string | null | undefined,
) {
e.stopPropagation();
renameCancelledRef.current = false;
setEditingSessionId(id);
setEditingTitle(title || "");
}
function handleRenameSubmit(id: string) {
const trimmed = editingTitle.trim();
if (trimmed) {
renameSession({ sessionId: id, data: { title: trimmed } });
} else {
setEditingSessionId(null);
}
}
function handleDeleteClick(
e: React.MouseEvent,
id: string,
@@ -160,29 +214,42 @@ export function ChatSidebar() {
</motion.div>
</SidebarHeader>
)}
{!isCollapsed && (
<SidebarHeader className="shrink-0 px-4 pb-4 pt-4 shadow-[0_4px_6px_-1px_rgba(0,0,0,0.05)]">
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.1 }}
className="flex flex-col gap-3 px-3"
>
<div className="flex items-center justify-between">
<Text variant="h3" size="body-medium">
Your chats
</Text>
<div className="relative left-6">
<SidebarTrigger />
</div>
</div>
<Button
variant="primary"
size="small"
onClick={handleNewChat}
className="w-full"
leftIcon={<PlusIcon className="h-4 w-4" weight="bold" />}
>
New Chat
</Button>
</motion.div>
</SidebarHeader>
)}
<SidebarContent className="gap-4 overflow-y-auto px-4 py-4 [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
{!isCollapsed && (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.1 }}
className="flex items-center justify-between px-3"
>
<Text variant="h3" size="body-medium">
Your chats
</Text>
<div className="relative left-6">
<SidebarTrigger />
</div>
</motion.div>
)}
{!isCollapsed && (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.15 }}
className="mt-4 flex flex-col gap-1"
className="flex flex-col gap-1"
>
{isLoadingSessions ? (
<div className="flex min-h-[30rem] items-center justify-center py-4">
@@ -203,76 +270,105 @@ export function ChatSidebar() {
: "hover:bg-zinc-50",
)}
>
<button
onClick={() => handleSelectSession(session.id)}
className="w-full px-3 py-2.5 pr-10 text-left"
>
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
<div className="min-w-0 max-w-full">
<Text
variant="body"
className={cn(
"truncate font-normal",
session.id === sessionId
? "text-zinc-600"
: "text-zinc-800",
)}
>
{session.title || `Untitled chat`}
{editingSessionId === session.id ? (
<div className="px-3 py-2.5">
<input
ref={renameInputRef}
type="text"
aria-label="Rename chat"
value={editingTitle}
onChange={(e) => setEditingTitle(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter") {
e.currentTarget.blur();
} else if (e.key === "Escape") {
renameCancelledRef.current = true;
setEditingSessionId(null);
}
}}
onBlur={() => {
if (renameCancelledRef.current) {
renameCancelledRef.current = false;
return;
}
handleRenameSubmit(session.id);
}}
className="w-full rounded border border-zinc-300 bg-white px-2 py-1 text-sm text-zinc-800 outline-none focus:border-purple-500 focus:ring-1 focus:ring-purple-500"
/>
</div>
) : (
<button
onClick={() => handleSelectSession(session.id)}
className="w-full px-3 py-2.5 pr-10 text-left"
>
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
<div className="min-w-0 max-w-full">
<Text
variant="body"
className={cn(
"truncate font-normal",
session.id === sessionId
? "text-zinc-600"
: "text-zinc-800",
)}
>
<AnimatePresence mode="wait" initial={false}>
<motion.span
key={session.title || "untitled"}
initial={{ opacity: 0, y: 4 }}
animate={{ opacity: 1, y: 0 }}
exit={{ opacity: 0, y: -4 }}
transition={{ duration: 0.2 }}
className="block truncate"
>
{session.title || "Untitled chat"}
</motion.span>
</AnimatePresence>
</Text>
</div>
<Text variant="small" className="text-neutral-400">
{formatDate(session.updated_at)}
</Text>
</div>
<Text variant="small" className="text-neutral-400">
{formatDate(session.updated_at)}
</Text>
</div>
</button>
<DropdownMenu>
<DropdownMenuTrigger asChild>
<button
onClick={(e) => e.stopPropagation()}
className="absolute right-2 top-1/2 -translate-y-1/2 rounded-full p-1.5 text-zinc-600 transition-all hover:bg-neutral-100"
aria-label="More actions"
>
<DotsThree className="h-4 w-4" />
</button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem
onClick={(e) =>
handleDeleteClick(e, session.id, session.title)
}
disabled={isDeleting}
className="text-red-600 focus:bg-red-50 focus:text-red-600"
>
Delete chat
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
</button>
)}
{editingSessionId !== session.id && (
<DropdownMenu>
<DropdownMenuTrigger asChild>
<button
onClick={(e) => e.stopPropagation()}
className="absolute right-2 top-1/2 -translate-y-1/2 rounded-full p-1.5 text-zinc-600 transition-all hover:bg-neutral-100"
aria-label="More actions"
>
<DotsThree className="h-4 w-4" />
</button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem
onClick={(e) =>
handleRenameClick(e, session.id, session.title)
}
>
Rename
</DropdownMenuItem>
<DropdownMenuItem
onClick={(e) =>
handleDeleteClick(e, session.id, session.title)
}
disabled={isDeleting}
className="text-red-600 focus:bg-red-50 focus:text-red-600"
>
Delete chat
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
)}
</div>
))
)}
</motion.div>
)}
</SidebarContent>
{!isCollapsed && sessionId && (
<SidebarFooter className="shrink-0 bg-zinc-50 p-3 pb-1 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.2 }}
>
<Button
variant="primary"
size="small"
onClick={handleNewChat}
className="w-full"
leftIcon={<PlusIcon className="h-4 w-4" weight="bold" />}
>
New Chat
</Button>
</motion.div>
</SidebarFooter>
)}
</Sidebar>
<DeleteChatDialog

View File

@@ -29,7 +29,6 @@ export function DeleteChatDialog({
}
},
}}
onClose={isDeleting ? undefined : onCancel}
>
<Dialog.Content>
<Text variant="body">

View File

@@ -71,6 +71,17 @@ export function MobileDrawer({
<X width="1rem" height="1rem" />
</Button>
</div>
<div className="mt-2">
<Button
variant="primary"
size="small"
onClick={onNewChat}
className="w-full"
leftIcon={<PlusIcon width="1rem" height="1rem" />}
>
New Chat
</Button>
</div>
</div>
<div
className={cn(
@@ -120,19 +131,6 @@ export function MobileDrawer({
))
)}
</div>
{currentSessionId && (
<div className="shrink-0 bg-white p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
<Button
variant="primary"
size="small"
onClick={onNewChat}
className="w-full"
leftIcon={<PlusIcon width="1rem" height="1rem" />}
>
New Chat
</Button>
</div>
)}
</Drawer.Content>
</Drawer.Portal>
</Drawer.Root>

View File

@@ -181,6 +181,14 @@ export function convertChatSessionMessagesToUiMessages(
if (parts.length === 0) return;
// Merge consecutive assistant messages into a single UIMessage
// to avoid split bubbles on page reload.
const prevUI = uiMessages[uiMessages.length - 1];
if (msg.role === "assistant" && prevUI && prevUI.role === "assistant") {
prevUI.parts.push(...parts);
return;
}
uiMessages.push({
id: `${sessionId}-${index}`,
role: msg.role,

View File

@@ -10,7 +10,7 @@ import {
WarningDiamondIcon,
} from "@phosphor-icons/react";
import type { ToolUIPart } from "ai";
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
import { ScaleLoader } from "../../components/ScaleLoader/ScaleLoader";
export type CreateAgentToolOutput =
| AgentPreviewResponse
@@ -134,7 +134,7 @@ export function ToolIcon({
);
}
if (isStreaming) {
return <OrbitLoader size={24} />;
return <ScaleLoader size={14} />;
}
return <PlusIcon size={14} weight="regular" className="text-neutral-400" />;
}

View File

@@ -9,7 +9,7 @@ import {
WarningDiamondIcon,
} from "@phosphor-icons/react";
import type { ToolUIPart } from "ai";
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
import { ScaleLoader } from "../../components/ScaleLoader/ScaleLoader";
export type EditAgentToolOutput =
| AgentPreviewResponse
@@ -121,7 +121,7 @@ export function ToolIcon({
);
}
if (isStreaming) {
return <OrbitLoader size={24} />;
return <ScaleLoader size={14} />;
}
return (
<PencilLineIcon size={14} weight="regular" className="text-neutral-400" />

Some files were not shown because too many files have changed in this diff Show More