mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
15 Commits
hotfix/tra
...
swiftyos/c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7d6338c1fc | ||
|
|
ea69536165 | ||
|
|
eadc68f2a5 | ||
|
|
eca7b5e793 | ||
|
|
c304a4937a | ||
|
|
8cfabcf4fd | ||
|
|
7bf407b66c | ||
|
|
7ead4c040f | ||
|
|
0f813f1bf9 | ||
|
|
aa08063939 | ||
|
|
bde6a4c0df | ||
|
|
d56452898a | ||
|
|
7507240177 | ||
|
|
d7c3f5b8fc | ||
|
|
3e108a813a |
@@ -9,6 +9,21 @@ import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_id() -> str:
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_id() -> str:
|
||||
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_test_user(test_user_id: str) -> str:
|
||||
return test_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def configured_snapshot(snapshot: Snapshot) -> Snapshot:
|
||||
"""Pre-configured snapshot fixture with standard settings."""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import urllib.parse
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, Any, Literal, Optional, Sequence
|
||||
from typing import Annotated, Any, Optional, Sequence
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Security
|
||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||
@@ -9,9 +9,10 @@ from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import backend.api.features.store.cache as store_cache
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
import backend.blocks
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.external.middleware import require_auth, require_permission
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data import user as user_db
|
||||
@@ -230,13 +231,13 @@ async def get_graph_execution_results(
|
||||
@v1_router.get(
|
||||
path="/store/agents",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
response_model=store_model.StoreAgentsResponse,
|
||||
)
|
||||
async def get_store_agents(
|
||||
featured: bool = False,
|
||||
creator: str | None = None,
|
||||
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
|
||||
sorted_by: store_db.StoreAgentsSortOptions | None = None,
|
||||
search_query: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
@@ -278,7 +279,7 @@ async def get_store_agents(
|
||||
@v1_router.get(
|
||||
path="/store/agents/{username}/{agent_name}",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
response_model=store_model.StoreAgentDetails,
|
||||
)
|
||||
async def get_store_agent(
|
||||
@@ -306,13 +307,13 @@ async def get_store_agent(
|
||||
@v1_router.get(
|
||||
path="/store/creators",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
response_model=store_model.CreatorsResponse,
|
||||
)
|
||||
async def get_store_creators(
|
||||
featured: bool = False,
|
||||
search_query: str | None = None,
|
||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
|
||||
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> store_model.CreatorsResponse:
|
||||
@@ -348,7 +349,7 @@ async def get_store_creators(
|
||||
@v1_router.get(
|
||||
path="/store/creators/{username}",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
response_model=store_model.CreatorDetails,
|
||||
)
|
||||
async def get_store_creator(
|
||||
|
||||
@@ -24,14 +24,13 @@ router = fastapi.APIRouter(
|
||||
@router.get(
|
||||
"/listings",
|
||||
summary="Get Admin Listings History",
|
||||
response_model=store_model.StoreListingsWithVersionsResponse,
|
||||
)
|
||||
async def get_admin_listings_with_versions(
|
||||
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
|
||||
search: typing.Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
) -> store_model.StoreListingsWithVersionsAdminViewResponse:
|
||||
"""
|
||||
Get store listings with their version history for admins.
|
||||
|
||||
@@ -45,36 +44,26 @@ async def get_admin_listings_with_versions(
|
||||
page_size: Number of items per page
|
||||
|
||||
Returns:
|
||||
StoreListingsWithVersionsResponse with listings and their versions
|
||||
Paginated listings with their versions
|
||||
"""
|
||||
try:
|
||||
listings = await store_db.get_admin_listings_with_versions(
|
||||
status=status,
|
||||
search_query=search,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return listings
|
||||
except Exception as e:
|
||||
logger.exception("Error getting admin listings with versions: %s", e)
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "An error occurred while retrieving listings with versions"
|
||||
},
|
||||
)
|
||||
listings = await store_db.get_admin_listings_with_versions(
|
||||
status=status,
|
||||
search_query=search,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return listings
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/{store_listing_version_id}/review",
|
||||
summary="Review Store Submission",
|
||||
response_model=store_model.StoreSubmission,
|
||||
)
|
||||
async def review_submission(
|
||||
store_listing_version_id: str,
|
||||
request: store_model.ReviewSubmissionRequest,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
) -> store_model.StoreSubmissionAdminView:
|
||||
"""
|
||||
Review a store listing submission.
|
||||
|
||||
@@ -84,31 +73,24 @@ async def review_submission(
|
||||
user_id: Authenticated admin user performing the review
|
||||
|
||||
Returns:
|
||||
StoreSubmission with updated review information
|
||||
StoreSubmissionAdminView with updated review information
|
||||
"""
|
||||
try:
|
||||
already_approved = await store_db.check_submission_already_approved(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
submission = await store_db.review_store_submission(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
is_approved=request.is_approved,
|
||||
external_comments=request.comments,
|
||||
internal_comments=request.internal_comments or "",
|
||||
reviewer_id=user_id,
|
||||
)
|
||||
already_approved = await store_db.check_submission_already_approved(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
submission = await store_db.review_store_submission(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
is_approved=request.is_approved,
|
||||
external_comments=request.comments,
|
||||
internal_comments=request.internal_comments or "",
|
||||
reviewer_id=user_id,
|
||||
)
|
||||
|
||||
state_changed = already_approved != request.is_approved
|
||||
# Clear caches when the request is approved as it updates what is shown on the store
|
||||
if state_changed:
|
||||
store_cache.clear_all_caches()
|
||||
return submission
|
||||
except Exception as e:
|
||||
logger.exception("Error reviewing submission: %s", e)
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while reviewing the submission"},
|
||||
)
|
||||
state_changed = already_approved != request.is_approved
|
||||
# Clear caches whenever approval state changes, since store visibility can change
|
||||
if state_changed:
|
||||
store_cache.clear_all_caches()
|
||||
return submission
|
||||
|
||||
|
||||
@router.get(
|
||||
|
||||
@@ -805,7 +805,6 @@ async def resume_session_stream(
|
||||
@router.patch(
|
||||
"/sessions/{session_id}/assign-user",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=200,
|
||||
)
|
||||
async def session_assign_user(
|
||||
session_id: str,
|
||||
|
||||
@@ -8,7 +8,6 @@ import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
import backend.api.features.store.exceptions as store_exceptions
|
||||
import backend.api.features.store.image_gen as store_image_gen
|
||||
import backend.api.features.store.media as store_media
|
||||
import backend.data.graph as graph_db
|
||||
@@ -251,7 +250,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
||||
The requested LibraryAgent.
|
||||
|
||||
Raises:
|
||||
AgentNotFoundError: If the specified agent does not exist.
|
||||
NotFoundError: If the specified agent does not exist.
|
||||
DatabaseError: If there's an error during retrieval.
|
||||
"""
|
||||
library_agent = await prisma.models.LibraryAgent.prisma().find_first(
|
||||
@@ -398,6 +397,7 @@ async def create_library_agent(
|
||||
hitl_safe_mode: bool = True,
|
||||
sensitive_action_safe_mode: bool = False,
|
||||
create_library_agents_for_sub_graphs: bool = True,
|
||||
folder_id: str | None = None,
|
||||
) -> list[library_model.LibraryAgent]:
|
||||
"""
|
||||
Adds an agent to the user's library (LibraryAgent table).
|
||||
@@ -414,12 +414,18 @@ async def create_library_agent(
|
||||
If the graph has sub-graphs, the parent graph will always be the first entry in the list.
|
||||
|
||||
Raises:
|
||||
AgentNotFoundError: If the specified agent does not exist.
|
||||
NotFoundError: If the specified agent does not exist.
|
||||
DatabaseError: If there's an error during creation or if image generation fails.
|
||||
"""
|
||||
logger.info(
|
||||
f"Creating library agent for graph #{graph.id} v{graph.version}; user:<redacted>"
|
||||
)
|
||||
|
||||
# Authorization: FK only checks existence, not ownership.
|
||||
# Verify the folder belongs to this user to prevent cross-user nesting.
|
||||
if folder_id:
|
||||
await get_folder(folder_id, user_id)
|
||||
|
||||
graph_entries = (
|
||||
[graph, *graph.sub_graphs] if create_library_agents_for_sub_graphs else [graph]
|
||||
)
|
||||
@@ -432,7 +438,6 @@ async def create_library_agent(
|
||||
isCreatedByUser=(user_id == user_id),
|
||||
useGraphIsActiveVersion=True,
|
||||
User={"connect": {"id": user_id}},
|
||||
# Creator={"connect": {"id": user_id}},
|
||||
AgentGraph={
|
||||
"connect": {
|
||||
"graphVersionId": {
|
||||
@@ -448,6 +453,11 @@ async def create_library_agent(
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
**(
|
||||
{"Folder": {"connect": {"id": folder_id}}}
|
||||
if folder_id and graph_entry is graph
|
||||
else {}
|
||||
),
|
||||
),
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
@@ -529,6 +539,7 @@ async def update_agent_version_in_library(
|
||||
async def create_graph_in_library(
|
||||
graph: graph_db.Graph,
|
||||
user_id: str,
|
||||
folder_id: str | None = None,
|
||||
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
||||
"""Create a new graph and add it to the user's library."""
|
||||
graph.version = 1
|
||||
@@ -542,6 +553,7 @@ async def create_graph_in_library(
|
||||
user_id=user_id,
|
||||
sensitive_action_safe_mode=True,
|
||||
create_library_agents_for_sub_graphs=False,
|
||||
folder_id=folder_id,
|
||||
)
|
||||
|
||||
if created_graph.is_active:
|
||||
@@ -817,7 +829,7 @@ async def add_store_agent_to_library(
|
||||
The newly created LibraryAgent if successfully added, the existing corresponding one if any.
|
||||
|
||||
Raises:
|
||||
AgentNotFoundError: If the store listing or associated agent is not found.
|
||||
NotFoundError: If the store listing or associated agent is not found.
|
||||
DatabaseError: If there's an issue creating the LibraryAgent record.
|
||||
"""
|
||||
logger.debug(
|
||||
@@ -832,7 +844,7 @@ async def add_store_agent_to_library(
|
||||
)
|
||||
if not store_listing_version or not store_listing_version.AgentGraph:
|
||||
logger.warning(f"Store listing version not found: {store_listing_version_id}")
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found or invalid"
|
||||
)
|
||||
|
||||
@@ -846,7 +858,7 @@ async def add_store_agent_to_library(
|
||||
include_subgraphs=False,
|
||||
)
|
||||
if not graph_model:
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
raise NotFoundError(
|
||||
f"Graph #{graph.id} v{graph.version} not found or accessible"
|
||||
)
|
||||
|
||||
@@ -1481,6 +1493,67 @@ async def bulk_move_agents_to_folder(
|
||||
return [library_model.LibraryAgent.from_db(agent) for agent in agents]
|
||||
|
||||
|
||||
def collect_tree_ids(
|
||||
nodes: list[library_model.LibraryFolderTree],
|
||||
visited: set[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Collect all folder IDs from a folder tree."""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
ids: list[str] = []
|
||||
for n in nodes:
|
||||
if n.id in visited:
|
||||
continue
|
||||
visited.add(n.id)
|
||||
ids.append(n.id)
|
||||
ids.extend(collect_tree_ids(n.children, visited))
|
||||
return ids
|
||||
|
||||
|
||||
async def get_folder_agent_summaries(
|
||||
user_id: str, folder_id: str
|
||||
) -> list[dict[str, str | None]]:
|
||||
"""Get a lightweight list of agents in a folder (id, name, description)."""
|
||||
all_agents: list[library_model.LibraryAgent] = []
|
||||
for page in itertools.count(1):
|
||||
resp = await list_library_agents(
|
||||
user_id=user_id, folder_id=folder_id, page=page
|
||||
)
|
||||
all_agents.extend(resp.agents)
|
||||
if page >= resp.pagination.total_pages:
|
||||
break
|
||||
return [
|
||||
{"id": a.id, "name": a.name, "description": a.description} for a in all_agents
|
||||
]
|
||||
|
||||
|
||||
async def get_root_agent_summaries(
|
||||
user_id: str,
|
||||
) -> list[dict[str, str | None]]:
|
||||
"""Get a lightweight list of root-level agents (folderId IS NULL)."""
|
||||
all_agents: list[library_model.LibraryAgent] = []
|
||||
for page in itertools.count(1):
|
||||
resp = await list_library_agents(
|
||||
user_id=user_id, include_root_only=True, page=page
|
||||
)
|
||||
all_agents.extend(resp.agents)
|
||||
if page >= resp.pagination.total_pages:
|
||||
break
|
||||
return [
|
||||
{"id": a.id, "name": a.name, "description": a.description} for a in all_agents
|
||||
]
|
||||
|
||||
|
||||
async def get_folder_agents_map(
|
||||
user_id: str, folder_ids: list[str]
|
||||
) -> dict[str, list[dict[str, str | None]]]:
|
||||
"""Get agent summaries for multiple folders concurrently."""
|
||||
results = await asyncio.gather(
|
||||
*(get_folder_agent_summaries(user_id, fid) for fid in folder_ids)
|
||||
)
|
||||
return dict(zip(folder_ids, results))
|
||||
|
||||
|
||||
##############################################
|
||||
########### Presets DB Functions #############
|
||||
##############################################
|
||||
|
||||
@@ -4,7 +4,6 @@ import prisma.enums
|
||||
import prisma.models
|
||||
import pytest
|
||||
|
||||
import backend.api.features.store.exceptions
|
||||
from backend.data.db import connect
|
||||
from backend.data.includes import library_agent_include
|
||||
|
||||
@@ -218,7 +217,7 @@ async def test_add_agent_to_library_not_found(mocker):
|
||||
)
|
||||
|
||||
# Call function and verify exception
|
||||
with pytest.raises(backend.api.features.store.exceptions.AgentNotFoundError):
|
||||
with pytest.raises(db.NotFoundError):
|
||||
await db.add_store_agent_to_library("version123", "test-user")
|
||||
|
||||
# Verify mock called correctly
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Literal
|
||||
|
||||
from backend.util.cache import cached
|
||||
|
||||
from . import db as store_db
|
||||
@@ -23,7 +21,7 @@ def clear_all_caches():
|
||||
async def _get_cached_store_agents(
|
||||
featured: bool,
|
||||
creator: str | None,
|
||||
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None,
|
||||
sorted_by: store_db.StoreAgentsSortOptions | None,
|
||||
search_query: str | None,
|
||||
category: str | None,
|
||||
page: int,
|
||||
@@ -57,7 +55,7 @@ async def _get_cached_agent_details(
|
||||
async def _get_cached_store_creators(
|
||||
featured: bool,
|
||||
search_query: str | None,
|
||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None,
|
||||
sorted_by: store_db.StoreCreatorsSortOptions | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
@@ -75,4 +73,4 @@ async def _get_cached_store_creators(
|
||||
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
|
||||
async def _get_cached_creator_details(username: str):
|
||||
"""Cached helper to get creator details."""
|
||||
return await store_db.get_store_creator_details(username=username.lower())
|
||||
return await store_db.get_store_creator(username=username.lower())
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -26,7 +26,7 @@ async def test_get_store_agents(mocker):
|
||||
mock_agents = [
|
||||
prisma.models.StoreAgent(
|
||||
listing_id="test-id",
|
||||
storeListingVersionId="version123",
|
||||
listing_version_id="version123",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video=None,
|
||||
@@ -40,11 +40,11 @@ async def test_get_store_agents(mocker):
|
||||
runs=10,
|
||||
rating=4.5,
|
||||
versions=["1.0"],
|
||||
agentGraphVersions=["1"],
|
||||
agentGraphId="test-graph-id",
|
||||
graph_id="test-graph-id",
|
||||
graph_versions=["1"],
|
||||
updated_at=datetime.now(),
|
||||
is_available=False,
|
||||
useForOnboarding=False,
|
||||
use_for_onboarding=False,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -68,10 +68,10 @@ async def test_get_store_agents(mocker):
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agent_details(mocker):
|
||||
# Mock data
|
||||
# Mock data - StoreAgent view already contains the active version data
|
||||
mock_agent = prisma.models.StoreAgent(
|
||||
listing_id="test-id",
|
||||
storeListingVersionId="version123",
|
||||
listing_version_id="version123",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video="video.mp4",
|
||||
@@ -85,102 +85,38 @@ async def test_get_store_agent_details(mocker):
|
||||
runs=10,
|
||||
rating=4.5,
|
||||
versions=["1.0"],
|
||||
agentGraphVersions=["1"],
|
||||
agentGraphId="test-graph-id",
|
||||
updated_at=datetime.now(),
|
||||
is_available=False,
|
||||
useForOnboarding=False,
|
||||
)
|
||||
|
||||
# Mock active version agent (what we want to return for active version)
|
||||
mock_active_agent = prisma.models.StoreAgent(
|
||||
listing_id="test-id",
|
||||
storeListingVersionId="active-version-id",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent Active",
|
||||
agent_video="active_video.mp4",
|
||||
agent_image=["active_image.jpg"],
|
||||
featured=False,
|
||||
creator_username="creator",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test heading active",
|
||||
description="Test description active",
|
||||
categories=["test"],
|
||||
runs=15,
|
||||
rating=4.8,
|
||||
versions=["1.0", "2.0"],
|
||||
agentGraphVersions=["1", "2"],
|
||||
agentGraphId="test-graph-id-active",
|
||||
graph_id="test-graph-id",
|
||||
graph_versions=["1"],
|
||||
updated_at=datetime.now(),
|
||||
is_available=True,
|
||||
useForOnboarding=False,
|
||||
use_for_onboarding=False,
|
||||
)
|
||||
|
||||
# Create a mock StoreListing result
|
||||
mock_store_listing = mocker.MagicMock()
|
||||
mock_store_listing.activeVersionId = "active-version-id"
|
||||
mock_store_listing.hasApprovedVersion = True
|
||||
mock_store_listing.ActiveVersion = mocker.MagicMock()
|
||||
mock_store_listing.ActiveVersion.recommendedScheduleCron = None
|
||||
|
||||
# Mock StoreAgent prisma call - need to handle multiple calls
|
||||
# Mock StoreAgent prisma call
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
|
||||
# Set up side_effect to return different results for different calls
|
||||
def mock_find_first_side_effect(*args, **kwargs):
|
||||
where_clause = kwargs.get("where", {})
|
||||
if "storeListingVersionId" in where_clause:
|
||||
# Second call for active version
|
||||
return mock_active_agent
|
||||
else:
|
||||
# First call for initial lookup
|
||||
return mock_agent
|
||||
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(
|
||||
side_effect=mock_find_first_side_effect
|
||||
)
|
||||
|
||||
# Mock Profile prisma call
|
||||
mock_profile = mocker.MagicMock()
|
||||
mock_profile.userId = "user-id-123"
|
||||
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
||||
mock_profile_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_profile
|
||||
)
|
||||
|
||||
# Mock StoreListing prisma call
|
||||
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_store_listing
|
||||
)
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_agent_details("creator", "test-agent")
|
||||
|
||||
# Verify results - should use active version data
|
||||
# Verify results - constructed from the StoreAgent view
|
||||
assert result.slug == "test-agent"
|
||||
assert result.agent_name == "Test Agent Active" # From active version
|
||||
assert result.active_version_id == "active-version-id"
|
||||
assert result.agent_name == "Test Agent"
|
||||
assert result.active_version_id == "version123"
|
||||
assert result.has_approved_version is True
|
||||
assert (
|
||||
result.store_listing_version_id == "active-version-id"
|
||||
) # Should be active version ID
|
||||
assert result.store_listing_version_id == "version123"
|
||||
assert result.graph_id == "test-graph-id"
|
||||
assert result.runs == 10
|
||||
assert result.rating == 4.5
|
||||
|
||||
# Verify mocks called correctly - now expecting 2 calls
|
||||
assert mock_store_agent.return_value.find_first.call_count == 2
|
||||
|
||||
# Check the specific calls
|
||||
calls = mock_store_agent.return_value.find_first.call_args_list
|
||||
assert calls[0] == mocker.call(
|
||||
# Verify single StoreAgent lookup
|
||||
mock_store_agent.return_value.find_first.assert_called_once_with(
|
||||
where={"creator_username": "creator", "slug": "test-agent"}
|
||||
)
|
||||
assert calls[1] == mocker.call(where={"storeListingVersionId": "active-version-id"})
|
||||
|
||||
mock_store_listing_db.return_value.find_first.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_creator_details(mocker):
|
||||
async def test_get_store_creator(mocker):
|
||||
# Mock data
|
||||
mock_creator_data = prisma.models.Creator(
|
||||
name="Test Creator",
|
||||
@@ -202,7 +138,7 @@ async def test_get_store_creator_details(mocker):
|
||||
mock_creator.return_value.find_unique.return_value = mock_creator_data
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_creator_details("creator")
|
||||
result = await db.get_store_creator("creator")
|
||||
|
||||
# Verify results
|
||||
assert result.username == "creator"
|
||||
@@ -218,61 +154,110 @@ async def test_get_store_creator_details(mocker):
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_create_store_submission(mocker):
|
||||
# Mock data
|
||||
now = datetime.now()
|
||||
|
||||
# Mock agent graph (with no pending submissions) and user with profile
|
||||
mock_profile = prisma.models.Profile(
|
||||
id="profile-id",
|
||||
userId="user-id",
|
||||
name="Test User",
|
||||
username="testuser",
|
||||
description="Test",
|
||||
isFeatured=False,
|
||||
links=[],
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
)
|
||||
mock_user = prisma.models.User(
|
||||
id="user-id",
|
||||
email="test@example.com",
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
Profile=[mock_profile],
|
||||
emailVerified=True,
|
||||
metadata="{}", # type: ignore[reportArgumentType]
|
||||
integrations="",
|
||||
maxEmailsPerDay=1,
|
||||
notifyOnAgentRun=True,
|
||||
notifyOnZeroBalance=True,
|
||||
notifyOnLowBalance=True,
|
||||
notifyOnBlockExecutionFailed=True,
|
||||
notifyOnContinuousAgentError=True,
|
||||
notifyOnDailySummary=True,
|
||||
notifyOnWeeklySummary=True,
|
||||
notifyOnMonthlySummary=True,
|
||||
notifyOnAgentApproved=True,
|
||||
notifyOnAgentRejected=True,
|
||||
timezone="Europe/Delft",
|
||||
)
|
||||
mock_agent = prisma.models.AgentGraph(
|
||||
id="agent-id",
|
||||
version=1,
|
||||
userId="user-id",
|
||||
createdAt=datetime.now(),
|
||||
createdAt=now,
|
||||
isActive=True,
|
||||
StoreListingVersions=[],
|
||||
User=mock_user,
|
||||
)
|
||||
|
||||
mock_listing = prisma.models.StoreListing(
|
||||
# Mock the created StoreListingVersion (returned by create)
|
||||
mock_store_listing_obj = prisma.models.StoreListing(
|
||||
id="listing-id",
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
isDeleted=False,
|
||||
hasApprovedVersion=False,
|
||||
slug="test-agent",
|
||||
agentGraphId="agent-id",
|
||||
agentGraphVersion=1,
|
||||
owningUserId="user-id",
|
||||
Versions=[
|
||||
prisma.models.StoreListingVersion(
|
||||
id="version-id",
|
||||
agentGraphId="agent-id",
|
||||
agentGraphVersion=1,
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
subHeading="Test heading",
|
||||
imageUrls=["image.jpg"],
|
||||
categories=["test"],
|
||||
isFeatured=False,
|
||||
isDeleted=False,
|
||||
version=1,
|
||||
storeListingId="listing-id",
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
isAvailable=True,
|
||||
)
|
||||
],
|
||||
useForOnboarding=False,
|
||||
)
|
||||
mock_version = prisma.models.StoreListingVersion(
|
||||
id="version-id",
|
||||
agentGraphId="agent-id",
|
||||
agentGraphVersion=1,
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
subHeading="",
|
||||
imageUrls=[],
|
||||
categories=[],
|
||||
isFeatured=False,
|
||||
isDeleted=False,
|
||||
version=1,
|
||||
storeListingId="listing-id",
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
isAvailable=True,
|
||||
submittedAt=now,
|
||||
StoreListing=mock_store_listing_obj,
|
||||
)
|
||||
|
||||
# Mock prisma calls
|
||||
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
|
||||
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
mock_store_listing = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_store_listing.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_store_listing.return_value.create = mocker.AsyncMock(return_value=mock_listing)
|
||||
# Mock transaction context manager
|
||||
mock_tx = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"backend.api.features.store.db.transaction",
|
||||
return_value=mocker.AsyncMock(
|
||||
__aenter__=mocker.AsyncMock(return_value=mock_tx),
|
||||
__aexit__=mocker.AsyncMock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
mock_sl = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_sl.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||
|
||||
mock_slv = mocker.patch("prisma.models.StoreListingVersion.prisma")
|
||||
mock_slv.return_value.create = mocker.AsyncMock(return_value=mock_version)
|
||||
|
||||
# Call function
|
||||
result = await db.create_store_submission(
|
||||
user_id="user-id",
|
||||
agent_id="agent-id",
|
||||
agent_version=1,
|
||||
graph_id="agent-id",
|
||||
graph_version=1,
|
||||
slug="test-agent",
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
@@ -281,11 +266,11 @@ async def test_create_store_submission(mocker):
|
||||
# Verify results
|
||||
assert result.name == "Test Agent"
|
||||
assert result.description == "Test description"
|
||||
assert result.store_listing_version_id == "version-id"
|
||||
assert result.listing_version_id == "version-id"
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_agent_graph.return_value.find_first.assert_called_once()
|
||||
mock_store_listing.return_value.create.assert_called_once()
|
||||
mock_slv.return_value.create.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -318,7 +303,6 @@ async def test_update_profile(mocker):
|
||||
description="Test description",
|
||||
links=["link1"],
|
||||
avatar_url="avatar.jpg",
|
||||
is_featured=False,
|
||||
)
|
||||
|
||||
# Call function
|
||||
@@ -389,7 +373,7 @@ async def test_get_store_agents_with_search_and_filters_parameterized():
|
||||
creators=["creator1'; DROP TABLE Users; --", "creator2"],
|
||||
category="AI'; DELETE FROM StoreAgent; --",
|
||||
featured=True,
|
||||
sorted_by="rating",
|
||||
sorted_by=db.StoreAgentsSortOptions.RATING,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
@@ -57,12 +57,6 @@ class StoreError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class AgentNotFoundError(NotFoundError):
|
||||
"""Raised when an agent is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CreatorNotFoundError(NotFoundError):
|
||||
"""Raised when a creator is not found"""
|
||||
|
||||
|
||||
@@ -568,7 +568,7 @@ async def hybrid_search(
|
||||
SELECT uce."contentId" as "storeListingVersionId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON uce."contentId" = sa."storeListingVersionId"
|
||||
ON uce."contentId" = sa.listing_version_id
|
||||
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
AND uce."userId" IS NULL
|
||||
AND uce.search @@ plainto_tsquery('english', {query_param})
|
||||
@@ -582,7 +582,7 @@ async def hybrid_search(
|
||||
SELECT uce."contentId", uce.embedding
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON uce."contentId" = sa."storeListingVersionId"
|
||||
ON uce."contentId" = sa.listing_version_id
|
||||
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
AND uce."userId" IS NULL
|
||||
AND {where_clause}
|
||||
@@ -605,7 +605,7 @@ async def hybrid_search(
|
||||
sa.featured,
|
||||
sa.is_available,
|
||||
sa.updated_at,
|
||||
sa."agentGraphId",
|
||||
sa.graph_id,
|
||||
-- Searchable text for BM25 reranking
|
||||
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
|
||||
-- Semantic score
|
||||
@@ -627,9 +627,9 @@ async def hybrid_search(
|
||||
sa.runs as popularity_raw
|
||||
FROM candidates c
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON c."storeListingVersionId" = sa."storeListingVersionId"
|
||||
ON c."storeListingVersionId" = sa.listing_version_id
|
||||
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
ON sa."storeListingVersionId" = uce."contentId"
|
||||
ON sa.listing_version_id = uce."contentId"
|
||||
AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
),
|
||||
max_vals AS (
|
||||
@@ -665,7 +665,7 @@ async def hybrid_search(
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
"agentGraphId",
|
||||
graph_id,
|
||||
searchable_text,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import datetime
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING, List, Self
|
||||
|
||||
import prisma.enums
|
||||
import pydantic
|
||||
|
||||
from backend.util.models import Pagination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import prisma.models
|
||||
|
||||
|
||||
class ChangelogEntry(pydantic.BaseModel):
|
||||
version: str
|
||||
@@ -13,9 +16,9 @@ class ChangelogEntry(pydantic.BaseModel):
|
||||
date: datetime.datetime
|
||||
|
||||
|
||||
class MyAgent(pydantic.BaseModel):
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
class MyUnpublishedAgent(pydantic.BaseModel):
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
agent_name: str
|
||||
agent_image: str | None = None
|
||||
description: str
|
||||
@@ -23,8 +26,8 @@ class MyAgent(pydantic.BaseModel):
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
|
||||
class MyAgentsResponse(pydantic.BaseModel):
|
||||
agents: list[MyAgent]
|
||||
class MyUnpublishedAgentsResponse(pydantic.BaseModel):
|
||||
agents: list[MyUnpublishedAgent]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
@@ -40,6 +43,21 @@ class StoreAgent(pydantic.BaseModel):
|
||||
rating: float
|
||||
agent_graph_id: str
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgent":
|
||||
return cls(
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_image=agent.agent_image[0] if agent.agent_image else "",
|
||||
creator=agent.creator_username or "Needs Profile",
|
||||
creator_avatar=agent.creator_avatar or "",
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
agent_graph_id=agent.graph_id,
|
||||
)
|
||||
|
||||
|
||||
class StoreAgentsResponse(pydantic.BaseModel):
|
||||
agents: list[StoreAgent]
|
||||
@@ -62,81 +80,192 @@ class StoreAgentDetails(pydantic.BaseModel):
|
||||
runs: int
|
||||
rating: float
|
||||
versions: list[str]
|
||||
agentGraphVersions: list[str]
|
||||
agentGraphId: str
|
||||
graph_id: str
|
||||
graph_versions: list[str]
|
||||
last_updated: datetime.datetime
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
active_version_id: str | None = None
|
||||
has_approved_version: bool = False
|
||||
active_version_id: str
|
||||
has_approved_version: bool
|
||||
|
||||
# Optional changelog data when include_changelog=True
|
||||
changelog: list[ChangelogEntry] | None = None
|
||||
|
||||
|
||||
class Creator(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
avatar_url: str
|
||||
num_agents: int
|
||||
agent_rating: float
|
||||
agent_runs: int
|
||||
is_featured: bool
|
||||
|
||||
|
||||
class CreatorsResponse(pydantic.BaseModel):
|
||||
creators: List[Creator]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class CreatorDetails(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
links: list[str]
|
||||
avatar_url: str
|
||||
agent_rating: float
|
||||
agent_runs: int
|
||||
top_categories: list[str]
|
||||
@classmethod
|
||||
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgentDetails":
|
||||
return cls(
|
||||
store_listing_version_id=agent.listing_version_id,
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_video=agent.agent_video or "",
|
||||
agent_output_demo=agent.agent_output_demo or "",
|
||||
agent_image=agent.agent_image,
|
||||
creator=agent.creator_username or "",
|
||||
creator_avatar=agent.creator_avatar or "",
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
categories=agent.categories,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
versions=agent.versions,
|
||||
graph_id=agent.graph_id,
|
||||
graph_versions=agent.graph_versions,
|
||||
last_updated=agent.updated_at,
|
||||
recommended_schedule_cron=agent.recommended_schedule_cron,
|
||||
active_version_id=agent.listing_version_id,
|
||||
has_approved_version=True, # StoreAgent view only has approved agents
|
||||
)
|
||||
|
||||
|
||||
class Profile(pydantic.BaseModel):
|
||||
name: str
|
||||
"""Marketplace user profile (only attributes that the user can update)"""
|
||||
|
||||
username: str
|
||||
name: str
|
||||
description: str
|
||||
avatar_url: str | None
|
||||
links: list[str]
|
||||
avatar_url: str
|
||||
is_featured: bool = False
|
||||
|
||||
|
||||
class ProfileDetails(Profile):
|
||||
"""Marketplace user profile (including read-only fields)"""
|
||||
|
||||
is_featured: bool
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, profile: "prisma.models.Profile") -> "ProfileDetails":
|
||||
return cls(
|
||||
name=profile.name,
|
||||
username=profile.username,
|
||||
avatar_url=profile.avatarUrl,
|
||||
description=profile.description,
|
||||
links=profile.links,
|
||||
is_featured=profile.isFeatured,
|
||||
)
|
||||
|
||||
|
||||
class CreatorDetails(ProfileDetails):
|
||||
"""Marketplace creator profile details, including aggregated stats"""
|
||||
|
||||
num_agents: int
|
||||
agent_runs: int
|
||||
agent_rating: float
|
||||
top_categories: list[str]
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, creator: "prisma.models.Creator") -> "CreatorDetails": # type: ignore[override]
|
||||
return cls(
|
||||
name=creator.name,
|
||||
username=creator.username,
|
||||
avatar_url=creator.avatar_url,
|
||||
description=creator.description,
|
||||
links=creator.links,
|
||||
is_featured=creator.is_featured,
|
||||
num_agents=creator.num_agents,
|
||||
agent_runs=creator.agent_runs,
|
||||
agent_rating=creator.agent_rating,
|
||||
top_categories=creator.top_categories,
|
||||
)
|
||||
|
||||
|
||||
class CreatorsResponse(pydantic.BaseModel):
|
||||
creators: List[CreatorDetails]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreSubmission(pydantic.BaseModel):
|
||||
# From StoreListing:
|
||||
listing_id: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
user_id: str
|
||||
slug: str
|
||||
|
||||
# From StoreListingVersion:
|
||||
listing_version_id: str
|
||||
listing_version: int
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
name: str
|
||||
sub_heading: str
|
||||
slug: str
|
||||
description: str
|
||||
instructions: str | None = None
|
||||
instructions: str | None
|
||||
categories: list[str]
|
||||
image_urls: list[str]
|
||||
date_submitted: datetime.datetime
|
||||
status: prisma.enums.SubmissionStatus
|
||||
runs: int
|
||||
rating: float
|
||||
store_listing_version_id: str | None = None
|
||||
version: int | None = None # Actual version number from the database
|
||||
video_url: str | None
|
||||
agent_output_demo_url: str | None
|
||||
|
||||
submitted_at: datetime.datetime | None
|
||||
changes_summary: str | None
|
||||
status: prisma.enums.SubmissionStatus
|
||||
reviewed_at: datetime.datetime | None = None
|
||||
reviewer_id: str | None = None
|
||||
review_comments: str | None = None # External comments visible to creator
|
||||
internal_comments: str | None = None # Private notes for admin use only
|
||||
reviewed_at: datetime.datetime | None = None
|
||||
changes_summary: str | None = None
|
||||
|
||||
# Additional fields for editing
|
||||
video_url: str | None = None
|
||||
agent_output_demo_url: str | None = None
|
||||
categories: list[str] = []
|
||||
# Aggregated from AgentGraphExecutions and StoreListingReviews:
|
||||
run_count: int = 0
|
||||
review_count: int = 0
|
||||
review_avg_rating: float = 0.0
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
|
||||
"""Construct from the StoreSubmission Prisma view."""
|
||||
return cls(
|
||||
listing_id=_sub.listing_id,
|
||||
user_id=_sub.user_id,
|
||||
slug=_sub.slug,
|
||||
listing_version_id=_sub.listing_version_id,
|
||||
listing_version=_sub.listing_version,
|
||||
graph_id=_sub.graph_id,
|
||||
graph_version=_sub.graph_version,
|
||||
name=_sub.name,
|
||||
sub_heading=_sub.sub_heading,
|
||||
description=_sub.description,
|
||||
instructions=_sub.instructions,
|
||||
categories=_sub.categories,
|
||||
image_urls=_sub.image_urls,
|
||||
video_url=_sub.video_url,
|
||||
agent_output_demo_url=_sub.agent_output_demo_url,
|
||||
submitted_at=_sub.submitted_at,
|
||||
changes_summary=_sub.changes_summary,
|
||||
status=_sub.status,
|
||||
reviewed_at=_sub.reviewed_at,
|
||||
reviewer_id=_sub.reviewer_id,
|
||||
review_comments=_sub.review_comments,
|
||||
run_count=_sub.run_count,
|
||||
review_count=_sub.review_count,
|
||||
review_avg_rating=_sub.review_avg_rating,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
|
||||
"""
|
||||
Construct from the StoreListingVersion Prisma model (with StoreListing included)
|
||||
"""
|
||||
if not (_l := _lv.StoreListing):
|
||||
raise ValueError("StoreListingVersion must have included StoreListing")
|
||||
|
||||
return cls(
|
||||
listing_id=_l.id,
|
||||
user_id=_l.owningUserId,
|
||||
slug=_l.slug,
|
||||
listing_version_id=_lv.id,
|
||||
listing_version=_lv.version,
|
||||
graph_id=_lv.agentGraphId,
|
||||
graph_version=_lv.agentGraphVersion,
|
||||
name=_lv.name,
|
||||
sub_heading=_lv.subHeading,
|
||||
description=_lv.description,
|
||||
instructions=_lv.instructions,
|
||||
categories=_lv.categories,
|
||||
image_urls=_lv.imageUrls,
|
||||
video_url=_lv.videoUrl,
|
||||
agent_output_demo_url=_lv.agentOutputDemoUrl,
|
||||
submitted_at=_lv.submittedAt,
|
||||
changes_summary=_lv.changesSummary,
|
||||
status=_lv.submissionStatus,
|
||||
reviewed_at=_lv.reviewedAt,
|
||||
reviewer_id=_lv.reviewerId,
|
||||
review_comments=_lv.reviewComments,
|
||||
)
|
||||
|
||||
|
||||
class StoreSubmissionsResponse(pydantic.BaseModel):
|
||||
@@ -144,33 +273,12 @@ class StoreSubmissionsResponse(pydantic.BaseModel):
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreListingWithVersions(pydantic.BaseModel):
|
||||
"""A store listing with its version history"""
|
||||
|
||||
listing_id: str
|
||||
slug: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
active_version_id: str | None = None
|
||||
has_approved_version: bool = False
|
||||
creator_email: str | None = None
|
||||
latest_version: StoreSubmission | None = None
|
||||
versions: list[StoreSubmission] = []
|
||||
|
||||
|
||||
class StoreListingsWithVersionsResponse(pydantic.BaseModel):
|
||||
"""Response model for listings with version history"""
|
||||
|
||||
listings: list[StoreListingWithVersions]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreSubmissionRequest(pydantic.BaseModel):
|
||||
agent_id: str = pydantic.Field(
|
||||
..., min_length=1, description="Agent ID cannot be empty"
|
||||
graph_id: str = pydantic.Field(
|
||||
..., min_length=1, description="Graph ID cannot be empty"
|
||||
)
|
||||
agent_version: int = pydantic.Field(
|
||||
..., gt=0, description="Agent version must be greater than 0"
|
||||
graph_version: int = pydantic.Field(
|
||||
..., gt=0, description="Graph version must be greater than 0"
|
||||
)
|
||||
slug: str
|
||||
name: str
|
||||
@@ -198,12 +306,42 @@ class StoreSubmissionEditRequest(pydantic.BaseModel):
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
|
||||
class ProfileDetails(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
links: list[str]
|
||||
avatar_url: str | None = None
|
||||
class StoreSubmissionAdminView(StoreSubmission):
|
||||
internal_comments: str | None # Private admin notes
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
|
||||
return cls(
|
||||
**StoreSubmission.from_db(_sub).model_dump(),
|
||||
internal_comments=_sub.internal_comments,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
|
||||
return cls(
|
||||
**StoreSubmission.from_listing_version(_lv).model_dump(),
|
||||
internal_comments=_lv.internalComments,
|
||||
)
|
||||
|
||||
|
||||
class StoreListingWithVersionsAdminView(pydantic.BaseModel):
|
||||
"""A store listing with its version history"""
|
||||
|
||||
listing_id: str
|
||||
graph_id: str
|
||||
slug: str
|
||||
active_listing_version_id: str | None = None
|
||||
has_approved_version: bool = False
|
||||
creator_email: str | None = None
|
||||
latest_version: StoreSubmissionAdminView | None = None
|
||||
versions: list[StoreSubmissionAdminView] = []
|
||||
|
||||
|
||||
class StoreListingsWithVersionsAdminViewResponse(pydantic.BaseModel):
|
||||
"""Response model for listings with version history"""
|
||||
|
||||
listings: list[StoreListingWithVersionsAdminView]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreReview(pydantic.BaseModel):
|
||||
|
||||
@@ -1,203 +0,0 @@
|
||||
import datetime
|
||||
|
||||
import prisma.enums
|
||||
|
||||
from . import model as store_model
|
||||
|
||||
|
||||
def test_pagination():
|
||||
pagination = store_model.Pagination(
|
||||
total_items=100, total_pages=5, current_page=2, page_size=20
|
||||
)
|
||||
assert pagination.total_items == 100
|
||||
assert pagination.total_pages == 5
|
||||
assert pagination.current_page == 2
|
||||
assert pagination.page_size == 20
|
||||
|
||||
|
||||
def test_store_agent():
|
||||
agent = store_model.StoreAgent(
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_image="test.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
agent_graph_id="test-graph-id",
|
||||
)
|
||||
assert agent.slug == "test-agent"
|
||||
assert agent.agent_name == "Test Agent"
|
||||
assert agent.runs == 50
|
||||
assert agent.rating == 4.5
|
||||
assert agent.agent_graph_id == "test-graph-id"
|
||||
|
||||
|
||||
def test_store_agents_response():
|
||||
response = store_model.StoreAgentsResponse(
|
||||
agents=[
|
||||
store_model.StoreAgent(
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_image="test.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
agent_graph_id="test-graph-id",
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
assert len(response.agents) == 1
|
||||
assert response.pagination.total_items == 1
|
||||
|
||||
|
||||
def test_store_agent_details():
|
||||
details = store_model.StoreAgentDetails(
|
||||
store_listing_version_id="version123",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video="video.mp4",
|
||||
agent_output_demo="demo.mp4",
|
||||
agent_image=["image1.jpg", "image2.jpg"],
|
||||
creator="creator1",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
categories=["cat1", "cat2"],
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
versions=["1.0", "2.0"],
|
||||
agentGraphVersions=["1", "2"],
|
||||
agentGraphId="test-graph-id",
|
||||
last_updated=datetime.datetime.now(),
|
||||
)
|
||||
assert details.slug == "test-agent"
|
||||
assert len(details.agent_image) == 2
|
||||
assert len(details.categories) == 2
|
||||
assert len(details.versions) == 2
|
||||
|
||||
|
||||
def test_creator():
|
||||
creator = store_model.Creator(
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
name="Test Creator",
|
||||
username="creator1",
|
||||
description="Test description",
|
||||
avatar_url="avatar.jpg",
|
||||
num_agents=5,
|
||||
is_featured=False,
|
||||
)
|
||||
assert creator.name == "Test Creator"
|
||||
assert creator.num_agents == 5
|
||||
|
||||
|
||||
def test_creators_response():
|
||||
response = store_model.CreatorsResponse(
|
||||
creators=[
|
||||
store_model.Creator(
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
name="Test Creator",
|
||||
username="creator1",
|
||||
description="Test description",
|
||||
avatar_url="avatar.jpg",
|
||||
num_agents=5,
|
||||
is_featured=False,
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
assert len(response.creators) == 1
|
||||
assert response.pagination.total_items == 1
|
||||
|
||||
|
||||
def test_creator_details():
|
||||
details = store_model.CreatorDetails(
|
||||
name="Test Creator",
|
||||
username="creator1",
|
||||
description="Test description",
|
||||
links=["link1.com", "link2.com"],
|
||||
avatar_url="avatar.jpg",
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
top_categories=["cat1", "cat2"],
|
||||
)
|
||||
assert details.name == "Test Creator"
|
||||
assert len(details.links) == 2
|
||||
assert details.agent_rating == 4.8
|
||||
assert len(details.top_categories) == 2
|
||||
|
||||
|
||||
def test_store_submission():
|
||||
submission = store_model.StoreSubmission(
|
||||
listing_id="listing123",
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
description="Test description",
|
||||
image_urls=["image1.jpg", "image2.jpg"],
|
||||
date_submitted=datetime.datetime(2023, 1, 1),
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
)
|
||||
assert submission.name == "Test Agent"
|
||||
assert len(submission.image_urls) == 2
|
||||
assert submission.status == prisma.enums.SubmissionStatus.PENDING
|
||||
|
||||
|
||||
def test_store_submissions_response():
|
||||
response = store_model.StoreSubmissionsResponse(
|
||||
submissions=[
|
||||
store_model.StoreSubmission(
|
||||
listing_id="listing123",
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
description="Test description",
|
||||
image_urls=["image1.jpg"],
|
||||
date_submitted=datetime.datetime(2023, 1, 1),
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
assert len(response.submissions) == 1
|
||||
assert response.pagination.total_items == 1
|
||||
|
||||
|
||||
def test_store_submission_request():
|
||||
request = store_model.StoreSubmissionRequest(
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
slug="test-agent",
|
||||
name="Test Agent",
|
||||
sub_heading="Test subheading",
|
||||
video_url="video.mp4",
|
||||
image_urls=["image1.jpg", "image2.jpg"],
|
||||
description="Test description",
|
||||
categories=["cat1", "cat2"],
|
||||
)
|
||||
assert request.agent_id == "agent123"
|
||||
assert request.agent_version == 1
|
||||
assert len(request.image_urls) == 2
|
||||
assert len(request.categories) == 2
|
||||
@@ -1,16 +1,17 @@
|
||||
import logging
|
||||
import tempfile
|
||||
import typing
|
||||
import urllib.parse
|
||||
from typing import Literal
|
||||
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
import prisma.enums
|
||||
from fastapi import Query, Security
|
||||
from pydantic import BaseModel
|
||||
|
||||
import backend.data.graph
|
||||
import backend.util.json
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from . import cache as store_cache
|
||||
@@ -34,22 +35,15 @@ router = fastapi.APIRouter()
|
||||
"/profile",
|
||||
summary="Get user profile",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.ProfileDetails,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_profile(
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Get the profile details for the authenticated user.
|
||||
Cached for 1 hour per user.
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.ProfileDetails:
|
||||
"""Get the profile details for the authenticated user."""
|
||||
profile = await store_db.get_user_profile(user_id)
|
||||
if profile is None:
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=404,
|
||||
content={"detail": "Profile not found"},
|
||||
)
|
||||
raise NotFoundError("User does not have a profile yet")
|
||||
return profile
|
||||
|
||||
|
||||
@@ -57,98 +51,17 @@ async def get_profile(
|
||||
"/profile",
|
||||
summary="Update user profile",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.CreatorDetails,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def update_or_create_profile(
|
||||
profile: store_model.Profile,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Update the store profile for the authenticated user.
|
||||
|
||||
Args:
|
||||
profile (Profile): The updated profile details
|
||||
user_id (str): ID of the authenticated user
|
||||
|
||||
Returns:
|
||||
CreatorDetails: The updated profile
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error updating the profile
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.ProfileDetails:
|
||||
"""Update the store profile for the authenticated user."""
|
||||
updated_profile = await store_db.update_profile(user_id=user_id, profile=profile)
|
||||
return updated_profile
|
||||
|
||||
|
||||
##############################################
|
||||
############### Agent Endpoints ##############
|
||||
##############################################
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents",
|
||||
summary="List store agents",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.StoreAgentsResponse,
|
||||
)
|
||||
async def get_agents(
|
||||
featured: bool = False,
|
||||
creator: str | None = None,
|
||||
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
|
||||
search_query: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""
|
||||
Get a paginated list of agents from the store with optional filtering and sorting.
|
||||
|
||||
Args:
|
||||
featured (bool, optional): Filter to only show featured agents. Defaults to False.
|
||||
creator (str | None, optional): Filter agents by creator username. Defaults to None.
|
||||
sorted_by (str | None, optional): Sort agents by "runs" or "rating". Defaults to None.
|
||||
search_query (str | None, optional): Search agents by name, subheading and description. Defaults to None.
|
||||
category (str | None, optional): Filter agents by category. Defaults to None.
|
||||
page (int, optional): Page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): Number of agents per page. Defaults to 20.
|
||||
|
||||
Returns:
|
||||
StoreAgentsResponse: Paginated list of agents matching the filters
|
||||
|
||||
Raises:
|
||||
HTTPException: If page or page_size are less than 1
|
||||
|
||||
Used for:
|
||||
- Home Page Featured Agents
|
||||
- Home Page Top Agents
|
||||
- Search Results
|
||||
- Agent Details - Other Agents By Creator
|
||||
- Agent Details - Similar Agents
|
||||
- Creator Details - Agents By Creator
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
agents = await store_cache._get_cached_store_agents(
|
||||
featured=featured,
|
||||
creator=creator,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return agents
|
||||
|
||||
|
||||
##############################################
|
||||
############### Search Endpoints #############
|
||||
##############################################
|
||||
@@ -158,60 +71,30 @@ async def get_agents(
|
||||
"/search",
|
||||
summary="Unified search across all content types",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.UnifiedSearchResponse,
|
||||
)
|
||||
async def unified_search(
|
||||
query: str,
|
||||
content_types: list[str] | None = fastapi.Query(
|
||||
content_types: list[prisma.enums.ContentType] | None = Query(
|
||||
default=None,
|
||||
description="Content types to search: STORE_AGENT, BLOCK, DOCUMENTATION. If not specified, searches all.",
|
||||
description="Content types to search. If not specified, searches all.",
|
||||
),
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
user_id: str | None = fastapi.Security(
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
user_id: str | None = Security(
|
||||
autogpt_libs.auth.get_optional_user_id, use_cache=False
|
||||
),
|
||||
):
|
||||
) -> store_model.UnifiedSearchResponse:
|
||||
"""
|
||||
Search across all content types (store agents, blocks, documentation) using hybrid search.
|
||||
Search across all content types (marketplace agents, blocks, documentation)
|
||||
using hybrid search.
|
||||
|
||||
Combines semantic (embedding-based) and lexical (text-based) search for best results.
|
||||
|
||||
Args:
|
||||
query: The search query string
|
||||
content_types: Optional list of content types to filter by (STORE_AGENT, BLOCK, DOCUMENTATION)
|
||||
page: Page number for pagination (default 1)
|
||||
page_size: Number of results per page (default 20)
|
||||
user_id: Optional authenticated user ID (for user-scoped content in future)
|
||||
|
||||
Returns:
|
||||
UnifiedSearchResponse: Paginated list of search results with relevance scores
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
# Convert string content types to enum
|
||||
content_type_enums: list[prisma.enums.ContentType] | None = None
|
||||
if content_types:
|
||||
try:
|
||||
content_type_enums = [prisma.enums.ContentType(ct) for ct in content_types]
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Invalid content type. Valid values: STORE_AGENT, BLOCK, DOCUMENTATION. Error: {e}",
|
||||
)
|
||||
|
||||
# Perform unified hybrid search
|
||||
results, total = await store_hybrid_search.unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=content_type_enums,
|
||||
content_types=content_types,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
@@ -245,22 +128,69 @@ async def unified_search(
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
############### Agent Endpoints ##############
|
||||
##############################################
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents",
|
||||
summary="List store agents",
|
||||
tags=["store", "public"],
|
||||
)
|
||||
async def get_agents(
|
||||
featured: bool = Query(
|
||||
default=False, description="Filter to only show featured agents"
|
||||
),
|
||||
creator: str | None = Query(
|
||||
default=None, description="Filter agents by creator username"
|
||||
),
|
||||
category: str | None = Query(default=None, description="Filter agents by category"),
|
||||
search_query: str | None = Query(
|
||||
default=None, description="Literal + semantic search on names and descriptions"
|
||||
),
|
||||
sorted_by: store_db.StoreAgentsSortOptions | None = Query(
|
||||
default=None,
|
||||
description="Property to sort results by. Ignored if search_query is provided.",
|
||||
),
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
) -> store_model.StoreAgentsResponse:
|
||||
"""
|
||||
Get a paginated list of agents from the marketplace,
|
||||
with optional filtering and sorting.
|
||||
|
||||
Used for:
|
||||
- Home Page Featured Agents
|
||||
- Home Page Top Agents
|
||||
- Search Results
|
||||
- Agent Details - Other Agents By Creator
|
||||
- Agent Details - Similar Agents
|
||||
- Creator Details - Agents By Creator
|
||||
"""
|
||||
agents = await store_cache._get_cached_store_agents(
|
||||
featured=featured,
|
||||
creator=creator,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return agents
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents/{username}/{agent_name}",
|
||||
summary="Get specific agent",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.StoreAgentDetails,
|
||||
)
|
||||
async def get_agent(
|
||||
async def get_agent_by_name(
|
||||
username: str,
|
||||
agent_name: str,
|
||||
include_changelog: bool = fastapi.Query(default=False),
|
||||
):
|
||||
"""
|
||||
This is only used on the AgentDetails Page.
|
||||
|
||||
It returns the store listing agents details.
|
||||
"""
|
||||
include_changelog: bool = Query(default=False),
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""Get details of a marketplace agent"""
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
# URL decode the agent name since it comes from the URL path
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
@@ -270,76 +200,82 @@ async def get_agent(
|
||||
return agent
|
||||
|
||||
|
||||
@router.get(
|
||||
"/graph/{store_listing_version_id}",
|
||||
summary="Get agent graph",
|
||||
tags=["store"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_graph_meta_by_store_listing_version_id(
|
||||
store_listing_version_id: str,
|
||||
) -> backend.data.graph.GraphModelWithoutNodes:
|
||||
"""
|
||||
Get Agent Graph from Store Listing Version ID.
|
||||
"""
|
||||
graph = await store_db.get_available_graph(store_listing_version_id)
|
||||
return graph
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents/{store_listing_version_id}",
|
||||
summary="Get agent by version",
|
||||
tags=["store"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreAgentDetails,
|
||||
)
|
||||
async def get_store_agent(store_listing_version_id: str):
|
||||
"""
|
||||
Get Store Agent Details from Store Listing Version ID.
|
||||
"""
|
||||
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@router.post(
|
||||
"/agents/{username}/{agent_name}/review",
|
||||
summary="Create agent review",
|
||||
tags=["store"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreReview,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def create_review(
|
||||
async def post_user_review_for_agent(
|
||||
username: str,
|
||||
agent_name: str,
|
||||
review: store_model.StoreReviewCreate,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Create a review for a store agent.
|
||||
|
||||
Args:
|
||||
username: Creator's username
|
||||
agent_name: Name/slug of the agent
|
||||
review: Review details including score and optional comments
|
||||
user_id: ID of authenticated user creating the review
|
||||
|
||||
Returns:
|
||||
The created review
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.StoreReview:
|
||||
"""Post a user review on a marketplace agent listing"""
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
# Create the review
|
||||
|
||||
created_review = await store_db.create_store_review(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=review.store_listing_version_id,
|
||||
score=review.score,
|
||||
comments=review.comments,
|
||||
)
|
||||
|
||||
return created_review
|
||||
|
||||
|
||||
@router.get(
|
||||
"/listings/versions/{store_listing_version_id}",
|
||||
summary="Get agent by version",
|
||||
tags=["store"],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_agent_by_listing_version(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
|
||||
return agent
|
||||
|
||||
|
||||
@router.get(
|
||||
"/listings/versions/{store_listing_version_id}/graph",
|
||||
summary="Get agent graph",
|
||||
tags=["store"],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_graph_meta_by_store_listing_version_id(
|
||||
store_listing_version_id: str,
|
||||
) -> backend.data.graph.GraphModelWithoutNodes:
|
||||
"""Get outline of graph belonging to a specific marketplace listing version"""
|
||||
graph = await store_db.get_available_graph(store_listing_version_id)
|
||||
return graph
|
||||
|
||||
|
||||
@router.get(
|
||||
"/listings/versions/{store_listing_version_id}/graph/download",
|
||||
summary="Download agent file",
|
||||
tags=["store", "public"],
|
||||
)
|
||||
async def download_agent_file(
|
||||
store_listing_version_id: str,
|
||||
) -> fastapi.responses.FileResponse:
|
||||
"""Download agent graph file for a specific marketplace listing version"""
|
||||
graph_data = await store_db.get_agent(store_listing_version_id)
|
||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||
|
||||
# Sending graph as a stream (similar to marketplace v1)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(backend.util.json.dumps(graph_data))
|
||||
tmp_file.flush()
|
||||
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
############# Creator Endpoints #############
|
||||
##############################################
|
||||
@@ -349,37 +285,19 @@ async def create_review(
|
||||
"/creators",
|
||||
summary="List store creators",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.CreatorsResponse,
|
||||
)
|
||||
async def get_creators(
|
||||
featured: bool = False,
|
||||
search_query: str | None = None,
|
||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""
|
||||
This is needed for:
|
||||
- Home Page Featured Creators
|
||||
- Search Results Page
|
||||
|
||||
---
|
||||
|
||||
To support this functionality we need:
|
||||
- featured: bool - to limit the list to just featured agents
|
||||
- search_query: str - vector search based on the creators profile description.
|
||||
- sorted_by: [agent_rating, agent_runs] -
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
featured: bool = Query(
|
||||
default=False, description="Filter to only show featured creators"
|
||||
),
|
||||
search_query: str | None = Query(
|
||||
default=None, description="Literal + semantic search on names and descriptions"
|
||||
),
|
||||
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
) -> store_model.CreatorsResponse:
|
||||
"""List or search marketplace creators"""
|
||||
creators = await store_cache._get_cached_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
@@ -391,18 +309,12 @@ async def get_creators(
|
||||
|
||||
|
||||
@router.get(
|
||||
"/creator/{username}",
|
||||
"/creators/{username}",
|
||||
summary="Get creator details",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.CreatorDetails,
|
||||
)
|
||||
async def get_creator(
|
||||
username: str,
|
||||
):
|
||||
"""
|
||||
Get the details of a creator.
|
||||
- Creator Details Page
|
||||
"""
|
||||
async def get_creator(username: str) -> store_model.CreatorDetails:
|
||||
"""Get details on a marketplace creator"""
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
creator = await store_cache._get_cached_creator_details(username=username)
|
||||
return creator
|
||||
@@ -414,20 +326,17 @@ async def get_creator(
|
||||
|
||||
|
||||
@router.get(
|
||||
"/myagents",
|
||||
"/my-unpublished-agents",
|
||||
summary="Get my agents",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.MyAgentsResponse,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_my_agents(
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
page: typing.Annotated[int, fastapi.Query(ge=1)] = 1,
|
||||
page_size: typing.Annotated[int, fastapi.Query(ge=1)] = 20,
|
||||
):
|
||||
"""
|
||||
Get user's own agents.
|
||||
"""
|
||||
async def get_my_unpublished_agents(
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
) -> store_model.MyUnpublishedAgentsResponse:
|
||||
"""List the authenticated user's unpublished agents"""
|
||||
agents = await store_db.get_my_agents(user_id, page=page, page_size=page_size)
|
||||
return agents
|
||||
|
||||
@@ -436,28 +345,17 @@ async def get_my_agents(
|
||||
"/submissions/{submission_id}",
|
||||
summary="Delete store submission",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=bool,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def delete_submission(
|
||||
submission_id: str,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Delete a store listing submission.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the authenticated user
|
||||
submission_id (str): ID of the submission to be deleted
|
||||
|
||||
Returns:
|
||||
bool: True if the submission was successfully deleted, False otherwise
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> bool:
|
||||
"""Delete a marketplace listing submission"""
|
||||
result = await store_db.delete_store_submission(
|
||||
user_id=user_id,
|
||||
submission_id=submission_id,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -465,37 +363,14 @@ async def delete_submission(
|
||||
"/submissions",
|
||||
summary="List my submissions",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreSubmissionsResponse,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_submissions(
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""
|
||||
Get a paginated list of store submissions for the authenticated user.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the authenticated user
|
||||
page (int, optional): Page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): Number of submissions per page. Defaults to 20.
|
||||
|
||||
Returns:
|
||||
StoreListingsResponse: Paginated list of store submissions
|
||||
|
||||
Raises:
|
||||
HTTPException: If page or page_size are less than 1
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
) -> store_model.StoreSubmissionsResponse:
|
||||
"""List the authenticated user's marketplace listing submissions"""
|
||||
listings = await store_db.get_store_submissions(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
@@ -508,30 +383,17 @@ async def get_submissions(
|
||||
"/submissions",
|
||||
summary="Create store submission",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreSubmission,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def create_submission(
|
||||
submission_request: store_model.StoreSubmissionRequest,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Create a new store listing submission.
|
||||
|
||||
Args:
|
||||
submission_request (StoreSubmissionRequest): The submission details
|
||||
user_id (str): ID of the authenticated user submitting the listing
|
||||
|
||||
Returns:
|
||||
StoreSubmission: The created store submission
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error creating the submission
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.StoreSubmission:
|
||||
"""Submit a new marketplace listing for review"""
|
||||
result = await store_db.create_store_submission(
|
||||
user_id=user_id,
|
||||
agent_id=submission_request.agent_id,
|
||||
agent_version=submission_request.agent_version,
|
||||
graph_id=submission_request.graph_id,
|
||||
graph_version=submission_request.graph_version,
|
||||
slug=submission_request.slug,
|
||||
name=submission_request.name,
|
||||
video_url=submission_request.video_url,
|
||||
@@ -544,7 +406,6 @@ async def create_submission(
|
||||
changes_summary=submission_request.changes_summary or "Initial Submission",
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -552,28 +413,14 @@ async def create_submission(
|
||||
"/submissions/{store_listing_version_id}",
|
||||
summary="Edit store submission",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreSubmission,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def edit_submission(
|
||||
store_listing_version_id: str,
|
||||
submission_request: store_model.StoreSubmissionEditRequest,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Edit an existing store listing submission.
|
||||
|
||||
Args:
|
||||
store_listing_version_id (str): ID of the store listing version to edit
|
||||
submission_request (StoreSubmissionRequest): The updated submission details
|
||||
user_id (str): ID of the authenticated user editing the listing
|
||||
|
||||
Returns:
|
||||
StoreSubmission: The updated store submission
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error editing the submission
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.StoreSubmission:
|
||||
"""Update a pending marketplace listing submission"""
|
||||
result = await store_db.edit_store_submission(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
@@ -588,7 +435,6 @@ async def edit_submission(
|
||||
changes_summary=submission_request.changes_summary,
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -596,115 +442,61 @@ async def edit_submission(
|
||||
"/submissions/media",
|
||||
summary="Upload submission media",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def upload_submission_media(
|
||||
file: fastapi.UploadFile,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Upload media (images/videos) for a store listing submission.
|
||||
|
||||
Args:
|
||||
file (UploadFile): The media file to upload
|
||||
user_id (str): ID of the authenticated user uploading the media
|
||||
|
||||
Returns:
|
||||
str: URL of the uploaded media file
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error uploading the media
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> str:
|
||||
"""Upload media for a marketplace listing submission"""
|
||||
media_url = await store_media.upload_media(user_id=user_id, file=file)
|
||||
return media_url
|
||||
|
||||
|
||||
class ImageURLResponse(BaseModel):
|
||||
image_url: str
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/generate_image",
|
||||
summary="Generate submission image",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def generate_image(
|
||||
agent_id: str,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
) -> fastapi.responses.Response:
|
||||
graph_id: str,
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> ImageURLResponse:
|
||||
"""
|
||||
Generate an image for a store listing submission.
|
||||
|
||||
Args:
|
||||
agent_id (str): ID of the agent to generate an image for
|
||||
user_id (str): ID of the authenticated user
|
||||
|
||||
Returns:
|
||||
JSONResponse: JSON containing the URL of the generated image
|
||||
Generate an image for a marketplace listing submission based on the properties
|
||||
of a given graph.
|
||||
"""
|
||||
agent = await backend.data.graph.get_graph(
|
||||
graph_id=agent_id, version=None, user_id=user_id
|
||||
graph = await backend.data.graph.get_graph(
|
||||
graph_id=graph_id, version=None, user_id=user_id
|
||||
)
|
||||
|
||||
if not agent:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(f"Agent graph #{graph_id} not found")
|
||||
# Use .jpeg here since we are generating JPEG images
|
||||
filename = f"agent_{agent_id}.jpeg"
|
||||
filename = f"agent_{graph_id}.jpeg"
|
||||
|
||||
existing_url = await store_media.check_media_exists(user_id, filename)
|
||||
if existing_url:
|
||||
logger.info(f"Using existing image for agent {agent_id}")
|
||||
return fastapi.responses.JSONResponse(content={"image_url": existing_url})
|
||||
logger.info(f"Using existing image for agent graph {graph_id}")
|
||||
return ImageURLResponse(image_url=existing_url)
|
||||
# Generate agent image as JPEG
|
||||
image = await store_image_gen.generate_agent_image(agent=agent)
|
||||
image = await store_image_gen.generate_agent_image(agent=graph)
|
||||
|
||||
# Create UploadFile with the correct filename and content_type
|
||||
image_file = fastapi.UploadFile(
|
||||
file=image,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
image_url = await store_media.upload_media(
|
||||
user_id=user_id, file=image_file, use_file_name=True
|
||||
)
|
||||
|
||||
return fastapi.responses.JSONResponse(content={"image_url": image_url})
|
||||
|
||||
|
||||
@router.get(
|
||||
"/download/agents/{store_listing_version_id}",
|
||||
summary="Download agent file",
|
||||
tags=["store", "public"],
|
||||
)
|
||||
async def download_agent_file(
|
||||
store_listing_version_id: str = fastapi.Path(
|
||||
..., description="The ID of the agent to download"
|
||||
),
|
||||
) -> fastapi.responses.FileResponse:
|
||||
"""
|
||||
Download the agent file by streaming its content.
|
||||
|
||||
Args:
|
||||
store_listing_version_id (str): The ID of the agent to download
|
||||
|
||||
Returns:
|
||||
StreamingResponse: A streaming response containing the agent's graph data.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the agent is not found or an unexpected error occurs.
|
||||
"""
|
||||
graph_data = await store_db.get_agent(store_listing_version_id)
|
||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||
|
||||
# Sending graph as a stream (similar to marketplace v1)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(backend.util.json.dumps(graph_data))
|
||||
tmp_file.flush()
|
||||
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
return ImageURLResponse(image_url=image_url)
|
||||
|
||||
|
||||
##############################################
|
||||
|
||||
@@ -8,6 +8,8 @@ import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.api.features.store.db import StoreAgentsSortOptions
|
||||
|
||||
from . import model as store_model
|
||||
from . import routes as store_routes
|
||||
|
||||
@@ -196,7 +198,7 @@ def test_get_agents_sorted(
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=False,
|
||||
creators=None,
|
||||
sorted_by="runs",
|
||||
sorted_by=StoreAgentsSortOptions.RUNS,
|
||||
search_query=None,
|
||||
category=None,
|
||||
page=1,
|
||||
@@ -380,9 +382,11 @@ def test_get_agent_details(
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
versions=["1.0.0", "1.1.0"],
|
||||
agentGraphVersions=["1", "2"],
|
||||
agentGraphId="test-graph-id",
|
||||
graph_versions=["1", "2"],
|
||||
graph_id="test-graph-id",
|
||||
last_updated=FIXED_NOW,
|
||||
active_version_id="test-version-id",
|
||||
has_approved_version=True,
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agent_details")
|
||||
mock_db_call.return_value = mocked_value
|
||||
@@ -435,15 +439,17 @@ def test_get_creators_pagination(
|
||||
) -> None:
|
||||
mocked_value = store_model.CreatorsResponse(
|
||||
creators=[
|
||||
store_model.Creator(
|
||||
store_model.CreatorDetails(
|
||||
name=f"Creator {i}",
|
||||
username=f"creator{i}",
|
||||
description=f"Creator {i} description",
|
||||
avatar_url=f"avatar{i}.jpg",
|
||||
num_agents=1,
|
||||
agent_rating=4.5,
|
||||
agent_runs=100,
|
||||
description=f"Creator {i} description",
|
||||
links=[f"user{i}.link.com"],
|
||||
is_featured=False,
|
||||
num_agents=1,
|
||||
agent_runs=100,
|
||||
agent_rating=4.5,
|
||||
top_categories=["cat1", "cat2", "cat3"],
|
||||
)
|
||||
for i in range(5)
|
||||
],
|
||||
@@ -496,19 +502,19 @@ def test_get_creator_details(
|
||||
mocked_value = store_model.CreatorDetails(
|
||||
name="Test User",
|
||||
username="creator1",
|
||||
avatar_url="avatar.jpg",
|
||||
description="Test creator description",
|
||||
links=["link1.com", "link2.com"],
|
||||
avatar_url="avatar.jpg",
|
||||
agent_rating=4.8,
|
||||
is_featured=True,
|
||||
num_agents=5,
|
||||
agent_runs=1000,
|
||||
agent_rating=4.8,
|
||||
top_categories=["category1", "category2"],
|
||||
)
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.api.features.store.db.get_store_creator_details"
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_creator")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/creator/creator1")
|
||||
response = client.get("/creators/creator1")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = store_model.CreatorDetails.model_validate(response.json())
|
||||
@@ -528,19 +534,26 @@ def test_get_submissions_success(
|
||||
submissions=[
|
||||
store_model.StoreSubmission(
|
||||
listing_id="test-listing-id",
|
||||
name="Test Agent",
|
||||
description="Test agent description",
|
||||
image_urls=["test.jpg"],
|
||||
date_submitted=FIXED_NOW,
|
||||
status=prisma.enums.SubmissionStatus.APPROVED,
|
||||
runs=50,
|
||||
rating=4.2,
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
sub_heading="Test agent subheading",
|
||||
user_id="test-user-id",
|
||||
slug="test-agent",
|
||||
video_url="test.mp4",
|
||||
listing_version_id="test-version-id",
|
||||
listing_version=1,
|
||||
graph_id="test-agent-id",
|
||||
graph_version=1,
|
||||
name="Test Agent",
|
||||
sub_heading="Test agent subheading",
|
||||
description="Test agent description",
|
||||
instructions="Click the button!",
|
||||
categories=["test-category"],
|
||||
image_urls=["test.jpg"],
|
||||
video_url="test.mp4",
|
||||
agent_output_demo_url="demo_video.mp4",
|
||||
submitted_at=FIXED_NOW,
|
||||
changes_summary="Initial Submission",
|
||||
status=prisma.enums.SubmissionStatus.APPROVED,
|
||||
run_count=50,
|
||||
review_count=5,
|
||||
review_avg_rating=4.2,
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
|
||||
@@ -11,6 +11,7 @@ import pytest
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from . import cache as store_cache
|
||||
from .db import StoreAgentsSortOptions
|
||||
from .model import StoreAgent, StoreAgentsResponse
|
||||
|
||||
|
||||
@@ -215,7 +216,7 @@ class TestCacheDeletion:
|
||||
await store_cache._get_cached_store_agents(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by="rating",
|
||||
sorted_by=StoreAgentsSortOptions.RATING,
|
||||
search_query="AI assistant",
|
||||
category="productivity",
|
||||
page=2,
|
||||
@@ -227,7 +228,7 @@ class TestCacheDeletion:
|
||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by="rating",
|
||||
sorted_by=StoreAgentsSortOptions.RATING,
|
||||
search_query="AI assistant",
|
||||
category="productivity",
|
||||
page=2,
|
||||
@@ -239,7 +240,7 @@ class TestCacheDeletion:
|
||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by="rating",
|
||||
sorted_by=StoreAgentsSortOptions.RATING,
|
||||
search_query="AI assistant",
|
||||
category="productivity",
|
||||
page=2,
|
||||
|
||||
@@ -29,6 +29,7 @@ from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
from backend.api.model import (
|
||||
BusinessUnderstandingPromptsResponse,
|
||||
CreateAPIKeyRequest,
|
||||
CreateAPIKeyResponse,
|
||||
CreateGraph,
|
||||
@@ -54,6 +55,7 @@ from backend.data.credit import (
|
||||
get_user_credit_model,
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.db_accessors import understanding_db
|
||||
from backend.data.graph import GraphSettings
|
||||
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
@@ -158,6 +160,22 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
return user.model_dump()
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/auth/user/understanding/prompts",
|
||||
summary="Get business understanding prompts",
|
||||
tags=["auth"],
|
||||
dependencies=[Security(requires_user)],
|
||||
response_model=BusinessUnderstandingPromptsResponse,
|
||||
)
|
||||
async def get_business_understanding_prompts_route(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> BusinessUnderstandingPromptsResponse:
|
||||
understanding = await understanding_db().get_business_understanding(user_id)
|
||||
return BusinessUnderstandingPromptsResponse(
|
||||
prompts=understanding.prompts if understanding else []
|
||||
)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/auth/user/email",
|
||||
summary="Update user email",
|
||||
@@ -449,7 +467,6 @@ async def execute_graph_block(
|
||||
async def upload_file(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
file: UploadFile = File(...),
|
||||
provider: str = "gcs",
|
||||
expiration_hours: int = 24,
|
||||
) -> UploadFileResponse:
|
||||
"""
|
||||
@@ -512,7 +529,6 @@ async def upload_file(
|
||||
storage_path = await cloud_storage.store_file(
|
||||
content=content,
|
||||
filename=file_name,
|
||||
provider=provider,
|
||||
expiration_hours=expiration_hours,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@@ -89,6 +89,48 @@ def test_update_user_email_route(
|
||||
)
|
||||
|
||||
|
||||
def test_get_business_understanding_prompts_route(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_understanding_db = Mock()
|
||||
mock_understanding_db.get_business_understanding = AsyncMock(
|
||||
return_value=Mock(prompts=["Prompt one", "Prompt two", "Prompt three"])
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.understanding_db",
|
||||
return_value=mock_understanding_db,
|
||||
)
|
||||
|
||||
response = client.get("/auth/user/understanding/prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": ["Prompt one", "Prompt two", "Prompt three"]}
|
||||
mock_understanding_db.get_business_understanding.assert_awaited_once_with(
|
||||
test_user_id
|
||||
)
|
||||
|
||||
|
||||
def test_get_business_understanding_prompts_route_returns_empty_list(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_understanding_db = Mock()
|
||||
mock_understanding_db.get_business_understanding = AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.understanding_db",
|
||||
return_value=mock_understanding_db,
|
||||
)
|
||||
|
||||
response = client.get("/auth/user/understanding/prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": []}
|
||||
mock_understanding_db.get_business_understanding.assert_awaited_once_with(
|
||||
test_user_id
|
||||
)
|
||||
|
||||
|
||||
# Blocks endpoints tests
|
||||
def test_get_graph_blocks(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
@@ -515,7 +557,6 @@ async def test_upload_file_success(test_user_id: str):
|
||||
result = await upload_file(
|
||||
file=upload_file_mock,
|
||||
user_id=test_user_id,
|
||||
provider="gcs",
|
||||
expiration_hours=24,
|
||||
)
|
||||
|
||||
@@ -533,7 +574,6 @@ async def test_upload_file_success(test_user_id: str):
|
||||
mock_handler.store_file.assert_called_once_with(
|
||||
content=file_content,
|
||||
filename="test.txt",
|
||||
provider="gcs",
|
||||
expiration_hours=24,
|
||||
user_id=test_user_id,
|
||||
)
|
||||
|
||||
@@ -85,6 +85,10 @@ class UpdateTimezoneRequest(pydantic.BaseModel):
|
||||
timezone: TimeZoneName
|
||||
|
||||
|
||||
class BusinessUnderstandingPromptsResponse(pydantic.BaseModel):
|
||||
prompts: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class NotificationPayload(pydantic.BaseModel):
|
||||
type: str
|
||||
event: str
|
||||
|
||||
@@ -55,6 +55,7 @@ from backend.util.exceptions import (
|
||||
MissingConfigError,
|
||||
NotAuthorizedError,
|
||||
NotFoundError,
|
||||
PreconditionFailed,
|
||||
)
|
||||
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
||||
from backend.util.service import UnhealthyServiceError
|
||||
@@ -275,6 +276,7 @@ app.add_exception_handler(RequestValidationError, validation_error_handler)
|
||||
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
|
||||
app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
|
||||
app.add_exception_handler(ValueError, handle_internal_http_error(400))
|
||||
app.add_exception_handler(PreconditionFailed, handle_internal_http_error(428))
|
||||
app.add_exception_handler(Exception, handle_internal_http_error(500))
|
||||
|
||||
app.include_router(backend.api.features.v1.v1_router, tags=["v1"], prefix="/api")
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.store.db import StoreAgentsSortOptions
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -176,8 +176,8 @@ class SearchStoreAgentsBlock(Block):
|
||||
category: str | None = SchemaField(
|
||||
description="Filter by category", default=None
|
||||
)
|
||||
sort_by: Literal["rating", "runs", "name", "updated_at"] = SchemaField(
|
||||
description="How to sort the results", default="rating"
|
||||
sort_by: StoreAgentsSortOptions = SchemaField(
|
||||
description="How to sort the results", default=StoreAgentsSortOptions.RATING
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of results to return", default=10, ge=1, le=100
|
||||
@@ -278,7 +278,7 @@ class SearchStoreAgentsBlock(Block):
|
||||
self,
|
||||
query: str | None = None,
|
||||
category: str | None = None,
|
||||
sort_by: Literal["rating", "runs", "name", "updated_at"] = "rating",
|
||||
sort_by: StoreAgentsSortOptions = StoreAgentsSortOptions.RATING,
|
||||
limit: int = 10,
|
||||
) -> SearchAgentsResponse:
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.store.db import StoreAgentsSortOptions
|
||||
from backend.blocks.system.library_operations import (
|
||||
AddToLibraryFromStoreBlock,
|
||||
LibraryAgent,
|
||||
@@ -121,7 +122,10 @@ async def test_search_store_agents_block(mocker):
|
||||
)
|
||||
|
||||
input_data = block.Input(
|
||||
query="test", category="productivity", sort_by="rating", limit=10
|
||||
query="test",
|
||||
category="productivity",
|
||||
sort_by=StoreAgentsSortOptions.RATING, # type: ignore[reportArgumentType]
|
||||
limit=10,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
|
||||
@@ -22,6 +22,7 @@ from backend.copilot.model import (
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
@@ -176,14 +177,17 @@ async def stream_chat_completion_baseline(
|
||||
# changes from concurrent chats updating business understanding.
|
||||
is_first_turn = len(session.messages) <= 1
|
||||
if is_first_turn:
|
||||
system_prompt, _ = await _build_system_prompt(
|
||||
base_system_prompt, _ = await _build_system_prompt(
|
||||
user_id, has_conversation_history=False
|
||||
)
|
||||
else:
|
||||
system_prompt, _ = await _build_system_prompt(
|
||||
base_system_prompt, _ = await _build_system_prompt(
|
||||
user_id=None, has_conversation_history=True
|
||||
)
|
||||
|
||||
# Append tool documentation and technical notes
|
||||
system_prompt = base_system_prompt + get_baseline_supplement()
|
||||
|
||||
# Compress context if approaching the model's token limit
|
||||
messages_for_context = await _compress_session_messages(session.messages)
|
||||
|
||||
|
||||
191
autogpt_platform/backend/backend/copilot/prompting.py
Normal file
191
autogpt_platform/backend/backend/copilot/prompting.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Centralized prompt building logic for CoPilot.
|
||||
|
||||
This module contains all prompt construction functions and constants,
|
||||
handling the distinction between:
|
||||
- SDK mode vs Baseline mode (tool documentation needs)
|
||||
- Local mode vs E2B mode (storage/filesystem differences)
|
||||
"""
|
||||
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
# Shared technical notes that apply to both SDK and baseline modes
|
||||
_SHARED_TOOL_NOTES = """\
|
||||
|
||||
### Sharing files with the user
|
||||
After saving a file to the persistent workspace with `write_workspace_file`,
|
||||
share it with the user by embedding the `download_url` from the response in
|
||||
your message as a Markdown link or image:
|
||||
|
||||
- **Any file** — shows as a clickable download link:
|
||||
`[report.csv](workspace://file_id#text/csv)`
|
||||
- **Image** — renders inline in chat:
|
||||
``
|
||||
- **Video** — renders inline in chat with player controls:
|
||||
``
|
||||
|
||||
The `download_url` field in the `write_workspace_file` response is already
|
||||
in the correct format — paste it directly after the `(` in the Markdown.
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
"""
|
||||
|
||||
|
||||
# Environment-specific supplement templates
|
||||
def _build_storage_supplement(
|
||||
working_dir: str,
|
||||
sandbox_type: str,
|
||||
storage_system_1_name: str,
|
||||
storage_system_1_characteristics: list[str],
|
||||
storage_system_1_persistence: list[str],
|
||||
file_move_name_1_to_2: str,
|
||||
file_move_name_2_to_1: str,
|
||||
) -> str:
|
||||
"""Build storage/filesystem supplement for a specific environment.
|
||||
|
||||
Template function handles all formatting (bullets, indentation, markdown).
|
||||
Callers provide clean data as lists of strings.
|
||||
|
||||
Args:
|
||||
working_dir: Working directory path
|
||||
sandbox_type: Description of bash_exec sandbox
|
||||
storage_system_1_name: Name of primary storage (ephemeral or cloud)
|
||||
storage_system_1_characteristics: List of characteristic descriptions
|
||||
storage_system_1_persistence: List of persistence behavior descriptions
|
||||
file_move_name_1_to_2: Direction label for primary→persistent
|
||||
file_move_name_2_to_1: Direction label for persistent→primary
|
||||
"""
|
||||
# Format lists as bullet points with proper indentation
|
||||
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
|
||||
persistence = "\n".join(f" - {p}" for p in storage_system_1_persistence)
|
||||
|
||||
return f"""
|
||||
|
||||
## Tool notes
|
||||
|
||||
### Shell commands
|
||||
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
|
||||
for shell commands — it runs {sandbox_type}.
|
||||
|
||||
### Working directory
|
||||
- Your working directory is: `{working_dir}`
|
||||
- All SDK file tools AND `bash_exec` operate on the same filesystem
|
||||
- Use relative paths or absolute paths under `{working_dir}` for all file operations
|
||||
|
||||
### Two storage systems — CRITICAL to understand
|
||||
|
||||
1. **{storage_system_1_name}** (`{working_dir}`):
|
||||
{characteristics}
|
||||
{persistence}
|
||||
|
||||
2. **Persistent workspace** (cloud storage):
|
||||
- Files here **survive across sessions indefinitely**
|
||||
|
||||
### Moving files between storages
|
||||
- **{file_move_name_1_to_2}**: Copy to persistent workspace
|
||||
- **{file_move_name_2_to_1}**: Download for processing
|
||||
|
||||
### File persistence
|
||||
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
|
||||
{_SHARED_TOOL_NOTES}"""
|
||||
|
||||
|
||||
# Pre-built supplements for common environments
|
||||
def _get_local_storage_supplement(cwd: str) -> str:
|
||||
"""Local ephemeral storage (files lost between turns)."""
|
||||
return _build_storage_supplement(
|
||||
working_dir=cwd,
|
||||
sandbox_type="in a network-isolated sandbox",
|
||||
storage_system_1_name="Ephemeral working directory",
|
||||
storage_system_1_characteristics=[
|
||||
"Shared by SDK Read/Write/Edit/Glob/Grep tools AND `bash_exec`",
|
||||
],
|
||||
storage_system_1_persistence=[
|
||||
"Files here are **lost between turns** — do NOT rely on them persisting",
|
||||
"Use for temporary work: running scripts, processing data, etc.",
|
||||
],
|
||||
file_move_name_1_to_2="Ephemeral → Persistent",
|
||||
file_move_name_2_to_1="Persistent → Ephemeral",
|
||||
)
|
||||
|
||||
|
||||
def _get_cloud_sandbox_supplement() -> str:
|
||||
"""Cloud persistent sandbox (files survive across turns in session)."""
|
||||
return _build_storage_supplement(
|
||||
working_dir="/home/user",
|
||||
sandbox_type="in a cloud sandbox with full internet access",
|
||||
storage_system_1_name="Cloud sandbox",
|
||||
storage_system_1_characteristics=[
|
||||
"Shared by all file tools AND `bash_exec` — same filesystem",
|
||||
"Full Linux environment with internet access",
|
||||
],
|
||||
storage_system_1_persistence=[
|
||||
"Files **persist across turns** within the current session",
|
||||
"Lost when the session expires (12 h inactivity)",
|
||||
],
|
||||
file_move_name_1_to_2="Sandbox → Persistent",
|
||||
file_move_name_2_to_1="Persistent → Sandbox",
|
||||
)
|
||||
|
||||
|
||||
def _generate_tool_documentation() -> str:
|
||||
"""Auto-generate tool documentation from TOOL_REGISTRY.
|
||||
|
||||
NOTE: This is ONLY used in baseline mode (direct OpenAI API).
|
||||
SDK mode doesn't need it since Claude gets tool schemas automatically.
|
||||
|
||||
This generates a complete list of available tools with their descriptions,
|
||||
ensuring the documentation stays in sync with the actual tool implementations.
|
||||
All workflow guidance is now embedded in individual tool descriptions.
|
||||
|
||||
Only documents tools that are available in the current environment
|
||||
(checked via tool.is_available property).
|
||||
"""
|
||||
docs = "\n## AVAILABLE TOOLS\n\n"
|
||||
|
||||
# Sort tools alphabetically for consistent output
|
||||
# Filter by is_available to match get_available_tools() behavior
|
||||
for name in sorted(TOOL_REGISTRY.keys()):
|
||||
tool = TOOL_REGISTRY[name]
|
||||
if not tool.is_available:
|
||||
continue
|
||||
schema = tool.as_openai_tool()
|
||||
desc = schema["function"].get("description", "No description available")
|
||||
# Format as bullet list with tool name in code style
|
||||
docs += f"- **`{name}`**: {desc}\n"
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def get_sdk_supplement(use_e2b: bool, cwd: str = "") -> str:
|
||||
"""Get the supplement for SDK mode (Claude Agent SDK).
|
||||
|
||||
SDK mode does NOT include tool documentation because Claude automatically
|
||||
receives tool schemas from the SDK. Only includes technical notes about
|
||||
storage systems and execution environment.
|
||||
|
||||
Args:
|
||||
use_e2b: Whether E2B cloud sandbox is being used
|
||||
cwd: Current working directory (only used in local_storage mode)
|
||||
|
||||
Returns:
|
||||
The supplement string to append to the system prompt
|
||||
"""
|
||||
if use_e2b:
|
||||
return _get_cloud_sandbox_supplement()
|
||||
return _get_local_storage_supplement(cwd)
|
||||
|
||||
|
||||
def get_baseline_supplement() -> str:
|
||||
"""Get the supplement for baseline mode (direct OpenAI API).
|
||||
|
||||
Baseline mode INCLUDES auto-generated tool documentation because the
|
||||
direct API doesn't automatically provide tool schemas to Claude.
|
||||
Also includes shared technical notes (but NOT SDK-specific environment details).
|
||||
|
||||
Returns:
|
||||
The supplement string to append to the system prompt
|
||||
"""
|
||||
tool_docs = _generate_tool_documentation()
|
||||
return tool_docs + _SHARED_TOOL_NOTES
|
||||
@@ -44,6 +44,7 @@ from ..model import (
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from ..prompting import get_sdk_supplement
|
||||
from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
@@ -146,140 +147,6 @@ _SDK_CWD_PREFIX = WORKSPACE_PREFIX
|
||||
_HEARTBEAT_INTERVAL = 10.0 # seconds
|
||||
|
||||
|
||||
# Appended to the system prompt to inform the agent about available tools.
|
||||
# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead,
|
||||
# which has kernel-level network isolation (unshare --net).
|
||||
_SHARED_TOOL_NOTES = """\
|
||||
|
||||
### Sharing files with the user
|
||||
After saving a file to the persistent workspace with `write_workspace_file`,
|
||||
share it with the user by embedding the `download_url` from the response in
|
||||
your message as a Markdown link or image:
|
||||
|
||||
- **Any file** — shows as a clickable download link:
|
||||
`[report.csv](workspace://file_id#text/csv)`
|
||||
- **Image** — renders inline in chat:
|
||||
``
|
||||
- **Video** — renders inline in chat with player controls:
|
||||
``
|
||||
|
||||
The `download_url` field in the `write_workspace_file` response is already
|
||||
in the correct format — paste it directly after the `(` in the Markdown.
|
||||
|
||||
### Long-running tools
|
||||
Long-running tools (create_agent, edit_agent, etc.) are handled
|
||||
asynchronously. You will receive an immediate response; the actual result
|
||||
is delivered to the user via a background stream.
|
||||
|
||||
### Large tool outputs
|
||||
When a tool output exceeds the display limit, it is automatically saved to
|
||||
the persistent workspace. The truncated output includes a
|
||||
`<tool-output-truncated>` tag with the workspace path. Use
|
||||
`read_workspace_file(path="...", offset=N, length=50000)` to retrieve
|
||||
additional sections.
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
"""
|
||||
|
||||
|
||||
_LOCAL_TOOL_SUPPLEMENT = (
|
||||
"""
|
||||
|
||||
## Tool notes
|
||||
|
||||
### Shell commands
|
||||
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
|
||||
for shell commands — it runs in a network-isolated sandbox.
|
||||
|
||||
### Working directory
|
||||
- Your working directory is: `{cwd}`
|
||||
- All SDK Read/Write/Edit/Glob/Grep tools AND `bash_exec` operate inside this
|
||||
directory. This is the ONLY writable path — do not attempt to read or write
|
||||
anywhere else on the filesystem.
|
||||
- Use relative paths or absolute paths under `{cwd}` for all file operations.
|
||||
|
||||
### Two storage systems — CRITICAL to understand
|
||||
|
||||
1. **Ephemeral working directory** (`{cwd}`):
|
||||
- Shared by SDK Read/Write/Edit/Glob/Grep tools AND `bash_exec`
|
||||
- Files here are **lost between turns** — do NOT rely on them persisting
|
||||
- Use for temporary work: running scripts, processing data, etc.
|
||||
|
||||
2. **Persistent workspace** (cloud storage):
|
||||
- Files here **survive across turns and sessions**
|
||||
- Use `write_workspace_file` to save important files (code, outputs, configs)
|
||||
- Use `read_workspace_file` to retrieve previously saved files
|
||||
- Use `list_workspace_files` to see what files you've saved before
|
||||
- Call `list_workspace_files(include_all_sessions=True)` to see files from
|
||||
all sessions
|
||||
|
||||
### Moving files between ephemeral and persistent storage
|
||||
- **Ephemeral → Persistent**: Use `write_workspace_file` with either:
|
||||
- `content` param (plain text) — for text files
|
||||
- `source_path` param — to copy any file directly from the ephemeral dir
|
||||
- **Persistent → Ephemeral**: Use `read_workspace_file` with `save_to_path`
|
||||
param to download a workspace file to the ephemeral dir for processing
|
||||
|
||||
### File persistence workflow
|
||||
When you create or modify important files (code, configs, outputs), you MUST:
|
||||
1. Save them using `write_workspace_file` so they persist
|
||||
2. At the start of a new turn, call `list_workspace_files` to see what files
|
||||
are available from previous turns
|
||||
"""
|
||||
+ _SHARED_TOOL_NOTES
|
||||
)
|
||||
|
||||
|
||||
_E2B_TOOL_SUPPLEMENT = (
|
||||
"""
|
||||
|
||||
## Tool notes
|
||||
|
||||
### Shell commands
|
||||
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
|
||||
for shell commands — it runs in a cloud sandbox with full internet access.
|
||||
|
||||
### Working directory
|
||||
- Your working directory is: `/home/user` (cloud sandbox)
|
||||
- All file tools (`read_file`, `write_file`, `edit_file`, `glob`, `grep`)
|
||||
AND `bash_exec` operate on the **same cloud sandbox filesystem**.
|
||||
- Files created by `bash_exec` are immediately visible to `read_file` and
|
||||
vice-versa — they share one filesystem.
|
||||
- Use relative paths (resolved from `/home/user`) or absolute paths.
|
||||
|
||||
### Two storage systems — CRITICAL to understand
|
||||
|
||||
1. **Cloud sandbox** (`/home/user`):
|
||||
- Shared by all file tools AND `bash_exec` — same filesystem
|
||||
- Files **persist across turns** within the current session
|
||||
- Full Linux environment with internet access
|
||||
- Lost when the session expires (12 h inactivity)
|
||||
|
||||
2. **Persistent workspace** (cloud storage):
|
||||
- Files here **survive across sessions indefinitely**
|
||||
- Use `write_workspace_file` to save important files permanently
|
||||
- Use `read_workspace_file` to retrieve previously saved files
|
||||
- Use `list_workspace_files` to see what files you've saved before
|
||||
- Call `list_workspace_files(include_all_sessions=True)` to see files from
|
||||
all sessions
|
||||
|
||||
### Moving files between sandbox and persistent storage
|
||||
- **Sandbox → Persistent**: Use `write_workspace_file` with `source_path`
|
||||
to copy from the sandbox to permanent storage
|
||||
- **Persistent → Sandbox**: Use `read_workspace_file` with `save_to_path`
|
||||
to download into the sandbox for processing
|
||||
|
||||
### File persistence workflow
|
||||
Important files that must survive beyond this session should be saved with
|
||||
`write_workspace_file`. Sandbox files persist across turns but are lost
|
||||
when the session expires.
|
||||
"""
|
||||
+ _SHARED_TOOL_NOTES
|
||||
)
|
||||
|
||||
|
||||
STREAM_LOCK_PREFIX = "copilot:stream:lock:"
|
||||
|
||||
|
||||
@@ -444,6 +311,7 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
|
||||
"""Convert SDK content blocks to transcript format.
|
||||
|
||||
Handles TextBlock, ToolUseBlock, ToolResultBlock, and ThinkingBlock.
|
||||
Unknown block types are logged and skipped.
|
||||
"""
|
||||
result: list[dict[str, Any]] = []
|
||||
for block in blocks or []:
|
||||
@@ -459,13 +327,14 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
|
||||
}
|
||||
)
|
||||
elif isinstance(block, ToolResultBlock):
|
||||
result.append(
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": block.tool_use_id,
|
||||
"content": block.content,
|
||||
}
|
||||
)
|
||||
tool_result_entry: dict[str, Any] = {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": block.tool_use_id,
|
||||
"content": block.content,
|
||||
}
|
||||
if block.is_error:
|
||||
tool_result_entry["is_error"] = True
|
||||
result.append(tool_result_entry)
|
||||
elif isinstance(block, ThinkingBlock):
|
||||
result.append(
|
||||
{
|
||||
@@ -474,6 +343,11 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
|
||||
"signature": block.signature,
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SDK] Unknown content block type: {type(block).__name__}. "
|
||||
f"This may indicate a new SDK version with additional block types."
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@@ -959,10 +833,9 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
|
||||
use_e2b = e2b_sandbox is not None
|
||||
system_prompt = base_system_prompt + (
|
||||
_E2B_TOOL_SUPPLEMENT
|
||||
if use_e2b
|
||||
else _LOCAL_TOOL_SUPPLEMENT.format(cwd=sdk_cwd)
|
||||
# Append appropriate supplement (Claude gets tool schemas automatically)
|
||||
system_prompt = base_system_prompt + get_sdk_supplement(
|
||||
use_e2b=use_e2b, cwd=sdk_cwd
|
||||
)
|
||||
|
||||
# Process transcript download result
|
||||
@@ -980,7 +853,7 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
if is_valid:
|
||||
# Load previous FULL context into builder
|
||||
transcript_builder.load_previous(dl.content)
|
||||
transcript_builder.load_previous(dl.content, log_prefix=log_prefix)
|
||||
resume_file = write_transcript_to_tempfile(
|
||||
dl.content, session_id, sdk_cwd
|
||||
)
|
||||
@@ -1127,18 +1000,18 @@ async def stream_chat_completion_sdk(
|
||||
json.dumps(user_msg) + "\n"
|
||||
)
|
||||
# Capture user message in transcript (multimodal)
|
||||
transcript_builder.add_user_message(content=content_blocks)
|
||||
transcript_builder.append_user(content=content_blocks)
|
||||
else:
|
||||
await client.query(query_message, session_id=session_id)
|
||||
# Capture user message in transcript (text only)
|
||||
transcript_builder.add_user_message(content=query_message)
|
||||
# Capture actual user message in transcript (not the engineered query)
|
||||
# query_message may include context wrappers, but transcript needs raw input
|
||||
transcript_builder.append_user(content=current_message)
|
||||
|
||||
assistant_response = ChatMessage(role="assistant", content="")
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
has_appended_assistant = False
|
||||
has_tool_results = False
|
||||
ended_with_stream_error = False
|
||||
|
||||
# Use an explicit async iterator with non-cancelling heartbeats.
|
||||
# CRITICAL: we must NOT cancel __anext__() mid-flight — doing so
|
||||
# (via asyncio.timeout or wait_for) corrupts the SDK's internal
|
||||
@@ -1209,15 +1082,6 @@ async def stream_chat_completion_sdk(
|
||||
len(adapter.resolved_tool_calls),
|
||||
)
|
||||
|
||||
# Capture AssistantMessage in transcript
|
||||
if isinstance(sdk_msg, AssistantMessage):
|
||||
content_blocks = _format_sdk_content_blocks(sdk_msg.content)
|
||||
model_name = getattr(sdk_msg, "model", "")
|
||||
transcript_builder.add_assistant_message(
|
||||
content_blocks=content_blocks,
|
||||
model=model_name,
|
||||
)
|
||||
|
||||
# Race-condition fix: SDK hooks (PostToolUse) are
|
||||
# executed asynchronously via start_soon() — the next
|
||||
# message can arrive before the hook stashes output.
|
||||
@@ -1348,22 +1212,37 @@ async def stream_chat_completion_sdk(
|
||||
has_appended_assistant = True
|
||||
|
||||
elif isinstance(response, StreamToolOutputAvailable):
|
||||
content = (
|
||||
response.output
|
||||
if isinstance(response.output, str)
|
||||
else json.dumps(response.output, ensure_ascii=False)
|
||||
)
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=(
|
||||
response.output
|
||||
if isinstance(response.output, str)
|
||||
else str(response.output)
|
||||
),
|
||||
content=content,
|
||||
tool_call_id=response.toolCallId,
|
||||
)
|
||||
)
|
||||
transcript_builder.append_tool_result(
|
||||
tool_use_id=response.toolCallId,
|
||||
content=content,
|
||||
)
|
||||
has_tool_results = True
|
||||
|
||||
elif isinstance(response, StreamFinish):
|
||||
stream_completed = True
|
||||
|
||||
# Append assistant entry AFTER convert_message so that
|
||||
# any stashed tool results from the previous turn are
|
||||
# recorded first, preserving the required API order:
|
||||
# assistant(tool_use) → tool_result → assistant(text).
|
||||
if isinstance(sdk_msg, AssistantMessage):
|
||||
transcript_builder.append_assistant(
|
||||
content_blocks=_format_sdk_content_blocks(sdk_msg.content),
|
||||
model=sdk_msg.model,
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Task/generator was cancelled (e.g. client disconnect,
|
||||
# server shutdown). Log and let the safety-net / finally
|
||||
@@ -1412,6 +1291,15 @@ async def stream_chat_completion_sdk(
|
||||
type(response).__name__,
|
||||
getattr(response, "toolName", "N/A"),
|
||||
)
|
||||
if isinstance(response, StreamToolOutputAvailable):
|
||||
transcript_builder.append_tool_result(
|
||||
tool_use_id=response.toolCallId,
|
||||
content=(
|
||||
response.output
|
||||
if isinstance(response.output, str)
|
||||
else json.dumps(response.output, ensure_ascii=False)
|
||||
),
|
||||
)
|
||||
yield response
|
||||
|
||||
# If the stream ended without a ResultMessage, the SDK
|
||||
@@ -1554,8 +1442,10 @@ async def stream_chat_completion_sdk(
|
||||
transcript_builder.entry_count,
|
||||
len(transcript_content),
|
||||
)
|
||||
# Create task first so we have a reference if timeout occurs
|
||||
upload_task = asyncio.create_task(
|
||||
# Shield upload from cancellation - let it complete even if
|
||||
# the finally block is interrupted. No timeout to avoid race
|
||||
# conditions where backgrounded uploads overwrite newer transcripts.
|
||||
await asyncio.shield(
|
||||
upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
@@ -1564,19 +1454,6 @@ async def stream_chat_completion_sdk(
|
||||
log_prefix=log_prefix,
|
||||
)
|
||||
)
|
||||
try:
|
||||
async with asyncio.timeout(30):
|
||||
await asyncio.shield(upload_task)
|
||||
except TimeoutError:
|
||||
# Timeout fired but shield keeps upload running - track the
|
||||
# SAME task to prevent garbage collection (no double upload)
|
||||
logger.warning(
|
||||
"%s Transcript upload exceeded 30s timeout, "
|
||||
"continuing in background",
|
||||
log_prefix,
|
||||
)
|
||||
_background_tasks.add(upload_task)
|
||||
upload_task.add_done_callback(_background_tasks.discard)
|
||||
except Exception as upload_err:
|
||||
logger.error(
|
||||
"%s Transcript upload failed in finally: %s",
|
||||
|
||||
@@ -145,3 +145,103 @@ class TestPrepareFileAttachments:
|
||||
|
||||
assert "Read tool" not in result.hint
|
||||
assert len(result.image_blocks) == 1
|
||||
|
||||
|
||||
class TestPromptSupplement:
|
||||
"""Tests for centralized prompt supplement generation."""
|
||||
|
||||
def test_sdk_supplement_excludes_tool_docs(self):
|
||||
"""SDK mode should NOT include tool documentation (Claude gets schemas automatically)."""
|
||||
from backend.copilot.prompting import get_sdk_supplement
|
||||
|
||||
# Test both local and E2B modes
|
||||
local_supplement = get_sdk_supplement(use_e2b=False, cwd="/tmp/test")
|
||||
e2b_supplement = get_sdk_supplement(use_e2b=True, cwd="")
|
||||
|
||||
# Should NOT have tool list section
|
||||
assert "## AVAILABLE TOOLS" not in local_supplement
|
||||
assert "## AVAILABLE TOOLS" not in e2b_supplement
|
||||
|
||||
# Should still have technical notes
|
||||
assert "## Tool notes" in local_supplement
|
||||
assert "## Tool notes" in e2b_supplement
|
||||
|
||||
def test_baseline_supplement_includes_tool_docs(self):
|
||||
"""Baseline mode MUST include tool documentation (direct API needs it)."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
|
||||
supplement = get_baseline_supplement()
|
||||
|
||||
# MUST have tool list section
|
||||
assert "## AVAILABLE TOOLS" in supplement
|
||||
|
||||
# Should NOT have environment-specific notes (SDK-only)
|
||||
assert "## Tool notes" not in supplement
|
||||
|
||||
def test_baseline_supplement_includes_key_tools(self):
|
||||
"""Baseline supplement should document all essential tools."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
docs = get_baseline_supplement()
|
||||
|
||||
# Core agent workflow tools (always available)
|
||||
assert "`create_agent`" in docs
|
||||
assert "`run_agent`" in docs
|
||||
assert "`find_library_agent`" in docs
|
||||
assert "`edit_agent`" in docs
|
||||
|
||||
# MCP integration (always available)
|
||||
assert "`run_mcp_tool`" in docs
|
||||
|
||||
# Folder management (always available)
|
||||
assert "`create_folder`" in docs
|
||||
|
||||
# Browser tools only if available (Playwright may not be installed in CI)
|
||||
if (
|
||||
TOOL_REGISTRY.get("browser_navigate")
|
||||
and TOOL_REGISTRY["browser_navigate"].is_available
|
||||
):
|
||||
assert "`browser_navigate`" in docs
|
||||
|
||||
def test_baseline_supplement_includes_workflows(self):
|
||||
"""Baseline supplement should include workflow guidance in tool descriptions."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
|
||||
docs = get_baseline_supplement()
|
||||
|
||||
# Workflows are now in individual tool descriptions (not separate sections)
|
||||
# Check that key workflow concepts appear in tool descriptions
|
||||
assert "suggested_goal" in docs or "clarifying_questions" in docs
|
||||
assert "run_mcp_tool" in docs
|
||||
|
||||
def test_baseline_supplement_completeness(self):
|
||||
"""All available tools from TOOL_REGISTRY should appear in baseline supplement."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
docs = get_baseline_supplement()
|
||||
|
||||
# Verify each available registered tool is documented
|
||||
# (matches _generate_tool_documentation which filters by is_available)
|
||||
for tool_name, tool in TOOL_REGISTRY.items():
|
||||
if not tool.is_available:
|
||||
continue
|
||||
assert (
|
||||
f"`{tool_name}`" in docs
|
||||
), f"Tool '{tool_name}' missing from baseline supplement"
|
||||
|
||||
def test_baseline_supplement_no_duplicate_tools(self):
|
||||
"""No tool should appear multiple times in baseline supplement."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
docs = get_baseline_supplement()
|
||||
|
||||
# Count occurrences of each available tool in the entire supplement
|
||||
for tool_name, tool in TOOL_REGISTRY.items():
|
||||
if not tool.is_available:
|
||||
continue
|
||||
# Count how many times this tool appears as a bullet point
|
||||
count = docs.count(f"- **`{tool_name}`**")
|
||||
assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)"
|
||||
|
||||
@@ -10,13 +10,14 @@ Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
|
||||
filesystem for self-hosted) — no DB column needed.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from backend.util import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
|
||||
@@ -68,17 +69,14 @@ def strip_progress_entries(content: str) -> str:
|
||||
# Parse entries, keeping the original line alongside the parsed dict.
|
||||
parsed: list[tuple[str, dict | None]] = []
|
||||
for line in lines:
|
||||
try:
|
||||
parsed.append((line, json.loads(line)))
|
||||
except json.JSONDecodeError:
|
||||
parsed.append((line, None))
|
||||
parsed.append((line, json.loads(line, fallback=None)))
|
||||
|
||||
# First pass: identify stripped UUIDs and build parent map.
|
||||
stripped_uuids: set[str] = set()
|
||||
uuid_to_parent: dict[str, str] = {}
|
||||
|
||||
for _line, entry in parsed:
|
||||
if entry is None:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
uid = entry.get("uuid", "")
|
||||
parent = entry.get("parentUuid", "")
|
||||
@@ -91,7 +89,7 @@ def strip_progress_entries(content: str) -> str:
|
||||
# Preserve original line when no reparenting is required.
|
||||
reparented: set[str] = set()
|
||||
for _line, entry in parsed:
|
||||
if entry is None:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
parent = entry.get("parentUuid", "")
|
||||
original_parent = parent
|
||||
@@ -105,7 +103,7 @@ def strip_progress_entries(content: str) -> str:
|
||||
|
||||
result_lines: list[str] = []
|
||||
for line, entry in parsed:
|
||||
if entry is None:
|
||||
if not isinstance(entry, dict):
|
||||
result_lines.append(line)
|
||||
continue
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES:
|
||||
@@ -225,12 +223,11 @@ def validate_transcript(content: str | None) -> bool:
|
||||
for line in lines:
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
if entry.get("type") == "assistant":
|
||||
has_assistant = True
|
||||
except json.JSONDecodeError:
|
||||
entry = json.loads(line, fallback=None)
|
||||
if not isinstance(entry, dict):
|
||||
return False
|
||||
if entry.get("type") == "assistant":
|
||||
has_assistant = True
|
||||
|
||||
return has_assistant
|
||||
|
||||
@@ -310,10 +307,8 @@ async def upload_transcript(
|
||||
# Log entry types for debugging — helps identify why validation failed
|
||||
entry_types: list[str] = []
|
||||
for line in stripped.strip().split("\n"):
|
||||
try:
|
||||
entry_types.append(json.loads(line).get("type", "?"))
|
||||
except json.JSONDecodeError:
|
||||
entry_types.append("INVALID_JSON")
|
||||
entry = json.loads(line, fallback={"type": "INVALID_JSON"})
|
||||
entry_types.append(entry.get("type", "?"))
|
||||
logger.warning(
|
||||
"%s Skipping upload — stripped content not valid "
|
||||
"(types=%s, stripped_len=%d, raw_len=%d)",
|
||||
@@ -396,10 +391,10 @@ async def download_transcript(
|
||||
meta_path = f"local://{mwid}/{mfid}/{mfname}"
|
||||
|
||||
meta_data = await storage.retrieve(meta_path)
|
||||
meta = json.loads(meta_data.decode("utf-8"))
|
||||
meta = json.loads(meta_data.decode("utf-8"), fallback={})
|
||||
message_count = meta.get("message_count", 0)
|
||||
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||
except (FileNotFoundError, json.JSONDecodeError, Exception):
|
||||
except (FileNotFoundError, Exception):
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
|
||||
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
|
||||
|
||||
@@ -11,13 +11,16 @@ Flow:
|
||||
The transcript is never incremental - always the complete atomic state.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .transcript import STRIPPABLE_TYPES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -41,7 +44,16 @@ class TranscriptBuilder:
|
||||
self._entries: list[TranscriptEntry] = []
|
||||
self._last_uuid: str | None = None
|
||||
|
||||
def load_previous(self, content: str) -> None:
|
||||
def _last_is_assistant(self) -> bool:
|
||||
return bool(self._entries) and self._entries[-1].type == "assistant"
|
||||
|
||||
def _last_message_id(self) -> str:
|
||||
"""Return the message.id of the last entry, or '' if none."""
|
||||
if self._entries:
|
||||
return self._entries[-1].message.get("id", "")
|
||||
return ""
|
||||
|
||||
def load_previous(self, content: str, log_prefix: str = "[Transcript]") -> None:
|
||||
"""Load complete previous transcript.
|
||||
|
||||
This loads the FULL previous context. As new messages come in,
|
||||
@@ -51,19 +63,25 @@ class TranscriptBuilder:
|
||||
if not content or not content.strip():
|
||||
return
|
||||
|
||||
for line in content.strip().split("\n"):
|
||||
lines = content.strip().split("\n")
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Failed to parse transcript line: %s", line[:100])
|
||||
data = json.loads(line, fallback=None)
|
||||
if data is None:
|
||||
logger.warning(
|
||||
"%s Failed to parse transcript line %d/%d",
|
||||
log_prefix,
|
||||
line_num,
|
||||
len(lines),
|
||||
)
|
||||
continue
|
||||
|
||||
# Only load conversation messages (user/assistant)
|
||||
# Skip metadata entries
|
||||
if data.get("type") not in ("user", "assistant"):
|
||||
# Load all non-strippable entries (user/assistant/system/etc.)
|
||||
# Skip only STRIPPABLE_TYPES to match strip_progress_entries() behavior
|
||||
entry_type = data.get("type", "")
|
||||
if entry_type in STRIPPABLE_TYPES:
|
||||
continue
|
||||
|
||||
entry = TranscriptEntry(
|
||||
@@ -76,15 +94,14 @@ class TranscriptBuilder:
|
||||
self._last_uuid = entry.uuid
|
||||
|
||||
logger.info(
|
||||
"Loaded %d entries from previous transcript (last_uuid=%s)",
|
||||
"%s Loaded %d entries from previous transcript (last_uuid=%s)",
|
||||
log_prefix,
|
||||
len(self._entries),
|
||||
self._last_uuid[:12] if self._last_uuid else None,
|
||||
)
|
||||
|
||||
def add_user_message(
|
||||
self, content: str | list[dict], uuid: str | None = None
|
||||
) -> None:
|
||||
"""Add user message to the complete context."""
|
||||
def append_user(self, content: str | list[dict], uuid: str | None = None) -> None:
|
||||
"""Append a user entry."""
|
||||
msg_uuid = uuid or str(uuid4())
|
||||
|
||||
self._entries.append(
|
||||
@@ -97,10 +114,34 @@ class TranscriptBuilder:
|
||||
)
|
||||
self._last_uuid = msg_uuid
|
||||
|
||||
def add_assistant_message(
|
||||
self, content_blocks: list[dict], model: str = ""
|
||||
def append_tool_result(self, tool_use_id: str, content: str) -> None:
|
||||
"""Append a tool result as a user entry (one per tool call)."""
|
||||
self.append_user(
|
||||
content=[
|
||||
{"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
|
||||
]
|
||||
)
|
||||
|
||||
def append_assistant(
|
||||
self,
|
||||
content_blocks: list[dict],
|
||||
model: str = "",
|
||||
stop_reason: str | None = None,
|
||||
) -> None:
|
||||
"""Add assistant message to the complete context."""
|
||||
"""Append an assistant entry.
|
||||
|
||||
Consecutive assistant entries automatically share the same message ID
|
||||
so the CLI can merge them (thinking → text → tool_use) into a single
|
||||
API message on ``--resume``. A new ID is assigned whenever an
|
||||
assistant entry follows a non-assistant entry (user message or tool
|
||||
result), because that marks the start of a new API response.
|
||||
"""
|
||||
message_id = (
|
||||
self._last_message_id()
|
||||
if self._last_is_assistant()
|
||||
else f"msg_sdk_{uuid4().hex[:24]}"
|
||||
)
|
||||
|
||||
msg_uuid = str(uuid4())
|
||||
|
||||
self._entries.append(
|
||||
@@ -111,7 +152,11 @@ class TranscriptBuilder:
|
||||
message={
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"id": message_id,
|
||||
"type": "message",
|
||||
"content": content_blocks,
|
||||
"stop_reason": stop_reason,
|
||||
"stop_sequence": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -120,6 +165,9 @@ class TranscriptBuilder:
|
||||
def to_jsonl(self) -> str:
|
||||
"""Export complete context as JSONL.
|
||||
|
||||
Consecutive assistant entries are kept separate to match the
|
||||
native CLI format — the SDK merges them internally on resume.
|
||||
|
||||
Returns the FULL conversation state (all entries), not incremental.
|
||||
This output REPLACES any previous transcript.
|
||||
"""
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Unit tests for JSONL transcript management utilities."""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .transcript import (
|
||||
STRIPPABLE_TYPES,
|
||||
strip_progress_entries,
|
||||
@@ -256,10 +257,9 @@ class TestStripProgressEntries:
|
||||
|
||||
def test_preserves_original_line_formatting(self):
|
||||
"""Non-reparented entries keep their original JSON formatting."""
|
||||
# Use pretty-printed JSON with spaces (as the CLI produces)
|
||||
original_line = json.dumps(USER_MSG) # default formatting with spaces
|
||||
compact_line = json.dumps(USER_MSG, separators=(",", ":"))
|
||||
assert original_line != compact_line # precondition
|
||||
# orjson produces compact JSON - test that we preserve the exact input
|
||||
# when no reparenting is needed (no re-serialization)
|
||||
original_line = json.dumps(USER_MSG)
|
||||
|
||||
content = original_line + "\n" + json.dumps(ASST_MSG) + "\n"
|
||||
result = strip_progress_entries(content)
|
||||
|
||||
@@ -18,7 +18,7 @@ from langfuse.openai import (
|
||||
|
||||
from backend.data.db_accessors import understanding_db
|
||||
from backend.data.understanding import format_understanding_for_prompt
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
from backend.util.settings import AppEnvironment, Settings
|
||||
|
||||
from .config import ChatConfig
|
||||
@@ -34,8 +34,9 @@ client = LangfuseAsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
langfuse = get_client()
|
||||
|
||||
# Default system prompt used when Langfuse is not configured
|
||||
# This is a snapshot of the "CoPilot Prompt" from Langfuse (version 11)
|
||||
DEFAULT_SYSTEM_PROMPT = """You are **Otto**, an AI Co-Pilot for AutoGPT and a Forward-Deployed Automation Engineer serving small business owners. Your mission is to help users automate business tasks with AI by delivering tangible value through working automations—not through documentation or lengthy explanations.
|
||||
# Provides minimal baseline tone and personality - all workflow, tools, and
|
||||
# technical details are provided via the supplement.
|
||||
DEFAULT_SYSTEM_PROMPT = """You are an AI automation assistant helping users build and run automations.
|
||||
|
||||
Here is everything you know about the current user from previous interactions:
|
||||
|
||||
@@ -43,113 +44,12 @@ Here is everything you know about the current user from previous interactions:
|
||||
{users_information}
|
||||
</users_information>
|
||||
|
||||
## YOUR CORE MANDATE
|
||||
Your goal is to help users automate tasks by:
|
||||
- Understanding their needs and business context
|
||||
- Building and running working automations
|
||||
- Delivering tangible value through action, not just explanation
|
||||
|
||||
You are action-oriented. Your success is measured by:
|
||||
- **Value Delivery**: Does the user think "wow, that was amazing" or "what was the point"?
|
||||
- **Demonstrable Proof**: Show working automations, not descriptions of what's possible
|
||||
- **Time Saved**: Focus on tangible efficiency gains
|
||||
- **Quality Output**: Deliver results that meet or exceed expectations
|
||||
|
||||
## YOUR WORKFLOW
|
||||
|
||||
Adapt flexibly to the conversation context. Not every interaction requires all stages:
|
||||
|
||||
1. **Explore & Understand**: Learn about the user's business, tasks, and goals. Use `add_understanding` to capture important context that will improve future conversations.
|
||||
|
||||
2. **Assess Automation Potential**: Help the user understand whether and how AI can automate their task.
|
||||
|
||||
3. **Prepare for AI**: Provide brief, actionable guidance on prerequisites (data, access, etc.).
|
||||
|
||||
4. **Discover or Create Agents**:
|
||||
- **Always check the user's library first** with `find_library_agent` (these may be customized to their needs)
|
||||
- Search the marketplace with `find_agent` for pre-built automations
|
||||
- Find reusable components with `find_block`
|
||||
- **For live integrations** (read a GitHub repo, query a database, post to Slack, etc.) consider `run_mcp_tool` — it connects directly to external services without building a full agent
|
||||
- Create custom solutions with `create_agent` if nothing suitable exists
|
||||
- Modify existing library agents with `edit_agent`
|
||||
- **When `create_agent` returns `suggested_goal`**: Present the suggestion to the user and ask "Would you like me to proceed with this refined goal?" If they accept, call `create_agent` again with the suggested goal.
|
||||
- **When `create_agent` returns `clarifying_questions`**: After the user answers, call `create_agent` again with the original description AND the answers in the `context` parameter.
|
||||
|
||||
5. **Execute**: Run automations immediately, schedule them, or set up webhooks using `run_agent`. Test specific components with `run_block`.
|
||||
|
||||
6. **Show Results**: Display outputs using `agent_output`.
|
||||
|
||||
## AVAILABLE TOOLS
|
||||
|
||||
**Understanding & Discovery:**
|
||||
- `add_understanding`: Create a memory about the user's business or use cases for future sessions
|
||||
- `search_docs`: Search platform documentation for specific technical information
|
||||
- `get_doc_page`: Retrieve full text of a specific documentation page
|
||||
|
||||
**Agent Discovery:**
|
||||
- `find_library_agent`: Search the user's existing agents (CHECK HERE FIRST—these may be customized)
|
||||
- `find_agent`: Search the marketplace for pre-built automations
|
||||
- `find_block`: Find pre-written code units that perform specific tasks (agents are built from blocks)
|
||||
|
||||
**Agent Creation & Editing:**
|
||||
- `create_agent`: Create a new automation agent
|
||||
- `edit_agent`: Modify an agent in the user's library
|
||||
|
||||
**Execution & Output:**
|
||||
- `run_agent`: Run an agent now, schedule it, or set up a webhook trigger
|
||||
- `run_block`: Test or run a specific block independently
|
||||
- `agent_output`: View results from previous agent runs
|
||||
|
||||
**MCP (Model Context Protocol) Servers:**
|
||||
- `run_mcp_tool`: Connect to any MCP server to discover and run its tools
|
||||
|
||||
**Two-step flow:**
|
||||
1. `run_mcp_tool(server_url)` → returns a list of available tools. Each tool has `name`, `description`, and `input_schema` (JSON Schema). Read `input_schema.properties` to understand what arguments are needed.
|
||||
2. `run_mcp_tool(server_url, tool_name, tool_arguments)` → executes the tool. Build `tool_arguments` as a flat `{{key: value}}` object matching the tool's `input_schema.properties`.
|
||||
|
||||
**Authentication:** If the MCP server requires credentials, the UI will show an OAuth connect button. Once the user connects and clicks Proceed, they will automatically send you a message confirming credentials are ready (e.g. "I've connected the MCP server credentials. Please retry run_mcp_tool..."). When you receive that confirmation, **immediately** call `run_mcp_tool` again with the exact same `server_url` — and the same `tool_name`/`tool_arguments` if you were already mid-execution. Do not ask the user what to do next; just retry.
|
||||
|
||||
**Finding server URLs (fastest → slowest):**
|
||||
1. **Known hosted servers** — use directly, no lookup:
|
||||
- Notion: `https://mcp.notion.com/mcp`
|
||||
- Linear: `https://mcp.linear.app/mcp`
|
||||
- Stripe: `https://mcp.stripe.com`
|
||||
- Intercom: `https://mcp.intercom.com/mcp`
|
||||
- Cloudflare: `https://mcp.cloudflare.com/mcp`
|
||||
- Atlassian (Jira/Confluence): `https://mcp.atlassian.com/mcp`
|
||||
2. **`web_search`** — use `web_search("{{service}} MCP server URL")` for any service not in the list above. This is the fastest way to find unlisted servers.
|
||||
3. **Registry API** — `web_fetch("https://registry.modelcontextprotocol.io/v0.1/servers?search={{query}}&limit=10")` to browse what's available. Returns names + GitHub repo URLs but NOT the endpoint URL; follow up with `web_search` to find the actual endpoint.
|
||||
- **Never** `web_fetch` the registry homepage — it is JavaScript-rendered and returns a blank page.
|
||||
|
||||
**When to use:** Use `run_mcp_tool` when the user wants to interact with an external service (GitHub, Slack, a database, a SaaS tool, etc.) via its MCP integration. Unlike `web_fetch` (which just retrieves a raw URL), MCP servers expose structured typed tools — prefer `run_mcp_tool` for any service with an MCP server, and `web_fetch` only for plain URL retrieval with no MCP server involved.
|
||||
|
||||
**CRITICAL**: `run_mcp_tool` is **always available** in your tool list. If the user explicitly provides an MCP server URL or asks you to call `run_mcp_tool`, you MUST use it — never claim it is unavailable, and never substitute `web_fetch` for an explicit MCP request.
|
||||
|
||||
## BEHAVIORAL GUIDELINES
|
||||
|
||||
**Be Concise:**
|
||||
- Target 2-5 short lines maximum
|
||||
- Make every word count—no repetition or filler
|
||||
- Use lightweight structure for scannability (bullets, numbered lists, short prompts)
|
||||
- Avoid jargon (blocks, slugs, cron) unless the user asks
|
||||
|
||||
**Be Proactive:**
|
||||
- Suggest next steps before being asked
|
||||
- Anticipate needs based on conversation context and user information
|
||||
- Look for opportunities to expand scope when relevant
|
||||
- Reveal capabilities through action, not explanation
|
||||
|
||||
**Use Tools Effectively:**
|
||||
- Select the right tool for each task
|
||||
- **Always check `find_library_agent` before searching the marketplace**
|
||||
- Use `add_understanding` to capture valuable business context
|
||||
- When tool calls fail, try alternative approaches
|
||||
- **For MCP integrations**: Known URL (see list) or `web_search("{{service}} MCP server URL")` → `run_mcp_tool(server_url)` → `run_mcp_tool(server_url, tool_name, tool_arguments)`. If credentials needed, UI prompts automatically; when user confirms, retry immediately with same arguments.
|
||||
|
||||
**Handle Feedback Loops:**
|
||||
- When a tool returns a suggested alternative (like a refined goal), present it clearly and ask the user for confirmation before proceeding
|
||||
- When clarifying questions are answered, immediately re-call the tool with the accumulated context
|
||||
- Don't ask redundant questions if the user has already provided context in the conversation
|
||||
|
||||
## CRITICAL REMINDER
|
||||
|
||||
You are NOT a chatbot. You are NOT documentation. You are a partner who helps busy business owners get value quickly by showing proof through working automations. Bias toward action over explanation."""
|
||||
Be concise, proactive, and action-oriented. Bias toward showing working solutions over lengthy explanations."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -298,6 +198,12 @@ async def assign_user_to_session(
|
||||
session = await get_chat_session(session_id, None)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found")
|
||||
if session.user_id is not None and session.user_id != user_id:
|
||||
logger.warning(
|
||||
f"[SECURITY] Attempt to claim session {session_id} by user {user_id}, "
|
||||
f"but it already belongs to user {session.user_id}"
|
||||
)
|
||||
raise NotAuthorizedError(f"Not authorized to claim session {session_id}")
|
||||
session.user_id = user_id
|
||||
session = await upsert_chat_session(session)
|
||||
return session
|
||||
|
||||
@@ -20,6 +20,14 @@ from .find_agent import FindAgentTool
|
||||
from .find_block import FindBlockTool
|
||||
from .find_library_agent import FindLibraryAgentTool
|
||||
from .get_doc_page import GetDocPageTool
|
||||
from .manage_folders import (
|
||||
CreateFolderTool,
|
||||
DeleteFolderTool,
|
||||
ListFoldersTool,
|
||||
MoveAgentsToFolderTool,
|
||||
MoveFolderTool,
|
||||
UpdateFolderTool,
|
||||
)
|
||||
from .run_agent import RunAgentTool
|
||||
from .run_block import RunBlockTool
|
||||
from .run_mcp_tool import RunMCPToolTool
|
||||
@@ -47,6 +55,13 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"find_agent": FindAgentTool(),
|
||||
"find_block": FindBlockTool(),
|
||||
"find_library_agent": FindLibraryAgentTool(),
|
||||
# Folder management tools
|
||||
"create_folder": CreateFolderTool(),
|
||||
"list_folders": ListFoldersTool(),
|
||||
"update_folder": UpdateFolderTool(),
|
||||
"move_folder": MoveFolderTool(),
|
||||
"delete_folder": DeleteFolderTool(),
|
||||
"move_agents_to_folder": MoveAgentsToFolderTool(),
|
||||
"run_agent": RunAgentTool(),
|
||||
"run_block": RunBlockTool(),
|
||||
"run_mcp_tool": RunMCPToolTool(),
|
||||
|
||||
@@ -151,8 +151,8 @@ async def setup_test_data(server):
|
||||
unique_slug = f"test-agent-{str(uuid.uuid4())[:8]}"
|
||||
store_submission = await store_db.create_store_submission(
|
||||
user_id=user.id,
|
||||
agent_id=created_graph.id,
|
||||
agent_version=created_graph.version,
|
||||
graph_id=created_graph.id,
|
||||
graph_version=created_graph.version,
|
||||
slug=unique_slug,
|
||||
name="Test Agent",
|
||||
description="A simple test agent",
|
||||
@@ -161,10 +161,10 @@ async def setup_test_data(server):
|
||||
image_urls=["https://example.com/image.jpg"],
|
||||
)
|
||||
|
||||
assert store_submission.store_listing_version_id is not None
|
||||
assert store_submission.listing_version_id is not None
|
||||
# 4. Approve the store listing version
|
||||
await store_db.review_store_submission(
|
||||
store_listing_version_id=store_submission.store_listing_version_id,
|
||||
store_listing_version_id=store_submission.listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Approved for testing",
|
||||
internal_comments="Test approval",
|
||||
@@ -321,8 +321,8 @@ async def setup_llm_test_data(server):
|
||||
unique_slug = f"llm-test-agent-{str(uuid.uuid4())[:8]}"
|
||||
store_submission = await store_db.create_store_submission(
|
||||
user_id=user.id,
|
||||
agent_id=created_graph.id,
|
||||
agent_version=created_graph.version,
|
||||
graph_id=created_graph.id,
|
||||
graph_version=created_graph.version,
|
||||
slug=unique_slug,
|
||||
name="LLM Test Agent",
|
||||
description="An agent with LLM capabilities",
|
||||
@@ -330,9 +330,9 @@ async def setup_llm_test_data(server):
|
||||
categories=["testing", "ai"],
|
||||
image_urls=["https://example.com/image.jpg"],
|
||||
)
|
||||
assert store_submission.store_listing_version_id is not None
|
||||
assert store_submission.listing_version_id is not None
|
||||
await store_db.review_store_submission(
|
||||
store_listing_version_id=store_submission.store_listing_version_id,
|
||||
store_listing_version_id=store_submission.listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Approved for testing",
|
||||
internal_comments="Test approval for LLM agent",
|
||||
@@ -476,8 +476,8 @@ async def setup_firecrawl_test_data(server):
|
||||
unique_slug = f"firecrawl-test-agent-{str(uuid.uuid4())[:8]}"
|
||||
store_submission = await store_db.create_store_submission(
|
||||
user_id=user.id,
|
||||
agent_id=created_graph.id,
|
||||
agent_version=created_graph.version,
|
||||
graph_id=created_graph.id,
|
||||
graph_version=created_graph.version,
|
||||
slug=unique_slug,
|
||||
name="Firecrawl Test Agent",
|
||||
description="An agent with Firecrawl integration (no credentials)",
|
||||
@@ -485,9 +485,9 @@ async def setup_firecrawl_test_data(server):
|
||||
categories=["testing", "scraping"],
|
||||
image_urls=["https://example.com/image.jpg"],
|
||||
)
|
||||
assert store_submission.store_listing_version_id is not None
|
||||
assert store_submission.listing_version_id is not None
|
||||
await store_db.review_store_submission(
|
||||
store_listing_version_id=store_submission.store_listing_version_id,
|
||||
store_listing_version_id=store_submission.listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Approved for testing",
|
||||
internal_comments="Test approval for Firecrawl agent",
|
||||
|
||||
@@ -695,7 +695,10 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||
|
||||
|
||||
async def save_agent_to_library(
|
||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||
agent_json: dict[str, Any],
|
||||
user_id: str,
|
||||
is_update: bool = False,
|
||||
folder_id: str | None = None,
|
||||
) -> tuple[Graph, Any]:
|
||||
"""Save agent to database and user's library.
|
||||
|
||||
@@ -703,6 +706,7 @@ async def save_agent_to_library(
|
||||
agent_json: Agent JSON dict
|
||||
user_id: User ID
|
||||
is_update: Whether this is an update to an existing agent
|
||||
folder_id: Optional folder ID to place the agent in
|
||||
|
||||
Returns:
|
||||
Tuple of (created Graph, LibraryAgent)
|
||||
@@ -711,7 +715,7 @@ async def save_agent_to_library(
|
||||
db = library_db()
|
||||
if is_update:
|
||||
return await db.update_graph_in_library(graph, user_id)
|
||||
return await db.create_graph_in_library(graph, user_id)
|
||||
return await db.create_graph_in_library(graph, user_id, folder_id=folder_id)
|
||||
|
||||
|
||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||
|
||||
@@ -39,9 +39,13 @@ class CreateAgentTool(BaseTool):
|
||||
return (
|
||||
"Create a new agent workflow from a natural language description. "
|
||||
"First generates a preview, then saves to library if save=true. "
|
||||
"\n\nIMPORTANT: Before calling this tool, search for relevant existing agents "
|
||||
"using find_library_agent that could be used as building blocks. "
|
||||
"Pass their IDs in the library_agent_ids parameter so the generator can compose them."
|
||||
"\n\nWorkflow: (1) Always check find_library_agent first for existing building blocks. "
|
||||
"(2) Call create_agent with description and library_agent_ids. "
|
||||
"(3) If response contains suggested_goal: Present to user, ask for confirmation, "
|
||||
"then call again with the suggested goal if accepted. "
|
||||
"(4) If response contains clarifying_questions: Present to user, collect answers, "
|
||||
"then call again with original description AND answers in the context parameter. "
|
||||
"\n\nThis feedback loop ensures the generated agent matches user intent."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -84,6 +88,14 @@ class CreateAgentTool(BaseTool):
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional folder ID to save the agent into. "
|
||||
"If not provided, the agent is saved at root level. "
|
||||
"Use list_folders to find available folders."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["description"],
|
||||
}
|
||||
@@ -105,6 +117,7 @@ class CreateAgentTool(BaseTool):
|
||||
context = kwargs.get("context", "")
|
||||
library_agent_ids = kwargs.get("library_agent_ids", [])
|
||||
save = kwargs.get("save", True)
|
||||
folder_id = kwargs.get("folder_id")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
logger.info(
|
||||
@@ -336,7 +349,7 @@ class CreateAgentTool(BaseTool):
|
||||
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
agent_json, user_id
|
||||
agent_json, user_id, folder_id=folder_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import store_db as get_store_db
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
@@ -80,6 +80,14 @@ class CustomizeAgentTool(BaseTool):
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional folder ID to save the agent into. "
|
||||
"If not provided, the agent is saved at root level. "
|
||||
"Use list_folders to find available folders."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["agent_id", "modifications"],
|
||||
}
|
||||
@@ -102,6 +110,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
modifications = kwargs.get("modifications", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
folder_id = kwargs.get("folder_id")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_id:
|
||||
@@ -140,7 +149,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
agent_details = await store_db.get_store_agent_details(
|
||||
username=creator_username, agent_name=agent_slug
|
||||
)
|
||||
except AgentNotFoundError:
|
||||
except NotFoundError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Could not find marketplace agent '{agent_id}'. "
|
||||
@@ -310,7 +319,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
# Save to user's library
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
customized_agent, user_id, is_update=False
|
||||
customized_agent, user_id, is_update=False, folder_id=folder_id
|
||||
)
|
||||
|
||||
return AgentSavedResponse(
|
||||
|
||||
573
autogpt_platform/backend/backend/copilot/tools/manage_folders.py
Normal file
573
autogpt_platform/backend/backend/copilot/tools/manage_folders.py
Normal file
@@ -0,0 +1,573 @@
|
||||
"""Folder management tools for the copilot."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.library import model as library_model
|
||||
from backend.api.features.library.db import collect_tree_ids
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import library_db
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentsMovedToFolderResponse,
|
||||
ErrorResponse,
|
||||
FolderAgentSummary,
|
||||
FolderCreatedResponse,
|
||||
FolderDeletedResponse,
|
||||
FolderInfo,
|
||||
FolderListResponse,
|
||||
FolderMovedResponse,
|
||||
FolderTreeInfo,
|
||||
FolderUpdatedResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
|
||||
def _folder_to_info(
|
||||
folder: library_model.LibraryFolder,
|
||||
agents: list[FolderAgentSummary] | None = None,
|
||||
) -> FolderInfo:
|
||||
"""Convert a LibraryFolder DB model to a FolderInfo response model."""
|
||||
return FolderInfo(
|
||||
id=folder.id,
|
||||
name=folder.name,
|
||||
parent_id=folder.parent_id,
|
||||
icon=folder.icon,
|
||||
color=folder.color,
|
||||
agent_count=folder.agent_count,
|
||||
subfolder_count=folder.subfolder_count,
|
||||
agents=agents,
|
||||
)
|
||||
|
||||
|
||||
def _tree_to_info(
|
||||
tree: library_model.LibraryFolderTree,
|
||||
agents_map: dict[str, list[FolderAgentSummary]] | None = None,
|
||||
) -> FolderTreeInfo:
|
||||
"""Recursively convert a LibraryFolderTree to a FolderTreeInfo response."""
|
||||
return FolderTreeInfo(
|
||||
id=tree.id,
|
||||
name=tree.name,
|
||||
parent_id=tree.parent_id,
|
||||
icon=tree.icon,
|
||||
color=tree.color,
|
||||
agent_count=tree.agent_count,
|
||||
subfolder_count=tree.subfolder_count,
|
||||
children=[_tree_to_info(child, agents_map) for child in tree.children],
|
||||
agents=agents_map.get(tree.id) if agents_map else None,
|
||||
)
|
||||
|
||||
|
||||
def _to_agent_summaries(
|
||||
raw: list[dict[str, str | None]],
|
||||
) -> list[FolderAgentSummary]:
|
||||
"""Convert raw agent dicts to typed FolderAgentSummary models."""
|
||||
return [
|
||||
FolderAgentSummary(
|
||||
id=a["id"] or "",
|
||||
name=a["name"] or "",
|
||||
description=a["description"] or "",
|
||||
)
|
||||
for a in raw
|
||||
]
|
||||
|
||||
|
||||
def _to_agent_summaries_map(
|
||||
raw: dict[str, list[dict[str, str | None]]],
|
||||
) -> dict[str, list[FolderAgentSummary]]:
|
||||
"""Convert a folder-id-keyed dict of raw agents to typed summaries."""
|
||||
return {fid: _to_agent_summaries(agents) for fid, agents in raw.items()}
|
||||
|
||||
|
||||
class CreateFolderTool(BaseTool):
|
||||
"""Tool for creating a library folder."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "create_folder"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Create a new folder in the user's library to organize agents. "
|
||||
"Optionally nest it inside an existing folder using parent_id."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Name for the new folder (max 100 chars).",
|
||||
},
|
||||
"parent_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"ID of the parent folder to nest inside. "
|
||||
"Omit to create at root level."
|
||||
),
|
||||
},
|
||||
"icon": {
|
||||
"type": "string",
|
||||
"description": "Optional icon identifier for the folder.",
|
||||
},
|
||||
"color": {
|
||||
"type": "string",
|
||||
"description": "Optional hex color code (#RRGGBB).",
|
||||
},
|
||||
},
|
||||
"required": ["name"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Create a folder with the given name and optional parent/icon/color."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
name = (kwargs.get("name") or "").strip()
|
||||
parent_id = kwargs.get("parent_id")
|
||||
icon = kwargs.get("icon")
|
||||
color = kwargs.get("color")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not name:
|
||||
return ErrorResponse(
|
||||
message="Please provide a folder name.",
|
||||
error="missing_name",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
folder = await library_db().create_folder(
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
parent_id=parent_id,
|
||||
icon=icon,
|
||||
color=color,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to create folder: {e}",
|
||||
error="create_folder_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return FolderCreatedResponse(
|
||||
message=f"Folder '{folder.name}' created successfully!",
|
||||
folder=_folder_to_info(folder),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class ListFoldersTool(BaseTool):
|
||||
"""Tool for listing library folders."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "list_folders"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"List the user's library folders. "
|
||||
"Omit parent_id to get the full folder tree. "
|
||||
"Provide parent_id to list only direct children of that folder. "
|
||||
"Set include_agents=true to also return the agents inside each folder "
|
||||
"and root-level agents not in any folder. Always set include_agents=true "
|
||||
"when the user asks about agents, wants to see what's in their folders, "
|
||||
"or mentions agents alongside folders."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"parent_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"List children of this folder. "
|
||||
"Omit to get the full folder tree."
|
||||
),
|
||||
},
|
||||
"include_agents": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to include the list of agents inside each folder. "
|
||||
"Defaults to false."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""List folders as a flat list (by parent) or full tree."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
parent_id = kwargs.get("parent_id")
|
||||
include_agents = kwargs.get("include_agents", False)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
try:
|
||||
if parent_id:
|
||||
folders = await library_db().list_folders(
|
||||
user_id=user_id, parent_id=parent_id
|
||||
)
|
||||
raw_map = (
|
||||
await library_db().get_folder_agents_map(
|
||||
user_id, [f.id for f in folders]
|
||||
)
|
||||
if include_agents
|
||||
else None
|
||||
)
|
||||
agents_map = _to_agent_summaries_map(raw_map) if raw_map else None
|
||||
return FolderListResponse(
|
||||
message=f"Found {len(folders)} folder(s).",
|
||||
folders=[
|
||||
_folder_to_info(f, agents_map.get(f.id) if agents_map else None)
|
||||
for f in folders
|
||||
],
|
||||
count=len(folders),
|
||||
session_id=session_id,
|
||||
)
|
||||
else:
|
||||
tree = await library_db().get_folder_tree(user_id=user_id)
|
||||
all_ids = collect_tree_ids(tree)
|
||||
agents_map = None
|
||||
root_agents = None
|
||||
if include_agents:
|
||||
raw_map = await library_db().get_folder_agents_map(user_id, all_ids)
|
||||
agents_map = _to_agent_summaries_map(raw_map)
|
||||
root_agents = _to_agent_summaries(
|
||||
await library_db().get_root_agent_summaries(user_id)
|
||||
)
|
||||
return FolderListResponse(
|
||||
message=f"Found {len(all_ids)} folder(s) in your library.",
|
||||
tree=[_tree_to_info(t, agents_map) for t in tree],
|
||||
root_agents=root_agents,
|
||||
count=len(all_ids),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to list folders: {e}",
|
||||
error="list_folders_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class UpdateFolderTool(BaseTool):
|
||||
"""Tool for updating a folder's properties."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "update_folder"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Update a folder's name, icon, or color."
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": "ID of the folder to update.",
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "New name for the folder.",
|
||||
},
|
||||
"icon": {
|
||||
"type": "string",
|
||||
"description": "New icon identifier.",
|
||||
},
|
||||
"color": {
|
||||
"type": "string",
|
||||
"description": "New hex color code (#RRGGBB).",
|
||||
},
|
||||
},
|
||||
"required": ["folder_id"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Update a folder's name, icon, or color."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
folder_id = (kwargs.get("folder_id") or "").strip()
|
||||
name = kwargs.get("name")
|
||||
icon = kwargs.get("icon")
|
||||
color = kwargs.get("color")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not folder_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide a folder_id.",
|
||||
error="missing_folder_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
folder = await library_db().update_folder(
|
||||
folder_id=folder_id,
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
icon=icon,
|
||||
color=color,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to update folder: {e}",
|
||||
error="update_folder_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return FolderUpdatedResponse(
|
||||
message=f"Folder updated to '{folder.name}'.",
|
||||
folder=_folder_to_info(folder),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class MoveFolderTool(BaseTool):
|
||||
"""Tool for moving a folder to a new parent."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "move_folder"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Move a folder to a different parent folder. "
|
||||
"Set target_parent_id to null to move to root level."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": "ID of the folder to move.",
|
||||
},
|
||||
"target_parent_id": {
|
||||
"type": ["string", "null"],
|
||||
"description": (
|
||||
"ID of the new parent folder. "
|
||||
"Use null to move to root level."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["folder_id"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Move a folder to a new parent or to root level."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
folder_id = (kwargs.get("folder_id") or "").strip()
|
||||
target_parent_id = kwargs.get("target_parent_id")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not folder_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide a folder_id.",
|
||||
error="missing_folder_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
folder = await library_db().move_folder(
|
||||
folder_id=folder_id,
|
||||
user_id=user_id,
|
||||
target_parent_id=target_parent_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to move folder: {e}",
|
||||
error="move_folder_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
dest = "a subfolder" if target_parent_id else "root level"
|
||||
return FolderMovedResponse(
|
||||
message=f"Folder '{folder.name}' moved to {dest}.",
|
||||
folder=_folder_to_info(folder),
|
||||
target_parent_id=target_parent_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class DeleteFolderTool(BaseTool):
|
||||
"""Tool for deleting a folder."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "delete_folder"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Delete a folder from the user's library. "
|
||||
"Agents inside the folder are moved to root level (not deleted)."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": "ID of the folder to delete.",
|
||||
},
|
||||
},
|
||||
"required": ["folder_id"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Soft-delete a folder; agents inside are moved to root level."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
folder_id = (kwargs.get("folder_id") or "").strip()
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not folder_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide a folder_id.",
|
||||
error="missing_folder_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
await library_db().delete_folder(
|
||||
folder_id=folder_id,
|
||||
user_id=user_id,
|
||||
soft_delete=True,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to delete folder: {e}",
|
||||
error="delete_folder_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return FolderDeletedResponse(
|
||||
message="Folder deleted. Any agents inside were moved to root level.",
|
||||
folder_id=folder_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class MoveAgentsToFolderTool(BaseTool):
|
||||
"""Tool for moving agents into a folder."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "move_agents_to_folder"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Move one or more agents to a folder. "
|
||||
"Set folder_id to null to move agents to root level."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List of library agent IDs to move.",
|
||||
},
|
||||
"folder_id": {
|
||||
"type": ["string", "null"],
|
||||
"description": (
|
||||
"Target folder ID. Use null to move to root level."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["agent_ids"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Move one or more agents to a folder or to root level."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
agent_ids = kwargs.get("agent_ids", [])
|
||||
folder_id = kwargs.get("folder_id")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_ids:
|
||||
return ErrorResponse(
|
||||
message="Please provide at least one agent ID.",
|
||||
error="missing_agent_ids",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
moved = await library_db().bulk_move_agents_to_folder(
|
||||
agent_ids=agent_ids,
|
||||
folder_id=folder_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to move agents: {e}",
|
||||
error="move_agents_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
moved_ids = [a.id for a in moved]
|
||||
agent_names = [a.name for a in moved]
|
||||
dest = "the folder" if folder_id else "root level"
|
||||
names_str = (
|
||||
", ".join(agent_names) if agent_names else f"{len(agent_ids)} agent(s)"
|
||||
)
|
||||
return AgentsMovedToFolderResponse(
|
||||
message=f"Moved {names_str} to {dest}.",
|
||||
agent_ids=moved_ids,
|
||||
agent_names=agent_names,
|
||||
folder_id=folder_id,
|
||||
count=len(moved),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -0,0 +1,455 @@
|
||||
"""Tests for folder management copilot tools."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.library import model as library_model
|
||||
from backend.copilot.tools.manage_folders import (
|
||||
CreateFolderTool,
|
||||
DeleteFolderTool,
|
||||
ListFoldersTool,
|
||||
MoveAgentsToFolderTool,
|
||||
MoveFolderTool,
|
||||
UpdateFolderTool,
|
||||
)
|
||||
from backend.copilot.tools.models import (
|
||||
AgentsMovedToFolderResponse,
|
||||
ErrorResponse,
|
||||
FolderCreatedResponse,
|
||||
FolderDeletedResponse,
|
||||
FolderListResponse,
|
||||
FolderMovedResponse,
|
||||
FolderUpdatedResponse,
|
||||
)
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-folders"
|
||||
_NOW = datetime.now(UTC)
|
||||
|
||||
|
||||
def _make_folder(
|
||||
id: str = "folder-1",
|
||||
name: str = "My Folder",
|
||||
parent_id: str | None = None,
|
||||
icon: str | None = None,
|
||||
color: str | None = None,
|
||||
agent_count: int = 0,
|
||||
subfolder_count: int = 0,
|
||||
) -> library_model.LibraryFolder:
|
||||
return library_model.LibraryFolder(
|
||||
id=id,
|
||||
user_id=_TEST_USER_ID,
|
||||
name=name,
|
||||
icon=icon,
|
||||
color=color,
|
||||
parent_id=parent_id,
|
||||
created_at=_NOW,
|
||||
updated_at=_NOW,
|
||||
agent_count=agent_count,
|
||||
subfolder_count=subfolder_count,
|
||||
)
|
||||
|
||||
|
||||
def _make_tree(
|
||||
id: str = "folder-1",
|
||||
name: str = "Root",
|
||||
children: list[library_model.LibraryFolderTree] | None = None,
|
||||
) -> library_model.LibraryFolderTree:
|
||||
return library_model.LibraryFolderTree(
|
||||
id=id,
|
||||
user_id=_TEST_USER_ID,
|
||||
name=name,
|
||||
created_at=_NOW,
|
||||
updated_at=_NOW,
|
||||
children=children or [],
|
||||
)
|
||||
|
||||
|
||||
def _make_library_agent(id: str = "agent-1", name: str = "Test Agent"):
|
||||
agent = MagicMock()
|
||||
agent.id = id
|
||||
agent.name = name
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session():
|
||||
return make_session(_TEST_USER_ID)
|
||||
|
||||
|
||||
# ── CreateFolderTool ──
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_tool():
|
||||
return CreateFolderTool()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_folder_missing_name(create_tool, session):
|
||||
result = await create_tool._execute(user_id=_TEST_USER_ID, session=session, name="")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_name"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_folder_none_name(create_tool, session):
|
||||
result = await create_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, name=None
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_name"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_folder_success(create_tool, session):
|
||||
folder = _make_folder(name="New Folder")
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.create_folder = AsyncMock(return_value=folder)
|
||||
result = await create_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, name="New Folder"
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderCreatedResponse)
|
||||
assert result.folder.name == "New Folder"
|
||||
assert "New Folder" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_folder_db_error(create_tool, session):
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.create_folder = AsyncMock(
|
||||
side_effect=Exception("db down")
|
||||
)
|
||||
result = await create_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, name="Folder"
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "create_folder_failed"
|
||||
|
||||
|
||||
# ── ListFoldersTool ──
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def list_tool():
|
||||
return ListFoldersTool()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_folders_by_parent(list_tool, session):
|
||||
folders = [_make_folder(id="f1", name="A"), _make_folder(id="f2", name="B")]
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.list_folders = AsyncMock(return_value=folders)
|
||||
result = await list_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, parent_id="parent-1"
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderListResponse)
|
||||
assert result.count == 2
|
||||
assert len(result.folders) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_folders_tree(list_tool, session):
|
||||
tree = [
|
||||
_make_tree(id="r1", name="Root", children=[_make_tree(id="c1", name="Child")])
|
||||
]
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.get_folder_tree = AsyncMock(return_value=tree)
|
||||
result = await list_tool._execute(user_id=_TEST_USER_ID, session=session)
|
||||
|
||||
assert isinstance(result, FolderListResponse)
|
||||
assert result.count == 2 # root + child
|
||||
assert result.tree is not None
|
||||
assert len(result.tree) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_folders_tree_with_agents_includes_root(list_tool, session):
|
||||
tree = [_make_tree(id="r1", name="Root")]
|
||||
raw_map = {"r1": [{"id": "a1", "name": "Foldered", "description": "In folder"}]}
|
||||
root_raw = [{"id": "a2", "name": "Loose Agent", "description": "At root"}]
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.get_folder_tree = AsyncMock(return_value=tree)
|
||||
mock_lib.return_value.get_folder_agents_map = AsyncMock(return_value=raw_map)
|
||||
mock_lib.return_value.get_root_agent_summaries = AsyncMock(
|
||||
return_value=root_raw
|
||||
)
|
||||
result = await list_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, include_agents=True
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderListResponse)
|
||||
assert result.root_agents is not None
|
||||
assert len(result.root_agents) == 1
|
||||
assert result.root_agents[0].name == "Loose Agent"
|
||||
assert result.tree is not None
|
||||
assert result.tree[0].agents is not None
|
||||
assert result.tree[0].agents[0].name == "Foldered"
|
||||
mock_lib.return_value.get_root_agent_summaries.assert_awaited_once_with(
|
||||
_TEST_USER_ID
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_folders_tree_without_agents_no_root(list_tool, session):
|
||||
tree = [_make_tree(id="r1", name="Root")]
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.get_folder_tree = AsyncMock(return_value=tree)
|
||||
result = await list_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, include_agents=False
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderListResponse)
|
||||
assert result.root_agents is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_folders_db_error(list_tool, session):
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.get_folder_tree = AsyncMock(
|
||||
side_effect=Exception("timeout")
|
||||
)
|
||||
result = await list_tool._execute(user_id=_TEST_USER_ID, session=session)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "list_folders_failed"
|
||||
|
||||
|
||||
# ── UpdateFolderTool ──
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def update_tool():
|
||||
return UpdateFolderTool()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_folder_missing_id(update_tool, session):
|
||||
result = await update_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id=""
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_folder_id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_folder_none_id(update_tool, session):
|
||||
result = await update_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id=None
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_folder_id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_folder_success(update_tool, session):
|
||||
folder = _make_folder(name="Renamed")
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.update_folder = AsyncMock(return_value=folder)
|
||||
result = await update_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id="folder-1", name="Renamed"
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderUpdatedResponse)
|
||||
assert result.folder.name == "Renamed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_folder_db_error(update_tool, session):
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.update_folder = AsyncMock(
|
||||
side_effect=Exception("not found")
|
||||
)
|
||||
result = await update_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id="folder-1", name="X"
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "update_folder_failed"
|
||||
|
||||
|
||||
# ── MoveFolderTool ──
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def move_tool():
|
||||
return MoveFolderTool()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_folder_missing_id(move_tool, session):
|
||||
result = await move_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id=""
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_folder_id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_folder_to_parent(move_tool, session):
|
||||
folder = _make_folder(name="Moved")
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.move_folder = AsyncMock(return_value=folder)
|
||||
result = await move_tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
folder_id="folder-1",
|
||||
target_parent_id="parent-1",
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderMovedResponse)
|
||||
assert "subfolder" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_folder_to_root(move_tool, session):
|
||||
folder = _make_folder(name="Moved")
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.move_folder = AsyncMock(return_value=folder)
|
||||
result = await move_tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
folder_id="folder-1",
|
||||
target_parent_id=None,
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderMovedResponse)
|
||||
assert "root level" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_folder_db_error(move_tool, session):
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.move_folder = AsyncMock(side_effect=Exception("circular"))
|
||||
result = await move_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id="folder-1"
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "move_folder_failed"
|
||||
|
||||
|
||||
# ── DeleteFolderTool ──
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def delete_tool():
|
||||
return DeleteFolderTool()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_folder_missing_id(delete_tool, session):
|
||||
result = await delete_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id=""
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_folder_id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_folder_success(delete_tool, session):
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.delete_folder = AsyncMock(return_value=None)
|
||||
result = await delete_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id="folder-1"
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderDeletedResponse)
|
||||
assert result.folder_id == "folder-1"
|
||||
assert "root level" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_folder_db_error(delete_tool, session):
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.delete_folder = AsyncMock(
|
||||
side_effect=Exception("permission denied")
|
||||
)
|
||||
result = await delete_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id="folder-1"
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "delete_folder_failed"
|
||||
|
||||
|
||||
# ── MoveAgentsToFolderTool ──
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def move_agents_tool():
|
||||
return MoveAgentsToFolderTool()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_agents_missing_ids(move_agents_tool, session):
|
||||
result = await move_agents_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, agent_ids=[]
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_agent_ids"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_agents_success(move_agents_tool, session):
|
||||
agents = [
|
||||
_make_library_agent(id="a1", name="Agent Alpha"),
|
||||
_make_library_agent(id="a2", name="Agent Beta"),
|
||||
]
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.bulk_move_agents_to_folder = AsyncMock(
|
||||
return_value=agents
|
||||
)
|
||||
result = await move_agents_tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_ids=["a1", "a2"],
|
||||
folder_id="folder-1",
|
||||
)
|
||||
|
||||
assert isinstance(result, AgentsMovedToFolderResponse)
|
||||
assert result.count == 2
|
||||
assert result.agent_names == ["Agent Alpha", "Agent Beta"]
|
||||
assert "Agent Alpha" in result.message
|
||||
assert "Agent Beta" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_agents_to_root(move_agents_tool, session):
|
||||
agents = [_make_library_agent(id="a1", name="Agent One")]
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.bulk_move_agents_to_folder = AsyncMock(
|
||||
return_value=agents
|
||||
)
|
||||
result = await move_agents_tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_ids=["a1"],
|
||||
folder_id=None,
|
||||
)
|
||||
|
||||
assert isinstance(result, AgentsMovedToFolderResponse)
|
||||
assert "root level" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_agents_db_error(move_agents_tool, session):
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.bulk_move_agents_to_folder = AsyncMock(
|
||||
side_effect=Exception("folder not found")
|
||||
)
|
||||
result = await move_agents_tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_ids=["a1"],
|
||||
folder_id="bad-folder",
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "move_agents_failed"
|
||||
@@ -55,6 +55,13 @@ class ResponseType(str, Enum):
|
||||
# MCP tool types
|
||||
MCP_TOOLS_DISCOVERED = "mcp_tools_discovered"
|
||||
MCP_TOOL_OUTPUT = "mcp_tool_output"
|
||||
# Folder management types
|
||||
FOLDER_CREATED = "folder_created"
|
||||
FOLDER_LIST = "folder_list"
|
||||
FOLDER_UPDATED = "folder_updated"
|
||||
FOLDER_MOVED = "folder_moved"
|
||||
FOLDER_DELETED = "folder_deleted"
|
||||
AGENTS_MOVED_TO_FOLDER = "agents_moved_to_folder"
|
||||
|
||||
|
||||
# Base response model
|
||||
@@ -539,3 +546,82 @@ class BrowserScreenshotResponse(ToolResponseBase):
|
||||
type: ResponseType = ResponseType.BROWSER_SCREENSHOT
|
||||
file_id: str # Workspace file ID — use read_workspace_file to retrieve
|
||||
filename: str
|
||||
|
||||
|
||||
# Folder management models
|
||||
|
||||
|
||||
class FolderAgentSummary(BaseModel):
|
||||
"""Lightweight agent info for folder listings."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str = ""
|
||||
|
||||
|
||||
class FolderInfo(BaseModel):
|
||||
"""Information about a folder."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
parent_id: str | None = None
|
||||
icon: str | None = None
|
||||
color: str | None = None
|
||||
agent_count: int = 0
|
||||
subfolder_count: int = 0
|
||||
agents: list[FolderAgentSummary] | None = None
|
||||
|
||||
|
||||
class FolderTreeInfo(FolderInfo):
|
||||
"""Folder with nested children for tree display."""
|
||||
|
||||
children: list["FolderTreeInfo"] = []
|
||||
|
||||
|
||||
class FolderCreatedResponse(ToolResponseBase):
|
||||
"""Response when a folder is created."""
|
||||
|
||||
type: ResponseType = ResponseType.FOLDER_CREATED
|
||||
folder: FolderInfo
|
||||
|
||||
|
||||
class FolderListResponse(ToolResponseBase):
|
||||
"""Response for listing folders."""
|
||||
|
||||
type: ResponseType = ResponseType.FOLDER_LIST
|
||||
folders: list[FolderInfo] = Field(default_factory=list)
|
||||
tree: list[FolderTreeInfo] | None = None
|
||||
root_agents: list[FolderAgentSummary] | None = None
|
||||
count: int = 0
|
||||
|
||||
|
||||
class FolderUpdatedResponse(ToolResponseBase):
|
||||
"""Response when a folder is updated."""
|
||||
|
||||
type: ResponseType = ResponseType.FOLDER_UPDATED
|
||||
folder: FolderInfo
|
||||
|
||||
|
||||
class FolderMovedResponse(ToolResponseBase):
|
||||
"""Response when a folder is moved."""
|
||||
|
||||
type: ResponseType = ResponseType.FOLDER_MOVED
|
||||
folder: FolderInfo
|
||||
target_parent_id: str | None = None
|
||||
|
||||
|
||||
class FolderDeletedResponse(ToolResponseBase):
|
||||
"""Response when a folder is deleted."""
|
||||
|
||||
type: ResponseType = ResponseType.FOLDER_DELETED
|
||||
folder_id: str
|
||||
|
||||
|
||||
class AgentsMovedToFolderResponse(ToolResponseBase):
|
||||
"""Response when agents are moved to a folder."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENTS_MOVED_TO_FOLDER
|
||||
agent_ids: list[str]
|
||||
agent_names: list[str] = []
|
||||
folder_id: str | None = None
|
||||
count: int = 0
|
||||
|
||||
@@ -53,11 +53,15 @@ class RunMCPToolTool(BaseTool):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Connect to an MCP (Model Context Protocol) server to discover and execute its tools. "
|
||||
"Call with just `server_url` to see available tools. "
|
||||
"Then call again with `server_url`, `tool_name`, and `tool_arguments` to execute. "
|
||||
"If the server requires authentication, the user will be prompted to connect it. "
|
||||
"Find MCP servers at https://registry.modelcontextprotocol.io/ — hundreds of integrations "
|
||||
"including GitHub, Postgres, Slack, filesystem, and more."
|
||||
"Two-step workflow: (1) Call with just `server_url` to discover available tools. "
|
||||
"(2) Call again with `server_url`, `tool_name`, and `tool_arguments` to execute. "
|
||||
"Known hosted servers (use directly): Notion (https://mcp.notion.com/mcp), "
|
||||
"Linear (https://mcp.linear.app/mcp), Stripe (https://mcp.stripe.com), "
|
||||
"Intercom (https://mcp.intercom.com/mcp), Cloudflare (https://mcp.cloudflare.com/mcp), "
|
||||
"Atlassian/Jira (https://mcp.atlassian.com/mcp). "
|
||||
"For other services, search the MCP registry at https://registry.modelcontextprotocol.io/. "
|
||||
"Authentication: If the server requires credentials, user will be prompted to complete the MCP credential setup flow."
|
||||
"Once connected and user confirms, retry the same call immediately."
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -4,11 +4,20 @@ from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cas
|
||||
|
||||
from backend.api.features.library.db import (
|
||||
add_store_agent_to_library,
|
||||
bulk_move_agents_to_folder,
|
||||
create_folder,
|
||||
create_graph_in_library,
|
||||
create_library_agent,
|
||||
delete_folder,
|
||||
get_folder_agents_map,
|
||||
get_folder_tree,
|
||||
get_library_agent,
|
||||
get_library_agent_by_graph_id,
|
||||
get_root_agent_summaries,
|
||||
list_folders,
|
||||
list_library_agents,
|
||||
move_folder,
|
||||
update_folder,
|
||||
update_graph_in_library,
|
||||
)
|
||||
from backend.api.features.store.db import (
|
||||
@@ -82,6 +91,7 @@ from backend.data.notifications import (
|
||||
from backend.data.onboarding import increment_onboarding_runs
|
||||
from backend.data.understanding import (
|
||||
get_business_understanding,
|
||||
update_business_understanding_prompts,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
from backend.data.user import (
|
||||
@@ -260,6 +270,16 @@ class DatabaseManager(AppService):
|
||||
update_graph_in_library = _(update_graph_in_library)
|
||||
validate_graph_execution_permissions = _(validate_graph_execution_permissions)
|
||||
|
||||
create_folder = _(create_folder)
|
||||
list_folders = _(list_folders)
|
||||
get_folder_tree = _(get_folder_tree)
|
||||
update_folder = _(update_folder)
|
||||
move_folder = _(move_folder)
|
||||
delete_folder = _(delete_folder)
|
||||
bulk_move_agents_to_folder = _(bulk_move_agents_to_folder)
|
||||
get_folder_agents_map = _(get_folder_agents_map)
|
||||
get_root_agent_summaries = _(get_root_agent_summaries)
|
||||
|
||||
# ============ Onboarding ============ #
|
||||
increment_onboarding_runs = _(increment_onboarding_runs)
|
||||
|
||||
@@ -292,6 +312,7 @@ class DatabaseManager(AppService):
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = _(get_business_understanding)
|
||||
update_business_understanding_prompts = _(update_business_understanding_prompts)
|
||||
upsert_business_understanding = _(upsert_business_understanding)
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
@@ -360,6 +381,11 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
backfill_missing_embeddings = _(d.backfill_missing_embeddings)
|
||||
cleanup_orphaned_embeddings = _(d.cleanup_orphaned_embeddings)
|
||||
|
||||
# Understanding
|
||||
get_business_understanding = _(d.get_business_understanding)
|
||||
update_business_understanding_prompts = _(d.update_business_understanding_prompts)
|
||||
upsert_business_understanding = _(d.upsert_business_understanding)
|
||||
|
||||
|
||||
class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
@@ -434,6 +460,17 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
update_graph_in_library = d.update_graph_in_library
|
||||
validate_graph_execution_permissions = d.validate_graph_execution_permissions
|
||||
|
||||
# ============ Library Folders ============ #
|
||||
create_folder = d.create_folder
|
||||
list_folders = d.list_folders
|
||||
get_folder_tree = d.get_folder_tree
|
||||
update_folder = d.update_folder
|
||||
move_folder = d.move_folder
|
||||
delete_folder = d.delete_folder
|
||||
bulk_move_agents_to_folder = d.bulk_move_agents_to_folder
|
||||
get_folder_agents_map = d.get_folder_agents_map
|
||||
get_root_agent_summaries = d.get_root_agent_summaries
|
||||
|
||||
# ============ Onboarding ============ #
|
||||
increment_onboarding_runs = d.increment_onboarding_runs
|
||||
|
||||
@@ -463,6 +500,7 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = d.get_business_understanding
|
||||
update_business_understanding_prompts = d.update_business_understanding_prompts
|
||||
upsert_business_understanding = d.upsert_business_understanding
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
|
||||
@@ -344,7 +344,7 @@ class GraphExecution(GraphExecutionMeta):
|
||||
),
|
||||
**{
|
||||
# input from webhook-triggered block
|
||||
"payload": exec.input_data["payload"]
|
||||
"payload": exec.input_data.get("payload")
|
||||
for exec in complete_node_executions
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import fastapi.exceptions
|
||||
import prisma
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
@@ -250,8 +251,8 @@ async def test_clean_graph(server: SpinTestServer):
|
||||
"_test_id": "node_with_secrets",
|
||||
"input": "normal_value",
|
||||
"control_test_input": "should be preserved",
|
||||
"api_key": "secret_api_key_123", # Should be filtered
|
||||
"password": "secret_password_456", # Should be filtered
|
||||
"api_key": "secret_api_key_123", # Should be filtered # pragma: allowlist secret # noqa
|
||||
"password": "secret_password_456", # Should be filtered # pragma: allowlist secret # noqa
|
||||
"token": "secret_token_789", # Should be filtered
|
||||
"credentials": { # Should be filtered
|
||||
"id": "fake-github-credentials-id",
|
||||
@@ -354,9 +355,24 @@ async def test_access_store_listing_graph(server: SpinTestServer):
|
||||
create_graph, DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
# Ensure the default user has a Profile (required for store submissions)
|
||||
existing_profile = await prisma.models.Profile.prisma().find_first(
|
||||
where={"userId": DEFAULT_USER_ID}
|
||||
)
|
||||
if not existing_profile:
|
||||
await prisma.models.Profile.prisma().create(
|
||||
data=prisma.types.ProfileCreateInput(
|
||||
userId=DEFAULT_USER_ID,
|
||||
name="Default User",
|
||||
username=f"default-user-{DEFAULT_USER_ID[:8]}",
|
||||
description="Default test user profile",
|
||||
links=[],
|
||||
)
|
||||
)
|
||||
|
||||
store_submission_request = store.StoreSubmissionRequest(
|
||||
agent_id=created_graph.id,
|
||||
agent_version=created_graph.version,
|
||||
graph_id=created_graph.id,
|
||||
graph_version=created_graph.version,
|
||||
slug=created_graph.id,
|
||||
name="Test name",
|
||||
sub_heading="Test sub heading",
|
||||
@@ -385,8 +401,8 @@ async def test_access_store_listing_graph(server: SpinTestServer):
|
||||
assert False, "Failed to create store listing"
|
||||
|
||||
slv_id = (
|
||||
store_listing.store_listing_version_id
|
||||
if store_listing.store_listing_version_id is not None
|
||||
store_listing.listing_version_id
|
||||
if store_listing.listing_version_id is not None
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,14 @@ from prisma.types import (
|
||||
)
|
||||
|
||||
# from backend.notifications.models import NotificationEvent
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
EmailStr,
|
||||
Field,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from backend.util.exceptions import DatabaseError
|
||||
from backend.util.json import SafeJson
|
||||
@@ -175,10 +182,26 @@ class RefundRequestData(BaseNotificationData):
|
||||
balance: int
|
||||
|
||||
|
||||
class AgentApprovalData(BaseNotificationData):
|
||||
class _LegacyAgentFieldsMixin:
|
||||
"""Temporary patch to handle existing queued payloads"""
|
||||
|
||||
# FIXME: remove in next release
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _map_legacy_agent_fields(cls, values: Any):
|
||||
if isinstance(values, dict):
|
||||
if "graph_id" not in values and "agent_id" in values:
|
||||
values["graph_id"] = values.pop("agent_id")
|
||||
if "graph_version" not in values and "agent_version" in values:
|
||||
values["graph_version"] = values.pop("agent_version")
|
||||
return values
|
||||
|
||||
|
||||
class AgentApprovalData(_LegacyAgentFieldsMixin, BaseNotificationData):
|
||||
agent_name: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
reviewer_name: str
|
||||
reviewer_email: str
|
||||
comments: str
|
||||
@@ -193,10 +216,10 @@ class AgentApprovalData(BaseNotificationData):
|
||||
return value
|
||||
|
||||
|
||||
class AgentRejectionData(BaseNotificationData):
|
||||
class AgentRejectionData(_LegacyAgentFieldsMixin, BaseNotificationData):
|
||||
agent_name: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
reviewer_name: str
|
||||
reviewer_email: str
|
||||
comments: str
|
||||
|
||||
@@ -15,8 +15,8 @@ class TestAgentApprovalData:
|
||||
"""Test creating valid AgentApprovalData."""
|
||||
data = AgentApprovalData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
graph_id="test-agent-123",
|
||||
graph_version=1,
|
||||
reviewer_name="John Doe",
|
||||
reviewer_email="john@example.com",
|
||||
comments="Great agent, approved!",
|
||||
@@ -25,8 +25,8 @@ class TestAgentApprovalData:
|
||||
)
|
||||
|
||||
assert data.agent_name == "Test Agent"
|
||||
assert data.agent_id == "test-agent-123"
|
||||
assert data.agent_version == 1
|
||||
assert data.graph_id == "test-agent-123"
|
||||
assert data.graph_version == 1
|
||||
assert data.reviewer_name == "John Doe"
|
||||
assert data.reviewer_email == "john@example.com"
|
||||
assert data.comments == "Great agent, approved!"
|
||||
@@ -40,8 +40,8 @@ class TestAgentApprovalData:
|
||||
):
|
||||
AgentApprovalData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
graph_id="test-agent-123",
|
||||
graph_version=1,
|
||||
reviewer_name="John Doe",
|
||||
reviewer_email="john@example.com",
|
||||
comments="Great agent, approved!",
|
||||
@@ -53,8 +53,8 @@ class TestAgentApprovalData:
|
||||
"""Test AgentApprovalData with empty comments."""
|
||||
data = AgentApprovalData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
graph_id="test-agent-123",
|
||||
graph_version=1,
|
||||
reviewer_name="John Doe",
|
||||
reviewer_email="john@example.com",
|
||||
comments="", # Empty comments
|
||||
@@ -72,8 +72,8 @@ class TestAgentRejectionData:
|
||||
"""Test creating valid AgentRejectionData."""
|
||||
data = AgentRejectionData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
graph_id="test-agent-123",
|
||||
graph_version=1,
|
||||
reviewer_name="Jane Doe",
|
||||
reviewer_email="jane@example.com",
|
||||
comments="Please fix the security issues before resubmitting.",
|
||||
@@ -82,8 +82,8 @@ class TestAgentRejectionData:
|
||||
)
|
||||
|
||||
assert data.agent_name == "Test Agent"
|
||||
assert data.agent_id == "test-agent-123"
|
||||
assert data.agent_version == 1
|
||||
assert data.graph_id == "test-agent-123"
|
||||
assert data.graph_version == 1
|
||||
assert data.reviewer_name == "Jane Doe"
|
||||
assert data.reviewer_email == "jane@example.com"
|
||||
assert data.comments == "Please fix the security issues before resubmitting."
|
||||
@@ -97,8 +97,8 @@ class TestAgentRejectionData:
|
||||
):
|
||||
AgentRejectionData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
graph_id="test-agent-123",
|
||||
graph_version=1,
|
||||
reviewer_name="Jane Doe",
|
||||
reviewer_email="jane@example.com",
|
||||
comments="Please fix the security issues.",
|
||||
@@ -111,8 +111,8 @@ class TestAgentRejectionData:
|
||||
long_comment = "A" * 1000 # Very long comment
|
||||
data = AgentRejectionData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
graph_id="test-agent-123",
|
||||
graph_version=1,
|
||||
reviewer_name="Jane Doe",
|
||||
reviewer_email="jane@example.com",
|
||||
comments=long_comment,
|
||||
@@ -126,8 +126,8 @@ class TestAgentRejectionData:
|
||||
"""Test that models can be serialized and deserialized."""
|
||||
original_data = AgentRejectionData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
graph_id="test-agent-123",
|
||||
graph_version=1,
|
||||
reviewer_name="Jane Doe",
|
||||
reviewer_email="jane@example.com",
|
||||
comments="Please fix the issues.",
|
||||
@@ -142,8 +142,8 @@ class TestAgentRejectionData:
|
||||
restored_data = AgentRejectionData.model_validate(data_dict)
|
||||
|
||||
assert restored_data.agent_name == original_data.agent_name
|
||||
assert restored_data.agent_id == original_data.agent_id
|
||||
assert restored_data.agent_version == original_data.agent_version
|
||||
assert restored_data.graph_id == original_data.graph_id
|
||||
assert restored_data.graph_version == original_data.graph_version
|
||||
assert restored_data.reviewer_name == original_data.reviewer_name
|
||||
assert restored_data.reviewer_email == original_data.reviewer_email
|
||||
assert restored_data.comments == original_data.comments
|
||||
|
||||
@@ -244,7 +244,10 @@ def _clean_and_split(text: str) -> list[str]:
|
||||
|
||||
|
||||
def _calculate_points(
|
||||
agent, categories: list[str], custom: list[str], integrations: list[str]
|
||||
agent: prisma.models.StoreAgent,
|
||||
categories: list[str],
|
||||
custom: list[str],
|
||||
integrations: list[str],
|
||||
) -> int:
|
||||
"""
|
||||
Calculates the total points for an agent based on the specified criteria.
|
||||
@@ -397,7 +400,7 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
storeAgents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where={
|
||||
"is_available": True,
|
||||
"useForOnboarding": True,
|
||||
"use_for_onboarding": True,
|
||||
},
|
||||
order=[
|
||||
{"featured": "desc"},
|
||||
@@ -407,7 +410,7 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
take=100,
|
||||
)
|
||||
|
||||
# If not enough agents found, relax the useForOnboarding filter
|
||||
# If not enough agents found, relax the use_for_onboarding filter
|
||||
if len(storeAgents) < 2:
|
||||
storeAgents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=prisma.types.StoreAgentWhereInput(**where_clause),
|
||||
@@ -420,7 +423,7 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
)
|
||||
|
||||
# Calculate points for the first X agents and choose the top 2
|
||||
agent_points = []
|
||||
agent_points: list[tuple[prisma.models.StoreAgent, int]] = []
|
||||
for agent in storeAgents[:POINTS_AGENT_COUNT]:
|
||||
points = _calculate_points(
|
||||
agent, categories, custom, user_onboarding.integrations
|
||||
@@ -430,28 +433,7 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
agent_points.sort(key=lambda x: x[1], reverse=True)
|
||||
recommended_agents = [agent for agent, _ in agent_points[:2]]
|
||||
|
||||
return [
|
||||
StoreAgentDetails(
|
||||
store_listing_version_id=agent.storeListingVersionId,
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_video=agent.agent_video or "",
|
||||
agent_output_demo=agent.agent_output_demo or "",
|
||||
agent_image=agent.agent_image,
|
||||
creator=agent.creator_username,
|
||||
creator_avatar=agent.creator_avatar,
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
categories=agent.categories,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
versions=agent.versions,
|
||||
agentGraphVersions=agent.agentGraphVersions,
|
||||
agentGraphId=agent.agentGraphId,
|
||||
last_updated=agent.updated_at,
|
||||
)
|
||||
for agent in recommended_agents
|
||||
]
|
||||
return [StoreAgentDetails.from_db(agent) for agent in recommended_agents]
|
||||
|
||||
|
||||
@cached(maxsize=1, ttl_seconds=300) # Cache for 5 minutes since this rarely changes
|
||||
|
||||
@@ -10,10 +10,13 @@ from openai import AsyncOpenAI
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstanding,
|
||||
BusinessUnderstandingInput,
|
||||
get_business_understanding,
|
||||
update_business_understanding_prompts,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
from backend.data.understanding_prompts import generate_understanding_prompts
|
||||
from backend.util.request import Requests
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -418,9 +421,31 @@ async def populate_understanding_from_tally(user_id: str, email: str) -> None:
|
||||
|
||||
understanding_input = await extract_business_understanding(formatted)
|
||||
|
||||
# Upsert into database
|
||||
await upsert_business_understanding(user_id, understanding_input)
|
||||
understanding = await upsert_business_understanding(
|
||||
user_id, understanding_input
|
||||
)
|
||||
await _generate_and_store_prompts(user_id, understanding)
|
||||
logger.info(f"Tally: successfully populated understanding for user {user_id}")
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Tally: error populating understanding for user {user_id}")
|
||||
|
||||
|
||||
async def _generate_and_store_prompts(
|
||||
user_id: str, understanding: BusinessUnderstanding
|
||||
) -> None:
|
||||
try:
|
||||
prompts = await generate_understanding_prompts(understanding)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Tally: skipping quick prompt generation for {user_id}: {e}")
|
||||
return
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Tally: failed to generate quick prompts for understanding {user_id}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
await update_business_understanding_prompts(user_id, prompts)
|
||||
except Exception:
|
||||
logger.exception(f"Tally: failed to store quick prompts for user {user_id}")
|
||||
|
||||
@@ -284,6 +284,7 @@ async def test_populate_understanding_full_flow():
|
||||
],
|
||||
}
|
||||
mock_input = MagicMock()
|
||||
mock_understanding = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -305,12 +306,35 @@ async def test_populate_understanding_full_flow():
|
||||
patch(
|
||||
"backend.data.tally.upsert_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_understanding,
|
||||
) as mock_upsert,
|
||||
patch(
|
||||
"backend.data.tally.generate_understanding_prompts",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[
|
||||
"Help me automate customer support",
|
||||
"Find repetitive support work",
|
||||
"Show me faster support workflows",
|
||||
],
|
||||
) as mock_generate_prompts,
|
||||
patch(
|
||||
"backend.data.tally.update_business_understanding_prompts",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update_prompts,
|
||||
):
|
||||
await populate_understanding_from_tally("user-1", "alice@example.com")
|
||||
|
||||
mock_extract.assert_awaited_once()
|
||||
mock_upsert.assert_awaited_once_with("user-1", mock_input)
|
||||
mock_generate_prompts.assert_awaited_once_with(mock_understanding)
|
||||
mock_update_prompts.assert_awaited_once_with(
|
||||
"user-1",
|
||||
[
|
||||
"Help me automate customer support",
|
||||
"Find repetitive support work",
|
||||
"Show me faster support workflows",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -352,6 +376,55 @@ async def test_populate_understanding_handles_llm_timeout():
|
||||
mock_upsert.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_understanding_keeps_understanding_when_prompt_generation_fails():
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.tally_api_key = "test-key"
|
||||
|
||||
submission = {
|
||||
"responses": [{"questionId": "q1", "value": "Alice"}],
|
||||
}
|
||||
mock_input = MagicMock()
|
||||
mock_understanding = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.tally.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(submission, SAMPLE_QUESTIONS),
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally.extract_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_input,
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally.upsert_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_understanding,
|
||||
) as mock_upsert,
|
||||
patch(
|
||||
"backend.data.tally.generate_understanding_prompts",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("bad prompts"),
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally.update_business_understanding_prompts",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update_prompts,
|
||||
):
|
||||
await populate_understanding_from_tally("user-1", "alice@example.com")
|
||||
|
||||
mock_upsert.assert_awaited_once_with("user-1", mock_input)
|
||||
mock_update_prompts.assert_not_awaited()
|
||||
|
||||
|
||||
# ── _mask_email ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
@@ -118,6 +118,7 @@ class BusinessUnderstanding(pydantic.BaseModel):
|
||||
# Current tools
|
||||
current_software: list[str] = pydantic.Field(default_factory=list)
|
||||
existing_automation: list[str] = pydantic.Field(default_factory=list)
|
||||
prompts: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
# Additional context
|
||||
additional_notes: Optional[str] = None
|
||||
@@ -148,6 +149,7 @@ class BusinessUnderstanding(pydantic.BaseModel):
|
||||
automation_goals=_json_to_list(business.get("automation_goals")),
|
||||
current_software=_json_to_list(business.get("current_software")),
|
||||
existing_automation=_json_to_list(business.get("existing_automation")),
|
||||
prompts=_json_to_list(business.get("prompts")),
|
||||
additional_notes=business.get("additional_notes"),
|
||||
)
|
||||
|
||||
@@ -313,6 +315,40 @@ async def upsert_business_understanding(
|
||||
return understanding
|
||||
|
||||
|
||||
async def update_business_understanding_prompts(
|
||||
user_id: str, prompts: list[str]
|
||||
) -> Optional[BusinessUnderstanding]:
|
||||
"""Update derived quick prompts for an existing business understanding."""
|
||||
existing = await CoPilotUnderstanding.prisma().find_unique(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
if existing is None:
|
||||
return None
|
||||
|
||||
existing_data: dict[str, Any] = {}
|
||||
if isinstance(existing.data, dict):
|
||||
existing_data = dict(existing.data)
|
||||
|
||||
existing_business: dict[str, Any] = {}
|
||||
if isinstance(existing_data.get("business"), dict):
|
||||
existing_business = dict(existing_data["business"])
|
||||
|
||||
existing_business["prompts"] = prompts
|
||||
existing_business["version"] = 1
|
||||
existing_data["business"] = existing_business
|
||||
|
||||
record = await CoPilotUnderstanding.prisma().update(
|
||||
where={"userId": user_id},
|
||||
data={"data": SafeJson(existing_data)},
|
||||
)
|
||||
if record is None:
|
||||
return None
|
||||
|
||||
understanding = BusinessUnderstanding.from_db(record)
|
||||
await _set_cache(user_id, understanding)
|
||||
return understanding
|
||||
|
||||
|
||||
async def clear_business_understanding(user_id: str) -> bool:
|
||||
"""Clear/delete business understanding for a user from both DB and cache."""
|
||||
# Delete from cache first
|
||||
|
||||
115
autogpt_platform/backend/backend/data/understanding_prompts.py
Normal file
115
autogpt_platform/backend/backend/data/understanding_prompts.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Helpers for generating quick prompts from saved business understanding."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstanding,
|
||||
format_understanding_for_prompt,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_LLM_TIMEOUT = 30
|
||||
|
||||
_PROMPTS_PROMPT = """\
|
||||
You generate three short starter prompts for a user to click in a chat UI.
|
||||
|
||||
Return a JSON object with this exact shape:
|
||||
{"prompts":["...","...","..."]}
|
||||
|
||||
Requirements:
|
||||
- Exactly 3 prompts
|
||||
- Each prompt must be written in first person, as if the user is speaking
|
||||
- Each prompt must be shorter than 20 words
|
||||
- Keep them specific to the user's business context
|
||||
- Do not number the prompts
|
||||
- Do not add labels or explanations
|
||||
|
||||
Business context:
|
||||
"""
|
||||
|
||||
_PROMPTS_SUFFIX = "\n\nReturn ONLY valid JSON."
|
||||
|
||||
|
||||
def has_prompt_generation_context(understanding: BusinessUnderstanding) -> bool:
|
||||
return bool(format_understanding_for_prompt(understanding).strip())
|
||||
|
||||
|
||||
def _normalize_prompt(prompt: str) -> str:
|
||||
return " ".join(prompt.split())
|
||||
|
||||
|
||||
def _validate_prompts(value: object) -> list[str]:
|
||||
if not isinstance(value, list) or len(value) != 3:
|
||||
raise ValueError("Prompt response must contain exactly three prompts")
|
||||
|
||||
prompts: list[str] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
for item in value:
|
||||
if not isinstance(item, str):
|
||||
raise ValueError("Each prompt must be a string")
|
||||
|
||||
prompt = _normalize_prompt(item)
|
||||
if not prompt:
|
||||
raise ValueError("Prompts cannot be empty")
|
||||
if len(prompt.split()) >= 20:
|
||||
raise ValueError("Prompts must be fewer than 20 words")
|
||||
if prompt in seen:
|
||||
raise ValueError("Prompts must be unique")
|
||||
|
||||
seen.add(prompt)
|
||||
prompts.append(prompt)
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
async def generate_understanding_prompts(
|
||||
understanding: BusinessUnderstanding,
|
||||
) -> list[str]:
|
||||
"""Generate validated quick prompts from a saved understanding snapshot."""
|
||||
context = format_understanding_for_prompt(understanding)
|
||||
if not context.strip():
|
||||
raise ValueError("Understanding does not contain usable context")
|
||||
|
||||
settings = Settings()
|
||||
client = AsyncOpenAI(
|
||||
api_key=settings.secrets.open_router_api_key,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
client.chat.completions.create(
|
||||
model="openai/gpt-4o-mini",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{_PROMPTS_PROMPT}{context}{_PROMPTS_SUFFIX}",
|
||||
}
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.2,
|
||||
),
|
||||
timeout=_LLM_TIMEOUT,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Understanding prompts: generation timed out")
|
||||
raise
|
||||
|
||||
raw = response.choices[0].message.content or "{}"
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Understanding prompts: invalid JSON response")
|
||||
raise
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("Prompt response must be a JSON object")
|
||||
|
||||
return _validate_prompts(data.get("prompts"))
|
||||
@@ -0,0 +1,145 @@
|
||||
"""Tests for backend.data.understanding_prompts."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.understanding import BusinessUnderstanding
|
||||
from backend.data.understanding_prompts import generate_understanding_prompts
|
||||
|
||||
|
||||
def make_understanding(**overrides) -> BusinessUnderstanding:
|
||||
data = {
|
||||
"id": "understanding-1",
|
||||
"user_id": "user-1",
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
"industry": "Customer support",
|
||||
"pain_points": ["manual ticket triage"],
|
||||
"automation_goals": ["speed up support responses"],
|
||||
}
|
||||
data.update(overrides)
|
||||
return BusinessUnderstanding(**data)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_understanding_prompts_success():
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
{
|
||||
"prompts": [
|
||||
"Help me automate customer support triage",
|
||||
"Show me how to speed up support replies",
|
||||
"Find repetitive work in our support process",
|
||||
]
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch(
|
||||
"backend.data.understanding_prompts.AsyncOpenAI", return_value=mock_client
|
||||
):
|
||||
prompts = await generate_understanding_prompts(make_understanding())
|
||||
|
||||
assert prompts == [
|
||||
"Help me automate customer support triage",
|
||||
"Show me how to speed up support replies",
|
||||
"Find repetitive work in our support process",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_understanding_prompts_rejects_duplicates():
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
{
|
||||
"prompts": [
|
||||
"Help me automate customer support",
|
||||
"Help me automate customer support",
|
||||
"Find repetitive support work",
|
||||
]
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.understanding_prompts.AsyncOpenAI", return_value=mock_client
|
||||
),
|
||||
pytest.raises(ValueError, match="unique"),
|
||||
):
|
||||
await generate_understanding_prompts(make_understanding())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_understanding_prompts_rejects_long_prompt():
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
{
|
||||
"prompts": [
|
||||
"Please help me automate every part of our customer support workflow starting with ticket triage routing follow-up escalation and reporting today",
|
||||
"Show me better support workflows",
|
||||
"Find support busywork for me",
|
||||
]
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.understanding_prompts.AsyncOpenAI", return_value=mock_client
|
||||
),
|
||||
pytest.raises(ValueError, match="fewer than 20 words"),
|
||||
):
|
||||
await generate_understanding_prompts(make_understanding())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_understanding_prompts_rejects_invalid_shape():
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
{"prompts": ["Help me automate support", "Find repetitive work"]}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.understanding_prompts.AsyncOpenAI", return_value=mock_client
|
||||
),
|
||||
pytest.raises(ValueError, match="exactly three prompts"),
|
||||
):
|
||||
await generate_understanding_prompts(make_understanding())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_understanding_prompts_timeout():
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.side_effect = asyncio.TimeoutError()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.understanding_prompts.AsyncOpenAI", return_value=mock_client
|
||||
),
|
||||
patch("backend.data.understanding_prompts._LLM_TIMEOUT", 0.001),
|
||||
pytest.raises(asyncio.TimeoutError),
|
||||
):
|
||||
await generate_understanding_prompts(make_understanding())
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import fastapi.responses
|
||||
import prisma
|
||||
import pytest
|
||||
|
||||
import backend.api.features.library.model
|
||||
@@ -497,9 +498,24 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
|
||||
# Ensure the test user has a Profile (required for store submissions)
|
||||
existing_profile = await prisma.models.Profile.prisma().find_first(
|
||||
where={"userId": test_user.id}
|
||||
)
|
||||
if not existing_profile:
|
||||
await prisma.models.Profile.prisma().create(
|
||||
data=prisma.types.ProfileCreateInput(
|
||||
userId=test_user.id,
|
||||
name=test_user.name or "Test User",
|
||||
username=f"test-user-{test_user.id[:8]}",
|
||||
description="Test user profile",
|
||||
links=[],
|
||||
)
|
||||
)
|
||||
|
||||
store_submission_request = backend.api.features.store.model.StoreSubmissionRequest(
|
||||
agent_id=test_graph.id,
|
||||
agent_version=test_graph.version,
|
||||
graph_id=test_graph.id,
|
||||
graph_version=test_graph.version,
|
||||
slug=test_graph.id,
|
||||
name="Test name",
|
||||
sub_heading="Test sub heading",
|
||||
@@ -517,8 +533,8 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
assert False, "Failed to create store listing"
|
||||
|
||||
slv_id = (
|
||||
store_listing.store_listing_version_id
|
||||
if store_listing.store_listing_version_id is not None
|
||||
store_listing.listing_version_id
|
||||
if store_listing.listing_version_id is not None
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from backend.data import graph as graph_db
|
||||
from backend.data import human_review as human_review_db
|
||||
from backend.data import onboarding as onboarding_db
|
||||
from backend.data import user as user_db
|
||||
from backend.data import workspace as workspace_db
|
||||
|
||||
# Import dynamic field utilities from centralized location
|
||||
from backend.data.block import BlockInput, BlockOutputEntry
|
||||
@@ -32,7 +33,6 @@ from backend.data.execution import (
|
||||
from backend.data.graph import GraphModel, Node
|
||||
from backend.data.model import USER_TIMEZONE_NOT_SET, CredentialsMetaInput, GraphInput
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
get_async_execution_queue,
|
||||
@@ -481,6 +481,22 @@ async def _construct_starting_node_execution_input(
|
||||
if nodes_input_masks and (node_input_mask := nodes_input_masks.get(node.id)):
|
||||
input_data.update(node_input_mask)
|
||||
|
||||
# Webhook-triggered agents cannot be executed directly without payload data.
|
||||
# Legitimate webhook triggers provide payload via nodes_input_masks above.
|
||||
if (
|
||||
block.block_type
|
||||
in (
|
||||
BlockType.WEBHOOK,
|
||||
BlockType.WEBHOOK_MANUAL,
|
||||
)
|
||||
and "payload" not in input_data
|
||||
):
|
||||
raise ValueError(
|
||||
"This agent is triggered by an external event (webhook) "
|
||||
"and cannot be executed directly. "
|
||||
"Please use the appropriate trigger to run this agent."
|
||||
)
|
||||
|
||||
input_data, error = validate_exec(node, input_data)
|
||||
if input_data is None:
|
||||
raise ValueError(error)
|
||||
@@ -831,8 +847,9 @@ async def add_graph_execution(
|
||||
udb = user_db
|
||||
gdb = graph_db
|
||||
odb = onboarding_db
|
||||
wdb = workspace_db
|
||||
else:
|
||||
edb = udb = gdb = odb = get_database_manager_async_client()
|
||||
edb = udb = gdb = odb = wdb = get_database_manager_async_client()
|
||||
|
||||
# Get or create the graph execution
|
||||
if graph_exec_id:
|
||||
@@ -892,7 +909,7 @@ async def add_graph_execution(
|
||||
if execution_context is None:
|
||||
user = await udb.get_user_by_id(user_id)
|
||||
settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id)
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
workspace = await wdb.get_or_create_workspace(user_id)
|
||||
|
||||
execution_context = ExecutionContext(
|
||||
# Execution identity
|
||||
|
||||
@@ -368,12 +368,10 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
mock_wdb = mocker.patch("backend.executor.utils.workspace_db")
|
||||
mock_workspace = mocker.MagicMock()
|
||||
mock_workspace.id = "test-workspace-id"
|
||||
mocker.patch(
|
||||
"backend.executor.utils.get_or_create_workspace",
|
||||
new=mocker.AsyncMock(return_value=mock_workspace),
|
||||
)
|
||||
mock_wdb.get_or_create_workspace = mocker.AsyncMock(return_value=mock_workspace)
|
||||
|
||||
# Setup mock returns
|
||||
# The function returns (graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip)
|
||||
@@ -649,12 +647,10 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
mock_wdb = mocker.patch("backend.executor.utils.workspace_db")
|
||||
mock_workspace = mocker.MagicMock()
|
||||
mock_workspace.id = "test-workspace-id"
|
||||
mocker.patch(
|
||||
"backend.executor.utils.get_or_create_workspace",
|
||||
new=mocker.AsyncMock(return_value=mock_workspace),
|
||||
)
|
||||
mock_wdb.get_or_create_workspace = mocker.AsyncMock(return_value=mock_workspace)
|
||||
|
||||
# Setup returns - include nodes_to_skip in the tuple
|
||||
mock_validate.return_value = (
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
{#
|
||||
Template variables:
|
||||
data.agent_name: the name of the approved agent
|
||||
data.agent_id: the ID of the agent
|
||||
data.agent_version: the version of the agent
|
||||
data.graph_id: the ID of the agent
|
||||
data.graph_version: the version of the agent
|
||||
data.reviewer_name: the name of the reviewer who approved it
|
||||
data.reviewer_email: the email of the reviewer
|
||||
data.comments: comments from the reviewer
|
||||
@@ -70,4 +70,4 @@
|
||||
Thank you for contributing to the AutoGPT ecosystem! 🚀
|
||||
</p>
|
||||
|
||||
{% endblock %}
|
||||
{% endblock %}
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
{#
|
||||
Template variables:
|
||||
data.agent_name: the name of the rejected agent
|
||||
data.agent_id: the ID of the agent
|
||||
data.agent_version: the version of the agent
|
||||
data.graph_id: the ID of the agent
|
||||
data.graph_version: the version of the agent
|
||||
data.reviewer_name: the name of the reviewer who rejected it
|
||||
data.reviewer_email: the email of the reviewer
|
||||
data.comments: comments from the reviewer explaining the rejection
|
||||
@@ -74,4 +74,4 @@
|
||||
We're excited to see your improved agent submission! 🚀
|
||||
</p>
|
||||
|
||||
{% endblock %}
|
||||
{% endblock %}
|
||||
|
||||
@@ -64,6 +64,10 @@ class GraphNotInLibraryError(GraphNotAccessibleError):
|
||||
"""Raised when attempting to execute a graph that is not / no longer in the user's library."""
|
||||
|
||||
|
||||
class PreconditionFailed(Exception):
|
||||
"""The user must do something else first before trying the current operation"""
|
||||
|
||||
|
||||
class InsufficientBalanceError(ValueError):
|
||||
user_id: str
|
||||
message: str
|
||||
|
||||
@@ -72,19 +72,58 @@ def dumps(
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@overload
|
||||
def loads(data: str | bytes, *args, target_type: Type[T], **kwargs) -> T: ...
|
||||
# Sentinel value to detect when fallback is not provided
|
||||
_NO_FALLBACK = object()
|
||||
|
||||
|
||||
@overload
|
||||
def loads(data: str | bytes, *args, **kwargs) -> Any: ...
|
||||
def loads(
|
||||
data: str | bytes, *args, target_type: Type[T], fallback: T | None = None, **kwargs
|
||||
) -> T:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def loads(data: str | bytes, *args, fallback: Any = None, **kwargs) -> Any:
|
||||
pass
|
||||
|
||||
|
||||
def loads(
|
||||
data: str | bytes, *args, target_type: Type[T] | None = None, **kwargs
|
||||
data: str | bytes,
|
||||
*args,
|
||||
target_type: Type[T] | None = None,
|
||||
fallback: Any = _NO_FALLBACK,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
parsed = orjson.loads(data)
|
||||
"""Parse JSON with optional fallback on decode errors.
|
||||
|
||||
Args:
|
||||
data: JSON string or bytes to parse
|
||||
target_type: Optional type to validate/cast result to
|
||||
fallback: Value to return on JSONDecodeError. If not provided, raises.
|
||||
**kwargs: Additional arguments (unused, for compatibility)
|
||||
|
||||
Returns:
|
||||
Parsed JSON data, or fallback value if parsing fails
|
||||
|
||||
Raises:
|
||||
orjson.JSONDecodeError: Only if fallback is not provided
|
||||
|
||||
Examples:
|
||||
>>> loads('{"valid": "json"}')
|
||||
{'valid': 'json'}
|
||||
>>> loads('invalid json', fallback=None)
|
||||
None
|
||||
>>> loads('invalid json', fallback={})
|
||||
{}
|
||||
>>> loads('invalid json') # raises orjson.JSONDecodeError
|
||||
"""
|
||||
try:
|
||||
parsed = orjson.loads(data)
|
||||
except orjson.JSONDecodeError:
|
||||
if fallback is not _NO_FALLBACK:
|
||||
return fallback
|
||||
raise
|
||||
|
||||
if target_type:
|
||||
return type_match(parsed, target_type)
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
BEGIN;
|
||||
|
||||
-- Drop illogical column StoreListing.agentGraphVersion;
|
||||
ALTER TABLE "StoreListing" DROP CONSTRAINT "StoreListing_agentGraphId_agentGraphVersion_fkey";
|
||||
DROP INDEX "StoreListing_agentGraphId_agentGraphVersion_idx";
|
||||
ALTER TABLE "StoreListing" DROP COLUMN "agentGraphVersion";
|
||||
|
||||
-- Add uniqueness constraint to Profile.userId and remove invalid data
|
||||
--
|
||||
-- Delete any profiles with null userId (which is invalid and doesn't occur in theory)
|
||||
DELETE FROM "Profile" WHERE "userId" IS NULL;
|
||||
--
|
||||
-- Delete duplicate profiles per userId, keeping the most recently updated one
|
||||
DELETE FROM "Profile"
|
||||
WHERE "id" IN (
|
||||
SELECT "id" FROM (
|
||||
SELECT "id", ROW_NUMBER() OVER (
|
||||
PARTITION BY "userId" ORDER BY "updatedAt" DESC, "id" DESC
|
||||
) AS rn
|
||||
FROM "Profile"
|
||||
) ranked
|
||||
WHERE rn > 1
|
||||
);
|
||||
--
|
||||
-- Add userId uniqueness constraint
|
||||
ALTER TABLE "Profile" ALTER COLUMN "userId" SET NOT NULL;
|
||||
CREATE UNIQUE INDEX "Profile_userId_key" ON "Profile"("userId");
|
||||
|
||||
-- Add formal relation StoreListing.owningUserId -> Profile.userId
|
||||
ALTER TABLE "StoreListing" ADD CONSTRAINT "StoreListing_owner_Profile_fkey" FOREIGN KEY ("owningUserId") REFERENCES "Profile"("userId") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
COMMIT;
|
||||
@@ -0,0 +1,219 @@
|
||||
-- Update the StoreSubmission and StoreAgent views with additional fields, clearer field names, and faster joins.
|
||||
-- Steps:
|
||||
-- 1. Update `mv_agent_run_counts` to exclude runs by the agent's creator
|
||||
-- a. Drop dependent views `StoreAgent` and `Creator`
|
||||
-- b. Update `mv_agent_run_counts` and its index
|
||||
-- c. Recreate `StoreAgent` view (with updates)
|
||||
-- d. Restore `Creator` view
|
||||
-- 2. Update `StoreSubmission` view
|
||||
-- 3. Update `StoreListingReview` indices to make `StoreSubmission` query more efficient
|
||||
|
||||
BEGIN;
|
||||
|
||||
-- Drop views that are dependent on mv_agent_run_counts
|
||||
DROP VIEW IF EXISTS "StoreAgent";
|
||||
DROP VIEW IF EXISTS "Creator";
|
||||
|
||||
-- Update materialized view for agent run counts to exclude runs by the agent's creator
|
||||
DROP INDEX IF EXISTS "idx_mv_agent_run_counts";
|
||||
DROP MATERIALIZED VIEW IF EXISTS "mv_agent_run_counts";
|
||||
CREATE MATERIALIZED VIEW "mv_agent_run_counts" AS
|
||||
SELECT
|
||||
run."agentGraphId" AS graph_id,
|
||||
COUNT(*) AS run_count
|
||||
FROM "AgentGraphExecution" run
|
||||
JOIN "AgentGraph" graph ON graph.id = run."agentGraphId"
|
||||
-- Exclude runs by the agent's creator to avoid inflating run counts
|
||||
WHERE graph."userId" != run."userId"
|
||||
GROUP BY run."agentGraphId";
|
||||
|
||||
-- Recreate index
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS "idx_mv_agent_run_counts" ON "mv_agent_run_counts"("graph_id");
|
||||
|
||||
-- Re-populate the materialized view
|
||||
REFRESH MATERIALIZED VIEW "mv_agent_run_counts";
|
||||
|
||||
|
||||
-- Recreate the StoreAgent view with the following changes
|
||||
-- (compared to 20260115210000_remove_storelistingversion_search):
|
||||
-- - Narrow to *explicitly active* version (sl.activeVersionId) instead of MAX(version)
|
||||
-- - Add `recommended_schedule_cron` column
|
||||
-- - Rename `"storeListingVersionId"` -> `listing_version_id`
|
||||
-- - Rename `"agentGraphVersions"` -> `graph_versions`
|
||||
-- - Rename `"agentGraphId"` -> `graph_id`
|
||||
-- - Rename `"useForOnboarding"` -> `use_for_onboarding`
|
||||
CREATE OR REPLACE VIEW "StoreAgent" AS
|
||||
WITH store_agent_versions AS (
|
||||
SELECT
|
||||
"storeListingId",
|
||||
array_agg(DISTINCT version::text ORDER BY version::text) AS versions
|
||||
FROM "StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
GROUP BY "storeListingId"
|
||||
),
|
||||
agent_graph_versions AS (
|
||||
SELECT
|
||||
"storeListingId",
|
||||
array_agg(DISTINCT "agentGraphVersion"::text ORDER BY "agentGraphVersion"::text) AS graph_versions
|
||||
FROM "StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
GROUP BY "storeListingId"
|
||||
)
|
||||
SELECT
|
||||
sl.id AS listing_id,
|
||||
slv.id AS listing_version_id,
|
||||
slv."createdAt" AS updated_at,
|
||||
sl.slug,
|
||||
COALESCE(slv.name, '') AS agent_name,
|
||||
slv."videoUrl" AS agent_video,
|
||||
slv."agentOutputDemoUrl" AS agent_output_demo,
|
||||
COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image,
|
||||
slv."isFeatured" AS featured,
|
||||
cp.username AS creator_username,
|
||||
cp."avatarUrl" AS creator_avatar,
|
||||
slv."subHeading" AS sub_heading,
|
||||
slv.description,
|
||||
slv.categories,
|
||||
COALESCE(arc.run_count, 0::bigint) AS runs,
|
||||
COALESCE(reviews.avg_rating, 0.0)::double precision AS rating,
|
||||
COALESCE(sav.versions, ARRAY[slv.version::text]) AS versions,
|
||||
slv."agentGraphId" AS graph_id,
|
||||
COALESCE(
|
||||
agv.graph_versions,
|
||||
ARRAY[slv."agentGraphVersion"::text]
|
||||
) AS graph_versions,
|
||||
slv."isAvailable" AS is_available,
|
||||
COALESCE(sl."useForOnboarding", false) AS use_for_onboarding,
|
||||
slv."recommendedScheduleCron" AS recommended_schedule_cron
|
||||
FROM "StoreListing" AS sl
|
||||
JOIN "StoreListingVersion" AS slv
|
||||
ON slv."storeListingId" = sl.id
|
||||
AND slv.id = sl."activeVersionId"
|
||||
AND slv."submissionStatus" = 'APPROVED'
|
||||
JOIN "AgentGraph" AS ag
|
||||
ON slv."agentGraphId" = ag.id
|
||||
AND slv."agentGraphVersion" = ag.version
|
||||
LEFT JOIN "Profile" AS cp
|
||||
ON sl."owningUserId" = cp."userId"
|
||||
LEFT JOIN "mv_review_stats" AS reviews
|
||||
ON sl.id = reviews."storeListingId"
|
||||
LEFT JOIN "mv_agent_run_counts" AS arc
|
||||
ON ag.id = arc.graph_id
|
||||
LEFT JOIN store_agent_versions AS sav
|
||||
ON sl.id = sav."storeListingId"
|
||||
LEFT JOIN agent_graph_versions AS agv
|
||||
ON sl.id = agv."storeListingId"
|
||||
WHERE sl."isDeleted" = false
|
||||
AND sl."hasApprovedVersion" = true;
|
||||
|
||||
|
||||
-- Restore Creator view as last updated in 20250604130249_optimise_store_agent_and_creator_views,
|
||||
-- with minor changes:
|
||||
-- - Ensure top_categories always TEXT[]
|
||||
-- - Filter out empty ('') categories
|
||||
CREATE OR REPLACE VIEW "Creator" AS
|
||||
WITH creator_listings AS (
|
||||
SELECT
|
||||
sl."owningUserId",
|
||||
sl.id AS listing_id,
|
||||
slv."agentGraphId",
|
||||
slv.categories,
|
||||
sr.score,
|
||||
ar.run_count
|
||||
FROM "StoreListing" sl
|
||||
JOIN "StoreListingVersion" slv
|
||||
ON slv."storeListingId" = sl.id
|
||||
AND slv."submissionStatus" = 'APPROVED'
|
||||
LEFT JOIN "StoreListingReview" sr
|
||||
ON sr."storeListingVersionId" = slv.id
|
||||
LEFT JOIN "mv_agent_run_counts" ar
|
||||
ON ar.graph_id = slv."agentGraphId"
|
||||
WHERE sl."isDeleted" = false
|
||||
AND sl."hasApprovedVersion" = true
|
||||
),
|
||||
creator_stats AS (
|
||||
SELECT
|
||||
cl."owningUserId",
|
||||
COUNT(DISTINCT cl.listing_id) AS num_agents,
|
||||
AVG(COALESCE(cl.score, 0)::numeric) AS agent_rating,
|
||||
SUM(COALESCE(cl.run_count, 0)) AS agent_runs,
|
||||
array_agg(DISTINCT cat ORDER BY cat)
|
||||
FILTER (WHERE cat IS NOT NULL AND cat != '') AS all_categories
|
||||
FROM creator_listings cl
|
||||
LEFT JOIN LATERAL unnest(COALESCE(cl.categories, ARRAY[]::text[])) AS cat ON true
|
||||
GROUP BY cl."owningUserId"
|
||||
)
|
||||
SELECT
|
||||
p.username,
|
||||
p.name,
|
||||
p."avatarUrl" AS avatar_url,
|
||||
p.description,
|
||||
COALESCE(cs.all_categories, ARRAY[]::text[]) AS top_categories,
|
||||
p.links,
|
||||
p."isFeatured" AS is_featured,
|
||||
COALESCE(cs.num_agents, 0::bigint) AS num_agents,
|
||||
COALESCE(cs.agent_rating, 0.0) AS agent_rating,
|
||||
COALESCE(cs.agent_runs, 0::numeric) AS agent_runs
|
||||
FROM "Profile" p
|
||||
LEFT JOIN creator_stats cs ON cs."owningUserId" = p."userId";
|
||||
|
||||
|
||||
-- Recreate the StoreSubmission view with updated fields & query strategy:
|
||||
-- - Uses mv_agent_run_counts instead of full AgentGraphExecution table scan + aggregation
|
||||
-- - Renamed agent_id, agent_version -> graph_id, graph_version
|
||||
-- - Renamed store_listing_version_id -> listing_version_id
|
||||
-- - Renamed date_submitted -> submitted_at
|
||||
-- - Renamed runs, rating -> run_count, review_avg_rating
|
||||
-- - Added fields: instructions, agent_output_demo_url, review_count, is_deleted
|
||||
DROP VIEW IF EXISTS "StoreSubmission";
|
||||
CREATE OR REPLACE VIEW "StoreSubmission" AS
|
||||
WITH review_stats AS (
|
||||
SELECT
|
||||
"storeListingVersionId" AS version_id, -- more specific than mv_review_stats
|
||||
avg(score) AS avg_rating,
|
||||
count(*) AS review_count
|
||||
FROM "StoreListingReview"
|
||||
GROUP BY "storeListingVersionId"
|
||||
)
|
||||
SELECT
|
||||
sl.id AS listing_id,
|
||||
sl."owningUserId" AS user_id,
|
||||
sl.slug AS slug,
|
||||
|
||||
slv.id AS listing_version_id,
|
||||
slv.version AS listing_version,
|
||||
slv."agentGraphId" AS graph_id,
|
||||
slv."agentGraphVersion" AS graph_version,
|
||||
slv.name AS name,
|
||||
slv."subHeading" AS sub_heading,
|
||||
slv.description AS description,
|
||||
slv.instructions AS instructions,
|
||||
slv.categories AS categories,
|
||||
slv."imageUrls" AS image_urls,
|
||||
slv."videoUrl" AS video_url,
|
||||
slv."agentOutputDemoUrl" AS agent_output_demo_url,
|
||||
slv."submittedAt" AS submitted_at,
|
||||
slv."changesSummary" AS changes_summary,
|
||||
slv."submissionStatus" AS status,
|
||||
slv."reviewedAt" AS reviewed_at,
|
||||
slv."reviewerId" AS reviewer_id,
|
||||
slv."reviewComments" AS review_comments,
|
||||
slv."internalComments" AS internal_comments,
|
||||
slv."isDeleted" AS is_deleted,
|
||||
|
||||
COALESCE(run_stats.run_count, 0::bigint) AS run_count,
|
||||
COALESCE(review_stats.review_count, 0::bigint) AS review_count,
|
||||
COALESCE(review_stats.avg_rating, 0.0)::double precision AS review_avg_rating
|
||||
FROM "StoreListing" sl
|
||||
JOIN "StoreListingVersion" slv ON slv."storeListingId" = sl.id
|
||||
LEFT JOIN review_stats ON review_stats.version_id = slv.id
|
||||
LEFT JOIN mv_agent_run_counts run_stats ON run_stats.graph_id = slv."agentGraphId"
|
||||
WHERE sl."isDeleted" = false;
|
||||
|
||||
|
||||
-- Drop unused index on StoreListingReview.reviewByUserId
|
||||
DROP INDEX IF EXISTS "StoreListingReview_reviewByUserId_idx";
|
||||
-- Add index on storeListingVersionId to make StoreSubmission query faster
|
||||
CREATE INDEX "StoreListingReview_storeListingVersionId_idx" ON "StoreListingReview"("storeListingVersionId");
|
||||
|
||||
COMMIT;
|
||||
@@ -281,7 +281,6 @@ model AgentGraph {
|
||||
|
||||
Presets AgentPreset[]
|
||||
LibraryAgents LibraryAgent[]
|
||||
StoreListings StoreListing[]
|
||||
StoreListingVersions StoreListingVersion[]
|
||||
|
||||
@@id(name: "graphVersionId", [id, version])
|
||||
@@ -814,10 +813,8 @@ model Profile {
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
// Only 1 of user or group can be set.
|
||||
// The user this profile belongs to, if any.
|
||||
userId String?
|
||||
User User? @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
userId String @unique
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
name String
|
||||
username String @unique
|
||||
@@ -830,6 +827,7 @@ model Profile {
|
||||
isFeatured Boolean @default(false)
|
||||
|
||||
LibraryAgents LibraryAgent[]
|
||||
StoreListings StoreListing[]
|
||||
|
||||
@@index([userId])
|
||||
}
|
||||
@@ -860,9 +858,9 @@ view Creator {
|
||||
}
|
||||
|
||||
view StoreAgent {
|
||||
listing_id String @id
|
||||
storeListingVersionId String
|
||||
updated_at DateTime
|
||||
listing_id String @id
|
||||
listing_version_id String
|
||||
updated_at DateTime
|
||||
|
||||
slug String
|
||||
agent_name String
|
||||
@@ -879,10 +877,12 @@ view StoreAgent {
|
||||
runs Int
|
||||
rating Float
|
||||
versions String[]
|
||||
agentGraphVersions String[]
|
||||
agentGraphId String
|
||||
graph_id String
|
||||
graph_versions String[]
|
||||
is_available Boolean @default(true)
|
||||
useForOnboarding Boolean @default(false)
|
||||
use_for_onboarding Boolean @default(false)
|
||||
|
||||
recommended_schedule_cron String?
|
||||
|
||||
// Materialized views used (refreshed every 15 minutes via pg_cron):
|
||||
// - mv_agent_run_counts - Pre-aggregated agent execution counts by agentGraphId
|
||||
@@ -896,41 +896,52 @@ view StoreAgent {
|
||||
}
|
||||
|
||||
view StoreSubmission {
|
||||
listing_id String @id
|
||||
user_id String
|
||||
slug String
|
||||
name String
|
||||
sub_heading String
|
||||
description String
|
||||
image_urls String[]
|
||||
date_submitted DateTime
|
||||
status SubmissionStatus
|
||||
runs Int
|
||||
rating Float
|
||||
agent_id String
|
||||
agent_version Int
|
||||
store_listing_version_id String
|
||||
reviewer_id String?
|
||||
review_comments String?
|
||||
internal_comments String?
|
||||
reviewed_at DateTime?
|
||||
changes_summary String?
|
||||
video_url String?
|
||||
categories String[]
|
||||
// From StoreListing:
|
||||
listing_id String
|
||||
user_id String
|
||||
slug String
|
||||
|
||||
// Index or unique are not applied to views
|
||||
// From StoreListingVersion:
|
||||
listing_version_id String @id
|
||||
listing_version Int
|
||||
graph_id String
|
||||
graph_version Int
|
||||
|
||||
name String
|
||||
sub_heading String
|
||||
description String
|
||||
instructions String?
|
||||
categories String[]
|
||||
image_urls String[]
|
||||
video_url String?
|
||||
agent_output_demo_url String?
|
||||
|
||||
submitted_at DateTime?
|
||||
changes_summary String?
|
||||
status SubmissionStatus
|
||||
reviewed_at DateTime?
|
||||
reviewer_id String?
|
||||
review_comments String?
|
||||
internal_comments String?
|
||||
|
||||
is_deleted Boolean
|
||||
|
||||
// Aggregated from AgentGraphExecutions and StoreListingReviews:
|
||||
run_count Int
|
||||
review_count Int
|
||||
review_avg_rating Float
|
||||
}
|
||||
|
||||
// Note: This is actually a MATERIALIZED VIEW in the database
|
||||
// Refreshed automatically every 15 minutes via pg_cron (with fallback to manual refresh)
|
||||
view mv_agent_run_counts {
|
||||
agentGraphId String @unique
|
||||
run_count Int
|
||||
graph_id String @unique
|
||||
run_count Int // excluding runs by the graph's creator
|
||||
|
||||
// Pre-aggregated count of AgentGraphExecution records by agentGraphId
|
||||
// Used by StoreAgent and Creator views for performance optimization
|
||||
// Unique index created automatically on agentGraphId for fast lookups
|
||||
// Refresh uses CONCURRENTLY to avoid blocking reads
|
||||
// Pre-aggregated count of AgentGraphExecution records by agentGraphId.
|
||||
// Used by StoreAgent, Creator, and StoreSubmission views for performance optimization.
|
||||
// - Should have a unique index on graph_id for fast lookups
|
||||
// - Refresh should use CONCURRENTLY to avoid blocking reads
|
||||
}
|
||||
|
||||
// Note: This is actually a MATERIALIZED VIEW in the database
|
||||
@@ -979,22 +990,18 @@ model StoreListing {
|
||||
ActiveVersion StoreListingVersion? @relation("ActiveVersion", fields: [activeVersionId], references: [id])
|
||||
|
||||
// The agent link here is only so we can do lookup on agentId
|
||||
agentGraphId String
|
||||
agentGraphVersion Int
|
||||
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version], onDelete: Cascade)
|
||||
agentGraphId String @unique
|
||||
|
||||
owningUserId String
|
||||
OwningUser User @relation(fields: [owningUserId], references: [id])
|
||||
owningUserId String
|
||||
OwningUser User @relation(fields: [owningUserId], references: [id])
|
||||
CreatorProfile Profile @relation(fields: [owningUserId], references: [userId], map: "StoreListing_owner_Profile_fkey", onDelete: Cascade)
|
||||
|
||||
// Relations
|
||||
Versions StoreListingVersion[] @relation("ListingVersions")
|
||||
|
||||
// Unique index on agentId to ensure only one listing per agent, regardless of number of versions the agent has.
|
||||
@@unique([agentGraphId])
|
||||
@@unique([owningUserId, slug])
|
||||
// Used in the view query
|
||||
@@index([isDeleted, hasApprovedVersion])
|
||||
@@index([agentGraphId, agentGraphVersion])
|
||||
}
|
||||
|
||||
model StoreListingVersion {
|
||||
@@ -1089,16 +1096,16 @@ model UnifiedContentEmbedding {
|
||||
// Search data
|
||||
embedding Unsupported("vector(1536)") // pgvector embedding (extension in platform schema)
|
||||
searchableText String // Combined text for search and fallback
|
||||
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector")) // Full-text search (auto-populated by trigger)
|
||||
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector")) // Full-text search (auto-populated by trigger)
|
||||
metadata Json @default("{}") // Content-specific metadata
|
||||
// NO @@index for search - GIN index "UnifiedContentEmbedding_search_idx" created via SQL migration
|
||||
// Prisma may generate DROP INDEX on migrate dev - that's okay, migration recreates it
|
||||
|
||||
@@unique([contentType, contentId, userId], map: "UnifiedContentEmbedding_contentType_contentId_userId_key")
|
||||
@@index([contentType])
|
||||
@@index([userId])
|
||||
@@index([contentType, userId])
|
||||
@@index([embedding], map: "UnifiedContentEmbedding_embedding_idx")
|
||||
// NO @@index for search - GIN index "UnifiedContentEmbedding_search_idx" created via SQL migration
|
||||
// Prisma may generate DROP INDEX on migrate dev - that's okay, migration recreates it
|
||||
}
|
||||
|
||||
model StoreListingReview {
|
||||
@@ -1115,8 +1122,9 @@ model StoreListingReview {
|
||||
score Int
|
||||
comments String?
|
||||
|
||||
// Enforce one review per user per listing version
|
||||
@@unique([storeListingVersionId, reviewByUserId])
|
||||
@@index([reviewByUserId])
|
||||
@@index([storeListingVersionId])
|
||||
}
|
||||
|
||||
enum SubmissionStatus {
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Backfill quick prompts for saved business understanding records."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
import click
|
||||
from prisma.models import CoPilotUnderstanding
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstanding,
|
||||
update_business_understanding_prompts,
|
||||
)
|
||||
from backend.data.understanding_prompts import (
|
||||
generate_understanding_prompts,
|
||||
has_prompt_generation_context,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def backfill_understanding_prompts(
|
||||
batch_size: int = 100,
|
||||
limit: int | None = None,
|
||||
dry_run: bool = False,
|
||||
) -> dict[str, int]:
|
||||
summary = {
|
||||
"scanned": 0,
|
||||
"candidates": 0,
|
||||
"eligible": 0,
|
||||
"updated": 0,
|
||||
"failed": 0,
|
||||
"skipped_existing": 0,
|
||||
"skipped_no_context": 0,
|
||||
}
|
||||
offset = 0
|
||||
|
||||
while True:
|
||||
records = await CoPilotUnderstanding.prisma().find_many(
|
||||
order={"id": "asc"},
|
||||
skip=offset,
|
||||
take=batch_size,
|
||||
)
|
||||
if not records:
|
||||
break
|
||||
|
||||
offset += len(records)
|
||||
|
||||
for record in records:
|
||||
summary["scanned"] += 1
|
||||
understanding = BusinessUnderstanding.from_db(record)
|
||||
|
||||
if understanding.prompts:
|
||||
summary["skipped_existing"] += 1
|
||||
continue
|
||||
|
||||
if limit is not None and summary["candidates"] >= limit:
|
||||
logger.info("Reached backfill limit of %s records", limit)
|
||||
return summary
|
||||
|
||||
summary["candidates"] += 1
|
||||
|
||||
if not has_prompt_generation_context(understanding):
|
||||
summary["skipped_no_context"] += 1
|
||||
continue
|
||||
|
||||
summary["eligible"] += 1
|
||||
if dry_run:
|
||||
continue
|
||||
|
||||
try:
|
||||
prompts = await generate_understanding_prompts(understanding)
|
||||
updated = await update_business_understanding_prompts(
|
||||
understanding.user_id, prompts
|
||||
)
|
||||
except Exception:
|
||||
summary["failed"] += 1
|
||||
logger.exception(
|
||||
"Failed to backfill prompts for user %s", understanding.user_id
|
||||
)
|
||||
continue
|
||||
|
||||
if updated is None:
|
||||
summary["failed"] += 1
|
||||
logger.warning(
|
||||
"Skipped backfill for user %s because the record no longer exists",
|
||||
understanding.user_id,
|
||||
)
|
||||
continue
|
||||
|
||||
summary["updated"] += 1
|
||||
|
||||
logger.info("Understanding prompt backfill summary: %s", json.dumps(summary))
|
||||
return summary
|
||||
|
||||
|
||||
async def run_backfill(
|
||||
batch_size: int = 100,
|
||||
limit: int | None = None,
|
||||
dry_run: bool = False,
|
||||
) -> dict[str, int]:
|
||||
await db.connect()
|
||||
try:
|
||||
return await backfill_understanding_prompts(
|
||||
batch_size=batch_size,
|
||||
limit=limit,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
finally:
|
||||
await db.disconnect()
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--dry-run", is_flag=True, default=False, help="Report candidates only.")
|
||||
@click.option("--limit", type=click.IntRange(min=1), default=None)
|
||||
@click.option(
|
||||
"--batch-size", type=click.IntRange(min=1), default=100, show_default=True
|
||||
)
|
||||
def main(dry_run: bool, limit: int | None, batch_size: int) -> None:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
summary = asyncio.run(
|
||||
run_backfill(
|
||||
batch_size=batch_size,
|
||||
limit=limit,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
)
|
||||
click.echo(json.dumps(summary, indent=2, sort_keys=True))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,158 @@
|
||||
"""Tests for the understanding prompt backfill script."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from scripts.backfill_understanding_prompts import backfill_understanding_prompts
|
||||
|
||||
|
||||
def make_record(*, user_id: str, business: dict) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=f"understanding-{user_id}",
|
||||
userId=user_id,
|
||||
createdAt=datetime.now(timezone.utc),
|
||||
updatedAt=datetime.now(timezone.utc),
|
||||
data={"business": business},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_understanding_prompts_dry_run():
|
||||
record = make_record(
|
||||
user_id="user-1",
|
||||
business={"business_name": "Acme", "industry": "Support"},
|
||||
)
|
||||
prisma = AsyncMock()
|
||||
prisma.find_many.side_effect = [[record], []]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"scripts.backfill_understanding_prompts.CoPilotUnderstanding.prisma",
|
||||
return_value=prisma,
|
||||
),
|
||||
patch(
|
||||
"scripts.backfill_understanding_prompts.generate_understanding_prompts",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_generate,
|
||||
patch(
|
||||
"scripts.backfill_understanding_prompts.update_business_understanding_prompts",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update,
|
||||
):
|
||||
summary = await backfill_understanding_prompts(batch_size=10, dry_run=True)
|
||||
|
||||
assert summary == {
|
||||
"scanned": 1,
|
||||
"candidates": 1,
|
||||
"eligible": 1,
|
||||
"updated": 0,
|
||||
"failed": 0,
|
||||
"skipped_existing": 0,
|
||||
"skipped_no_context": 0,
|
||||
}
|
||||
mock_generate.assert_not_awaited()
|
||||
mock_update.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_understanding_prompts_skips_existing_prompts():
|
||||
record = make_record(
|
||||
user_id="user-1",
|
||||
business={
|
||||
"business_name": "Acme",
|
||||
"prompts": ["Prompt one", "Prompt two", "Prompt three"],
|
||||
},
|
||||
)
|
||||
prisma = AsyncMock()
|
||||
prisma.find_many.side_effect = [[record], []]
|
||||
|
||||
with patch(
|
||||
"scripts.backfill_understanding_prompts.CoPilotUnderstanding.prisma",
|
||||
return_value=prisma,
|
||||
):
|
||||
summary = await backfill_understanding_prompts(batch_size=10)
|
||||
|
||||
assert summary == {
|
||||
"scanned": 1,
|
||||
"candidates": 0,
|
||||
"eligible": 0,
|
||||
"updated": 0,
|
||||
"failed": 0,
|
||||
"skipped_existing": 1,
|
||||
"skipped_no_context": 0,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_understanding_prompts_updates_missing_prompts():
|
||||
record = make_record(
|
||||
user_id="user-1",
|
||||
business={"business_name": "Acme", "industry": "Support"},
|
||||
)
|
||||
prisma = AsyncMock()
|
||||
prisma.find_many.side_effect = [[record], []]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"scripts.backfill_understanding_prompts.CoPilotUnderstanding.prisma",
|
||||
return_value=prisma,
|
||||
),
|
||||
patch(
|
||||
"scripts.backfill_understanding_prompts.generate_understanding_prompts",
|
||||
new_callable=AsyncMock,
|
||||
return_value=["Prompt one", "Prompt two", "Prompt three"],
|
||||
) as mock_generate,
|
||||
patch(
|
||||
"scripts.backfill_understanding_prompts.update_business_understanding_prompts",
|
||||
new_callable=AsyncMock,
|
||||
return_value=object(),
|
||||
) as mock_update,
|
||||
):
|
||||
summary = await backfill_understanding_prompts(batch_size=10)
|
||||
|
||||
assert summary == {
|
||||
"scanned": 1,
|
||||
"candidates": 1,
|
||||
"eligible": 1,
|
||||
"updated": 1,
|
||||
"failed": 0,
|
||||
"skipped_existing": 0,
|
||||
"skipped_no_context": 0,
|
||||
}
|
||||
mock_generate.assert_awaited_once()
|
||||
mock_update.assert_awaited_once_with(
|
||||
"user-1", ["Prompt one", "Prompt two", "Prompt three"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_understanding_prompts_skips_records_without_context():
|
||||
record = make_record(user_id="user-1", business={})
|
||||
prisma = AsyncMock()
|
||||
prisma.find_many.side_effect = [[record], []]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"scripts.backfill_understanding_prompts.CoPilotUnderstanding.prisma",
|
||||
return_value=prisma,
|
||||
),
|
||||
patch(
|
||||
"scripts.backfill_understanding_prompts.generate_understanding_prompts",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_generate,
|
||||
):
|
||||
summary = await backfill_understanding_prompts(batch_size=10)
|
||||
|
||||
assert summary == {
|
||||
"scanned": 1,
|
||||
"candidates": 1,
|
||||
"eligible": 0,
|
||||
"updated": 0,
|
||||
"failed": 0,
|
||||
"skipped_existing": 0,
|
||||
"skipped_no_context": 1,
|
||||
}
|
||||
mock_generate.assert_not_awaited()
|
||||
@@ -23,14 +23,14 @@
|
||||
"1.0.0",
|
||||
"1.1.0"
|
||||
],
|
||||
"agentGraphVersions": [
|
||||
"graph_id": "test-graph-id",
|
||||
"graph_versions": [
|
||||
"1",
|
||||
"2"
|
||||
],
|
||||
"agentGraphId": "test-graph-id",
|
||||
"last_updated": "2023-01-01T00:00:00",
|
||||
"recommended_schedule_cron": null,
|
||||
"active_version_id": null,
|
||||
"has_approved_version": false,
|
||||
"active_version_id": "test-version-id",
|
||||
"has_approved_version": true,
|
||||
"changelog": null
|
||||
}
|
||||
@@ -1,14 +1,16 @@
|
||||
{
|
||||
"name": "Test User",
|
||||
"username": "creator1",
|
||||
"name": "Test User",
|
||||
"description": "Test creator description",
|
||||
"avatar_url": "avatar.jpg",
|
||||
"links": [
|
||||
"link1.com",
|
||||
"link2.com"
|
||||
],
|
||||
"avatar_url": "avatar.jpg",
|
||||
"agent_rating": 4.8,
|
||||
"is_featured": true,
|
||||
"num_agents": 5,
|
||||
"agent_runs": 1000,
|
||||
"agent_rating": 4.8,
|
||||
"top_categories": [
|
||||
"category1",
|
||||
"category2"
|
||||
|
||||
@@ -1,54 +1,94 @@
|
||||
{
|
||||
"creators": [
|
||||
{
|
||||
"name": "Creator 0",
|
||||
"username": "creator0",
|
||||
"name": "Creator 0",
|
||||
"description": "Creator 0 description",
|
||||
"avatar_url": "avatar0.jpg",
|
||||
"links": [
|
||||
"user0.link.com"
|
||||
],
|
||||
"is_featured": false,
|
||||
"num_agents": 1,
|
||||
"agent_rating": 4.5,
|
||||
"agent_runs": 100,
|
||||
"is_featured": false
|
||||
"agent_rating": 4.5,
|
||||
"top_categories": [
|
||||
"cat1",
|
||||
"cat2",
|
||||
"cat3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Creator 1",
|
||||
"username": "creator1",
|
||||
"name": "Creator 1",
|
||||
"description": "Creator 1 description",
|
||||
"avatar_url": "avatar1.jpg",
|
||||
"links": [
|
||||
"user1.link.com"
|
||||
],
|
||||
"is_featured": false,
|
||||
"num_agents": 1,
|
||||
"agent_rating": 4.5,
|
||||
"agent_runs": 100,
|
||||
"is_featured": false
|
||||
"agent_rating": 4.5,
|
||||
"top_categories": [
|
||||
"cat1",
|
||||
"cat2",
|
||||
"cat3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Creator 2",
|
||||
"username": "creator2",
|
||||
"name": "Creator 2",
|
||||
"description": "Creator 2 description",
|
||||
"avatar_url": "avatar2.jpg",
|
||||
"links": [
|
||||
"user2.link.com"
|
||||
],
|
||||
"is_featured": false,
|
||||
"num_agents": 1,
|
||||
"agent_rating": 4.5,
|
||||
"agent_runs": 100,
|
||||
"is_featured": false
|
||||
"agent_rating": 4.5,
|
||||
"top_categories": [
|
||||
"cat1",
|
||||
"cat2",
|
||||
"cat3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Creator 3",
|
||||
"username": "creator3",
|
||||
"name": "Creator 3",
|
||||
"description": "Creator 3 description",
|
||||
"avatar_url": "avatar3.jpg",
|
||||
"links": [
|
||||
"user3.link.com"
|
||||
],
|
||||
"is_featured": false,
|
||||
"num_agents": 1,
|
||||
"agent_rating": 4.5,
|
||||
"agent_runs": 100,
|
||||
"is_featured": false
|
||||
"agent_rating": 4.5,
|
||||
"top_categories": [
|
||||
"cat1",
|
||||
"cat2",
|
||||
"cat3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Creator 4",
|
||||
"username": "creator4",
|
||||
"name": "Creator 4",
|
||||
"description": "Creator 4 description",
|
||||
"avatar_url": "avatar4.jpg",
|
||||
"links": [
|
||||
"user4.link.com"
|
||||
],
|
||||
"is_featured": false,
|
||||
"num_agents": 1,
|
||||
"agent_rating": 4.5,
|
||||
"agent_runs": 100,
|
||||
"is_featured": false
|
||||
"agent_rating": 4.5,
|
||||
"top_categories": [
|
||||
"cat1",
|
||||
"cat2",
|
||||
"cat3"
|
||||
]
|
||||
}
|
||||
],
|
||||
"pagination": {
|
||||
|
||||
@@ -2,32 +2,33 @@
|
||||
"submissions": [
|
||||
{
|
||||
"listing_id": "test-listing-id",
|
||||
"agent_id": "test-agent-id",
|
||||
"agent_version": 1,
|
||||
"user_id": "test-user-id",
|
||||
"slug": "test-agent",
|
||||
"listing_version_id": "test-version-id",
|
||||
"listing_version": 1,
|
||||
"graph_id": "test-agent-id",
|
||||
"graph_version": 1,
|
||||
"name": "Test Agent",
|
||||
"sub_heading": "Test agent subheading",
|
||||
"slug": "test-agent",
|
||||
"description": "Test agent description",
|
||||
"instructions": null,
|
||||
"instructions": "Click the button!",
|
||||
"categories": [
|
||||
"test-category"
|
||||
],
|
||||
"image_urls": [
|
||||
"test.jpg"
|
||||
],
|
||||
"date_submitted": "2023-01-01T00:00:00",
|
||||
"video_url": "test.mp4",
|
||||
"agent_output_demo_url": "demo_video.mp4",
|
||||
"submitted_at": "2023-01-01T00:00:00",
|
||||
"changes_summary": "Initial Submission",
|
||||
"status": "APPROVED",
|
||||
"runs": 50,
|
||||
"rating": 4.2,
|
||||
"store_listing_version_id": null,
|
||||
"version": null,
|
||||
"reviewed_at": null,
|
||||
"reviewer_id": null,
|
||||
"review_comments": null,
|
||||
"internal_comments": null,
|
||||
"reviewed_at": null,
|
||||
"changes_summary": null,
|
||||
"video_url": "test.mp4",
|
||||
"agent_output_demo_url": null,
|
||||
"categories": [
|
||||
"test-category"
|
||||
]
|
||||
"run_count": 50,
|
||||
"review_count": 5,
|
||||
"review_avg_rating": 4.2
|
||||
}
|
||||
],
|
||||
"pagination": {
|
||||
|
||||
@@ -128,7 +128,7 @@ class TestDataCreator:
|
||||
email = "test123@gmail.com"
|
||||
else:
|
||||
email = faker.unique.email()
|
||||
password = "testpassword123" # Standard test password
|
||||
password = "testpassword123" # Standard test password # pragma: allowlist secret # noqa
|
||||
user_id = f"test-user-{i}-{faker.uuid4()}"
|
||||
|
||||
# Create user in Supabase Auth (if needed)
|
||||
@@ -571,8 +571,8 @@ class TestDataCreator:
|
||||
if test_user and self.agent_graphs:
|
||||
test_submission_data = {
|
||||
"user_id": test_user["id"],
|
||||
"agent_id": self.agent_graphs[0]["id"],
|
||||
"agent_version": 1,
|
||||
"graph_id": self.agent_graphs[0]["id"],
|
||||
"graph_version": 1,
|
||||
"slug": "test-agent-submission",
|
||||
"name": "Test Agent Submission",
|
||||
"sub_heading": "A test agent for frontend testing",
|
||||
@@ -593,9 +593,9 @@ class TestDataCreator:
|
||||
print("✅ Created special test store submission for test123@gmail.com")
|
||||
|
||||
# ALWAYS approve and feature the test submission
|
||||
if test_submission.store_listing_version_id:
|
||||
if test_submission.listing_version_id:
|
||||
approved_submission = await review_store_submission(
|
||||
store_listing_version_id=test_submission.store_listing_version_id,
|
||||
store_listing_version_id=test_submission.listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Test submission approved",
|
||||
internal_comments="Auto-approved test submission",
|
||||
@@ -605,7 +605,7 @@ class TestDataCreator:
|
||||
print("✅ Approved test store submission")
|
||||
|
||||
await prisma.storelistingversion.update(
|
||||
where={"id": test_submission.store_listing_version_id},
|
||||
where={"id": test_submission.listing_version_id},
|
||||
data={"isFeatured": True},
|
||||
)
|
||||
featured_count += 1
|
||||
@@ -640,8 +640,8 @@ class TestDataCreator:
|
||||
|
||||
submission = await create_store_submission(
|
||||
user_id=user["id"],
|
||||
agent_id=graph["id"],
|
||||
agent_version=graph.get("version", 1),
|
||||
graph_id=graph["id"],
|
||||
graph_version=graph.get("version", 1),
|
||||
slug=faker.slug(),
|
||||
name=graph.get("name", faker.sentence(nb_words=3)),
|
||||
sub_heading=faker.sentence(),
|
||||
@@ -654,7 +654,7 @@ class TestDataCreator:
|
||||
submissions.append(submission.model_dump())
|
||||
print(f"✅ Created store submission: {submission.name}")
|
||||
|
||||
if submission.store_listing_version_id:
|
||||
if submission.listing_version_id:
|
||||
# DETERMINISTIC: First N submissions are always approved
|
||||
# First GUARANTEED_FEATURED_AGENTS of those are always featured
|
||||
should_approve = (
|
||||
@@ -667,7 +667,7 @@ class TestDataCreator:
|
||||
try:
|
||||
reviewer_id = random.choice(self.users)["id"]
|
||||
approved_submission = await review_store_submission(
|
||||
store_listing_version_id=submission.store_listing_version_id,
|
||||
store_listing_version_id=submission.listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Auto-approved for E2E testing",
|
||||
internal_comments="Automatically approved by E2E test data script",
|
||||
@@ -683,9 +683,7 @@ class TestDataCreator:
|
||||
if should_feature:
|
||||
try:
|
||||
await prisma.storelistingversion.update(
|
||||
where={
|
||||
"id": submission.store_listing_version_id
|
||||
},
|
||||
where={"id": submission.listing_version_id},
|
||||
data={"isFeatured": True},
|
||||
)
|
||||
featured_count += 1
|
||||
@@ -699,9 +697,7 @@ class TestDataCreator:
|
||||
elif random.random() < 0.2:
|
||||
try:
|
||||
await prisma.storelistingversion.update(
|
||||
where={
|
||||
"id": submission.store_listing_version_id
|
||||
},
|
||||
where={"id": submission.listing_version_id},
|
||||
data={"isFeatured": True},
|
||||
)
|
||||
featured_count += 1
|
||||
@@ -721,7 +717,7 @@ class TestDataCreator:
|
||||
try:
|
||||
reviewer_id = random.choice(self.users)["id"]
|
||||
await review_store_submission(
|
||||
store_listing_version_id=submission.store_listing_version_id,
|
||||
store_listing_version_id=submission.listing_version_id,
|
||||
is_approved=False,
|
||||
external_comments="Submission rejected - needs improvements",
|
||||
internal_comments="Automatically rejected by E2E test data script",
|
||||
|
||||
@@ -394,7 +394,6 @@ async def main():
|
||||
listing = await db.storelisting.create(
|
||||
data={
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"owningUserId": user.id,
|
||||
"hasApprovedVersion": random.choice([True, False]),
|
||||
"slug": slug,
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
"cssVariables": false,
|
||||
"prefix": ""
|
||||
},
|
||||
"iconLibrary": "radix",
|
||||
"aliases": {
|
||||
"components": "@/components",
|
||||
"utils": "@/lib/utils"
|
||||
|
||||
@@ -1,33 +1,39 @@
|
||||
"use server";
|
||||
|
||||
import { revalidatePath } from "next/cache";
|
||||
import BackendApi from "@/lib/autogpt-server-api";
|
||||
import {
|
||||
StoreListingsWithVersionsResponse,
|
||||
SubmissionStatus,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
getV2GetAdminListingsHistory,
|
||||
postV2ReviewStoreSubmission,
|
||||
getV2AdminDownloadAgentFile,
|
||||
} from "@/app/api/__generated__/endpoints/admin/admin";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
|
||||
|
||||
export async function approveAgent(formData: FormData) {
|
||||
const data = {
|
||||
store_listing_version_id: formData.get("id") as string,
|
||||
const storeListingVersionId = formData.get("id") as string;
|
||||
const comments = formData.get("comments") as string;
|
||||
|
||||
await postV2ReviewStoreSubmission(storeListingVersionId, {
|
||||
store_listing_version_id: storeListingVersionId,
|
||||
is_approved: true,
|
||||
comments: formData.get("comments") as string,
|
||||
};
|
||||
const api = new BackendApi();
|
||||
await api.reviewSubmissionAdmin(data.store_listing_version_id, data);
|
||||
comments,
|
||||
});
|
||||
|
||||
revalidatePath("/admin/marketplace");
|
||||
}
|
||||
|
||||
export async function rejectAgent(formData: FormData) {
|
||||
const data = {
|
||||
store_listing_version_id: formData.get("id") as string,
|
||||
const storeListingVersionId = formData.get("id") as string;
|
||||
const comments = formData.get("comments") as string;
|
||||
const internal_comments =
|
||||
(formData.get("internal_comments") as string) || undefined;
|
||||
|
||||
await postV2ReviewStoreSubmission(storeListingVersionId, {
|
||||
store_listing_version_id: storeListingVersionId,
|
||||
is_approved: false,
|
||||
comments: formData.get("comments") as string,
|
||||
internal_comments: formData.get("internal_comments") as string,
|
||||
};
|
||||
const api = new BackendApi();
|
||||
await api.reviewSubmissionAdmin(data.store_listing_version_id, data);
|
||||
comments,
|
||||
internal_comments,
|
||||
});
|
||||
|
||||
revalidatePath("/admin/marketplace");
|
||||
}
|
||||
@@ -37,26 +43,18 @@ export async function getAdminListingsWithVersions(
|
||||
search?: string,
|
||||
page: number = 1,
|
||||
pageSize: number = 20,
|
||||
): Promise<StoreListingsWithVersionsResponse> {
|
||||
const data: Record<string, any> = {
|
||||
) {
|
||||
const response = await getV2GetAdminListingsHistory({
|
||||
status,
|
||||
search,
|
||||
page,
|
||||
page_size: pageSize,
|
||||
};
|
||||
});
|
||||
|
||||
if (status) {
|
||||
data.status = status;
|
||||
}
|
||||
|
||||
if (search) {
|
||||
data.search = search;
|
||||
}
|
||||
const api = new BackendApi();
|
||||
const response = await api.getAdminListingsWithVersions(data);
|
||||
return response;
|
||||
return okData(response);
|
||||
}
|
||||
|
||||
export async function downloadAsAdmin(storeListingVersion: string) {
|
||||
const api = new BackendApi();
|
||||
const file = await api.downloadStoreAgentAdmin(storeListingVersion);
|
||||
return file;
|
||||
const response = await getV2AdminDownloadAgentFile(storeListingVersion);
|
||||
return okData(response);
|
||||
}
|
||||
|
||||
@@ -6,10 +6,8 @@ import {
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/__legacy__/ui/table";
|
||||
import {
|
||||
StoreSubmission,
|
||||
SubmissionStatus,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import type { StoreSubmissionAdminView } from "@/app/api/__generated__/models/storeSubmissionAdminView";
|
||||
import type { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
|
||||
import { PaginationControls } from "../../../../../components/__legacy__/ui/pagination-controls";
|
||||
import { getAdminListingsWithVersions } from "@/app/(platform)/admin/marketplace/actions";
|
||||
import { ExpandableRow } from "./ExpandleRow";
|
||||
@@ -17,12 +15,14 @@ import { SearchAndFilterAdminMarketplace } from "./SearchFilterForm";
|
||||
|
||||
// Helper function to get the latest version by version number
|
||||
const getLatestVersionByNumber = (
|
||||
versions: StoreSubmission[],
|
||||
): StoreSubmission | null => {
|
||||
versions: StoreSubmissionAdminView[] | undefined,
|
||||
): StoreSubmissionAdminView | null => {
|
||||
if (!versions || versions.length === 0) return null;
|
||||
return versions.reduce(
|
||||
(latest, current) =>
|
||||
(current.version ?? 0) > (latest.version ?? 1) ? current : latest,
|
||||
(current.listing_version ?? 0) > (latest.listing_version ?? 1)
|
||||
? current
|
||||
: latest,
|
||||
versions[0],
|
||||
);
|
||||
};
|
||||
@@ -37,12 +37,14 @@ export async function AdminAgentsDataTable({
|
||||
initialSearch?: string;
|
||||
}) {
|
||||
// Server-side data fetching
|
||||
const { listings, pagination } = await getAdminListingsWithVersions(
|
||||
const data = await getAdminListingsWithVersions(
|
||||
initialStatus,
|
||||
initialSearch,
|
||||
initialPage,
|
||||
10,
|
||||
);
|
||||
const listings = data?.listings ?? [];
|
||||
const pagination = data?.pagination;
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
@@ -92,7 +94,7 @@ export async function AdminAgentsDataTable({
|
||||
|
||||
<PaginationControls
|
||||
currentPage={initialPage}
|
||||
totalPages={pagination.total_pages}
|
||||
totalPages={pagination?.total_pages ?? 1}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -13,7 +13,7 @@ import {
|
||||
} from "@/components/__legacy__/ui/dialog";
|
||||
import { Label } from "@/components/__legacy__/ui/label";
|
||||
import { Textarea } from "@/components/__legacy__/ui/textarea";
|
||||
import type { StoreSubmission } from "@/lib/autogpt-server-api/types";
|
||||
import type { StoreSubmissionAdminView } from "@/app/api/__generated__/models/storeSubmissionAdminView";
|
||||
import { useRouter } from "next/navigation";
|
||||
import {
|
||||
approveAgent,
|
||||
@@ -23,7 +23,7 @@ import {
|
||||
export function ApproveRejectButtons({
|
||||
version,
|
||||
}: {
|
||||
version: StoreSubmission;
|
||||
version: StoreSubmissionAdminView;
|
||||
}) {
|
||||
const router = useRouter();
|
||||
const [isApproveDialogOpen, setIsApproveDialogOpen] = useState(false);
|
||||
@@ -95,7 +95,7 @@ export function ApproveRejectButtons({
|
||||
<input
|
||||
type="hidden"
|
||||
name="id"
|
||||
value={version.store_listing_version_id || ""}
|
||||
value={version.listing_version_id || ""}
|
||||
/>
|
||||
|
||||
<div className="grid gap-4 py-4">
|
||||
@@ -142,7 +142,7 @@ export function ApproveRejectButtons({
|
||||
<input
|
||||
type="hidden"
|
||||
name="id"
|
||||
value={version.store_listing_version_id || ""}
|
||||
value={version.listing_version_id || ""}
|
||||
/>
|
||||
|
||||
<div className="grid gap-4 py-4">
|
||||
|
||||
@@ -12,11 +12,9 @@ import {
|
||||
import { Badge } from "@/components/__legacy__/ui/badge";
|
||||
import { ChevronDown, ChevronRight } from "lucide-react";
|
||||
import { formatDistanceToNow } from "date-fns";
|
||||
import {
|
||||
type StoreListingWithVersions,
|
||||
type StoreSubmission,
|
||||
SubmissionStatus,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import type { StoreListingWithVersionsAdminView } from "@/app/api/__generated__/models/storeListingWithVersionsAdminView";
|
||||
import type { StoreSubmissionAdminView } from "@/app/api/__generated__/models/storeSubmissionAdminView";
|
||||
import { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
|
||||
import { ApproveRejectButtons } from "./ApproveRejectButton";
|
||||
import { DownloadAgentAdminButton } from "./DownloadAgentButton";
|
||||
|
||||
@@ -38,8 +36,8 @@ export function ExpandableRow({
|
||||
listing,
|
||||
latestVersion,
|
||||
}: {
|
||||
listing: StoreListingWithVersions;
|
||||
latestVersion: StoreSubmission | null;
|
||||
listing: StoreListingWithVersionsAdminView;
|
||||
latestVersion: StoreSubmissionAdminView | null;
|
||||
}) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
|
||||
@@ -69,17 +67,17 @@ export function ExpandableRow({
|
||||
{latestVersion?.status && getStatusBadge(latestVersion.status)}
|
||||
</TableCell>
|
||||
<TableCell onClick={() => setExpanded(!expanded)}>
|
||||
{latestVersion?.date_submitted
|
||||
? formatDistanceToNow(new Date(latestVersion.date_submitted), {
|
||||
{latestVersion?.submitted_at
|
||||
? formatDistanceToNow(new Date(latestVersion.submitted_at), {
|
||||
addSuffix: true,
|
||||
})
|
||||
: "Unknown"}
|
||||
</TableCell>
|
||||
<TableCell className="text-right">
|
||||
<div className="flex justify-end gap-2">
|
||||
{latestVersion?.store_listing_version_id && (
|
||||
{latestVersion?.listing_version_id && (
|
||||
<DownloadAgentAdminButton
|
||||
storeListingVersionId={latestVersion.store_listing_version_id}
|
||||
storeListingVersionId={latestVersion.listing_version_id}
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -115,14 +113,17 @@ export function ExpandableRow({
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{listing.versions
|
||||
.sort((a, b) => (b.version ?? 1) - (a.version ?? 0))
|
||||
{(listing.versions ?? [])
|
||||
.sort(
|
||||
(a, b) =>
|
||||
(b.listing_version ?? 1) - (a.listing_version ?? 0),
|
||||
)
|
||||
.map((version) => (
|
||||
<TableRow key={version.store_listing_version_id}>
|
||||
<TableRow key={version.listing_version_id}>
|
||||
<TableCell>
|
||||
v{version.version || "?"}
|
||||
{version.store_listing_version_id ===
|
||||
listing.active_version_id && (
|
||||
v{version.listing_version || "?"}
|
||||
{version.listing_version_id ===
|
||||
listing.active_listing_version_id && (
|
||||
<Badge className="ml-2 bg-blue-500">Active</Badge>
|
||||
)}
|
||||
</TableCell>
|
||||
@@ -131,9 +132,9 @@ export function ExpandableRow({
|
||||
{version.changes_summary || "No summary"}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{version.date_submitted
|
||||
{version.submitted_at
|
||||
? formatDistanceToNow(
|
||||
new Date(version.date_submitted),
|
||||
new Date(version.submitted_at),
|
||||
{ addSuffix: true },
|
||||
)
|
||||
: "Unknown"}
|
||||
@@ -182,10 +183,10 @@ export function ExpandableRow({
|
||||
{/* <TableCell>{version.categories.join(", ")}</TableCell> */}
|
||||
<TableCell className="text-right">
|
||||
<div className="flex justify-end gap-2">
|
||||
{version.store_listing_version_id && (
|
||||
{version.listing_version_id && (
|
||||
<DownloadAgentAdminButton
|
||||
storeListingVersionId={
|
||||
version.store_listing_version_id
|
||||
version.listing_version_id
|
||||
}
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -12,7 +12,7 @@ import {
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/__legacy__/ui/select";
|
||||
import { SubmissionStatus } from "@/lib/autogpt-server-api/types";
|
||||
import { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
|
||||
|
||||
export function SearchAndFilterAdminMarketplace({
|
||||
initialSearch,
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { withRoleAccess } from "@/lib/withRoleAccess";
|
||||
import { Suspense } from "react";
|
||||
import type { SubmissionStatus } from "@/lib/autogpt-server-api/types";
|
||||
import type { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
|
||||
import { AdminAgentsDataTable } from "./components/AdminAgentsDataTable";
|
||||
|
||||
type MarketplaceAdminPageSearchParams = {
|
||||
page?: string;
|
||||
status?: string;
|
||||
status?: SubmissionStatus;
|
||||
search?: string;
|
||||
};
|
||||
|
||||
@@ -15,7 +15,7 @@ async function AdminMarketplaceDashboard({
|
||||
searchParams: MarketplaceAdminPageSearchParams;
|
||||
}) {
|
||||
const page = searchParams.page ? Number.parseInt(searchParams.page) : 1;
|
||||
const status = searchParams.status as SubmissionStatus | undefined;
|
||||
const status = searchParams.status;
|
||||
const search = searchParams.search;
|
||||
|
||||
return (
|
||||
|
||||
@@ -151,6 +151,9 @@ export function ChatInput({
|
||||
onFilesSelected={handleFilesSelected}
|
||||
disabled={isBusy}
|
||||
/>
|
||||
</PromptInputTools>
|
||||
|
||||
<div className="flex items-center gap-4">
|
||||
{showMicButton && (
|
||||
<RecordingButton
|
||||
isRecording={isRecording}
|
||||
@@ -160,13 +163,12 @@ export function ChatInput({
|
||||
onClick={toggleRecording}
|
||||
/>
|
||||
)}
|
||||
</PromptInputTools>
|
||||
|
||||
{isStreaming ? (
|
||||
<PromptInputSubmit status="streaming" onStop={onStop} />
|
||||
) : (
|
||||
<PromptInputSubmit disabled={!canSend} />
|
||||
)}
|
||||
{isStreaming ? (
|
||||
<PromptInputSubmit status="streaming" onStop={onStop} />
|
||||
) : (
|
||||
<PromptInputSubmit disabled={!canSend} />
|
||||
)}
|
||||
</div>
|
||||
</PromptInputFooter>
|
||||
</InputGroup>
|
||||
</form>
|
||||
|
||||
@@ -28,10 +28,9 @@ export function RecordingButton({
|
||||
disabled={disabled}
|
||||
onClick={onClick}
|
||||
className={cn(
|
||||
"border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
|
||||
"border-0 bg-white text-zinc-500 hover:bg-zinc-50 hover:text-zinc-700",
|
||||
disabled && "opacity-40",
|
||||
isRecording &&
|
||||
"animate-pulse border-red-500 bg-red-500 text-white hover:border-red-600 hover:bg-red-600",
|
||||
isRecording && "animate-pulse bg-red-500 text-white hover:bg-red-600",
|
||||
isTranscribing && "bg-zinc-100 text-zinc-400",
|
||||
isStreaming && "opacity-40",
|
||||
)}
|
||||
|
||||
@@ -5,15 +5,19 @@ import {
|
||||
} from "@/components/ai-elements/conversation";
|
||||
import { Message, MessageContent } from "@/components/ai-elements/message";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { FileUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { FileUIPart, ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { TOOL_PART_PREFIX } from "../JobStatsBar/constants";
|
||||
import { TurnStatsBar } from "../JobStatsBar/TurnStatsBar";
|
||||
import { parseSpecialMarkers } from "./helpers";
|
||||
import { AssistantMessageActions } from "./components/AssistantMessageActions";
|
||||
import { CollapsedToolGroup } from "./components/CollapsedToolGroup";
|
||||
import { MessageAttachments } from "./components/MessageAttachments";
|
||||
import { MessagePartRenderer } from "./components/MessagePartRenderer";
|
||||
import { ReasoningCollapse } from "./components/ReasoningCollapse";
|
||||
import { ThinkingIndicator } from "./components/ThinkingIndicator";
|
||||
|
||||
type MessagePart = UIMessage<unknown, UIDataTypes, UITools>["parts"][number];
|
||||
|
||||
interface Props {
|
||||
messages: UIMessage<unknown, UIDataTypes, UITools>[];
|
||||
status: string;
|
||||
@@ -23,6 +27,132 @@ interface Props {
|
||||
sessionID?: string | null;
|
||||
}
|
||||
|
||||
function isCompletedToolPart(part: MessagePart): part is ToolUIPart {
|
||||
return (
|
||||
part.type.startsWith("tool-") &&
|
||||
"state" in part &&
|
||||
(part.state === "output-available" || part.state === "output-error")
|
||||
);
|
||||
}
|
||||
|
||||
type RenderSegment =
|
||||
| { kind: "part"; part: MessagePart; index: number }
|
||||
| { kind: "collapsed-group"; parts: ToolUIPart[] };
|
||||
|
||||
// Tool types that have custom renderers and should NOT be collapsed
|
||||
const CUSTOM_TOOL_TYPES = new Set([
|
||||
"tool-find_block",
|
||||
"tool-find_agent",
|
||||
"tool-find_library_agent",
|
||||
"tool-search_docs",
|
||||
"tool-get_doc_page",
|
||||
"tool-run_block",
|
||||
"tool-run_mcp_tool",
|
||||
"tool-run_agent",
|
||||
"tool-schedule_agent",
|
||||
"tool-create_agent",
|
||||
"tool-edit_agent",
|
||||
"tool-view_agent_output",
|
||||
"tool-search_feature_requests",
|
||||
"tool-create_feature_request",
|
||||
]);
|
||||
|
||||
/**
|
||||
* Groups consecutive completed generic tool parts into collapsed segments.
|
||||
* Non-generic tools (those with custom renderers) and active/streaming tools
|
||||
* are left as individual parts.
|
||||
*/
|
||||
function buildRenderSegments(
|
||||
parts: MessagePart[],
|
||||
baseIndex = 0,
|
||||
): RenderSegment[] {
|
||||
const segments: RenderSegment[] = [];
|
||||
let pendingGroup: Array<{ part: ToolUIPart; index: number }> | null = null;
|
||||
|
||||
function flushGroup() {
|
||||
if (!pendingGroup) return;
|
||||
if (pendingGroup.length >= 2) {
|
||||
segments.push({
|
||||
kind: "collapsed-group",
|
||||
parts: pendingGroup.map((p) => p.part),
|
||||
});
|
||||
} else {
|
||||
for (const p of pendingGroup) {
|
||||
segments.push({ kind: "part", part: p.part, index: p.index });
|
||||
}
|
||||
}
|
||||
pendingGroup = null;
|
||||
}
|
||||
|
||||
parts.forEach((part, i) => {
|
||||
const absoluteIndex = baseIndex + i;
|
||||
const isGenericCompletedTool =
|
||||
isCompletedToolPart(part) && !CUSTOM_TOOL_TYPES.has(part.type);
|
||||
|
||||
if (isGenericCompletedTool) {
|
||||
if (!pendingGroup) pendingGroup = [];
|
||||
pendingGroup.push({ part: part as ToolUIPart, index: absoluteIndex });
|
||||
} else {
|
||||
flushGroup();
|
||||
segments.push({ kind: "part", part, index: absoluteIndex });
|
||||
}
|
||||
});
|
||||
|
||||
flushGroup();
|
||||
return segments;
|
||||
}
|
||||
|
||||
/**
|
||||
* For finalized assistant messages, split parts into "reasoning" (intermediate
|
||||
* text + tools before the final response) and "response" (final text after the
|
||||
* last tool). If there are no tools, everything is response.
|
||||
*/
|
||||
function splitReasoningAndResponse(parts: MessagePart[]): {
|
||||
reasoning: MessagePart[];
|
||||
response: MessagePart[];
|
||||
} {
|
||||
const lastToolIndex = parts.findLastIndex((p) => p.type.startsWith("tool-"));
|
||||
|
||||
// No tools → everything is response
|
||||
if (lastToolIndex === -1) {
|
||||
return { reasoning: [], response: parts };
|
||||
}
|
||||
|
||||
// Check if there's any text after the last tool
|
||||
const hasResponseAfterTools = parts
|
||||
.slice(lastToolIndex + 1)
|
||||
.some((p) => p.type === "text");
|
||||
|
||||
if (!hasResponseAfterTools) {
|
||||
// No final text response → don't collapse anything
|
||||
return { reasoning: [], response: parts };
|
||||
}
|
||||
|
||||
return {
|
||||
reasoning: parts.slice(0, lastToolIndex + 1),
|
||||
response: parts.slice(lastToolIndex + 1),
|
||||
};
|
||||
}
|
||||
|
||||
function renderSegments(
|
||||
segments: RenderSegment[],
|
||||
messageID: string,
|
||||
): React.ReactNode[] {
|
||||
return segments.map((seg, segIdx) => {
|
||||
if (seg.kind === "collapsed-group") {
|
||||
return <CollapsedToolGroup key={`group-${segIdx}`} parts={seg.parts} />;
|
||||
}
|
||||
return (
|
||||
<MessagePartRenderer
|
||||
key={`${messageID}-${seg.index}`}
|
||||
part={seg.part}
|
||||
messageID={messageID}
|
||||
partIndex={seg.index}
|
||||
/>
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
/** Collect all messages belonging to a turn: the user message + every
|
||||
* assistant message up to (but not including) the next user message. */
|
||||
function getTurnMessages(
|
||||
@@ -119,6 +249,24 @@ export function ChatMessagesContainer({
|
||||
(p): p is FileUIPart => p.type === "file",
|
||||
);
|
||||
|
||||
// For finalized assistant messages, split into reasoning + response.
|
||||
// During streaming, show everything normally with tool collapsing.
|
||||
const isFinalized =
|
||||
message.role === "assistant" && !isCurrentlyStreaming;
|
||||
const { reasoning, response } = isFinalized
|
||||
? splitReasoningAndResponse(message.parts)
|
||||
: { reasoning: [] as MessagePart[], response: message.parts };
|
||||
const hasReasoning = reasoning.length > 0;
|
||||
|
||||
const responseStartIndex = message.parts.length - response.length;
|
||||
const responseSegments =
|
||||
message.role === "assistant"
|
||||
? buildRenderSegments(response, responseStartIndex)
|
||||
: null;
|
||||
const reasoningSegments = hasReasoning
|
||||
? buildRenderSegments(reasoning, 0)
|
||||
: null;
|
||||
|
||||
return (
|
||||
<Message from={message.role} key={message.id}>
|
||||
<MessageContent
|
||||
@@ -128,14 +276,21 @@ export function ChatMessagesContainer({
|
||||
"group-[.is-assistant]:bg-transparent group-[.is-assistant]:text-slate-900"
|
||||
}
|
||||
>
|
||||
{message.parts.map((part, i) => (
|
||||
<MessagePartRenderer
|
||||
key={`${message.id}-${i}`}
|
||||
part={part}
|
||||
messageID={message.id}
|
||||
partIndex={i}
|
||||
/>
|
||||
))}
|
||||
{hasReasoning && reasoningSegments && (
|
||||
<ReasoningCollapse>
|
||||
{renderSegments(reasoningSegments, message.id)}
|
||||
</ReasoningCollapse>
|
||||
)}
|
||||
{responseSegments
|
||||
? renderSegments(responseSegments, message.id)
|
||||
: message.parts.map((part, i) => (
|
||||
<MessagePartRenderer
|
||||
key={`${message.id}-${i}`}
|
||||
part={part}
|
||||
messageID={message.id}
|
||||
partIndex={i}
|
||||
/>
|
||||
))}
|
||||
{isLastInTurn && !isCurrentlyStreaming && (
|
||||
<TurnStatsBar
|
||||
turnMessages={getTurnMessages(messages, messageIndex)}
|
||||
|
||||
@@ -0,0 +1,152 @@
|
||||
"use client";
|
||||
|
||||
import { useId, useState } from "react";
|
||||
import {
|
||||
ArrowsClockwiseIcon,
|
||||
CaretRightIcon,
|
||||
CheckCircleIcon,
|
||||
FileIcon,
|
||||
FilesIcon,
|
||||
GearIcon,
|
||||
GlobeIcon,
|
||||
ListChecksIcon,
|
||||
MagnifyingGlassIcon,
|
||||
MonitorIcon,
|
||||
PencilSimpleIcon,
|
||||
TerminalIcon,
|
||||
TrashIcon,
|
||||
WarningDiamondIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import type { ToolUIPart } from "ai";
|
||||
import {
|
||||
type ToolCategory,
|
||||
extractToolName,
|
||||
getAnimationText,
|
||||
getToolCategory,
|
||||
} from "../../../tools/GenericTool/helpers";
|
||||
|
||||
interface Props {
|
||||
parts: ToolUIPart[];
|
||||
}
|
||||
|
||||
/** Category icon matching GenericTool's ToolIcon for completed states. */
|
||||
function EntryIcon({
|
||||
category,
|
||||
isError,
|
||||
}: {
|
||||
category: ToolCategory;
|
||||
isError: boolean;
|
||||
}) {
|
||||
if (isError) {
|
||||
return (
|
||||
<WarningDiamondIcon size={14} weight="regular" className="text-red-500" />
|
||||
);
|
||||
}
|
||||
|
||||
const iconClass = "text-green-500";
|
||||
switch (category) {
|
||||
case "bash":
|
||||
return <TerminalIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "web":
|
||||
return <GlobeIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "browser":
|
||||
return <MonitorIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "file-read":
|
||||
case "file-write":
|
||||
return <FileIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "file-delete":
|
||||
return <TrashIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "file-list":
|
||||
return <FilesIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "search":
|
||||
return (
|
||||
<MagnifyingGlassIcon size={14} weight="regular" className={iconClass} />
|
||||
);
|
||||
case "edit":
|
||||
return (
|
||||
<PencilSimpleIcon size={14} weight="regular" className={iconClass} />
|
||||
);
|
||||
case "todo":
|
||||
return (
|
||||
<ListChecksIcon size={14} weight="regular" className={iconClass} />
|
||||
);
|
||||
case "compaction":
|
||||
return (
|
||||
<ArrowsClockwiseIcon size={14} weight="regular" className={iconClass} />
|
||||
);
|
||||
default:
|
||||
return <GearIcon size={14} weight="regular" className={iconClass} />;
|
||||
}
|
||||
}
|
||||
|
||||
export function CollapsedToolGroup({ parts }: Props) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const panelId = useId();
|
||||
|
||||
const errorCount = parts.filter((p) => p.state === "output-error").length;
|
||||
const label =
|
||||
errorCount > 0
|
||||
? `${parts.length} tool calls (${errorCount} failed)`
|
||||
: `${parts.length} tool calls completed`;
|
||||
|
||||
return (
|
||||
<div className="py-1">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setExpanded(!expanded)}
|
||||
aria-expanded={expanded}
|
||||
aria-controls={panelId}
|
||||
className="flex items-center gap-1.5 text-sm text-muted-foreground transition-colors hover:text-foreground"
|
||||
>
|
||||
<CaretRightIcon
|
||||
size={12}
|
||||
weight="bold"
|
||||
className={
|
||||
"transition-transform duration-150 " + (expanded ? "rotate-90" : "")
|
||||
}
|
||||
/>
|
||||
{errorCount > 0 ? (
|
||||
<WarningDiamondIcon
|
||||
size={14}
|
||||
weight="regular"
|
||||
className="text-red-500"
|
||||
/>
|
||||
) : (
|
||||
<CheckCircleIcon
|
||||
size={14}
|
||||
weight="regular"
|
||||
className="text-green-500"
|
||||
/>
|
||||
)}
|
||||
<span>{label}</span>
|
||||
</button>
|
||||
|
||||
{expanded && (
|
||||
<div
|
||||
id={panelId}
|
||||
className="ml-5 mt-1 space-y-0.5 border-l border-neutral-200 pl-3"
|
||||
>
|
||||
{parts.map((part) => {
|
||||
const toolName = extractToolName(part);
|
||||
const category = getToolCategory(toolName);
|
||||
const text = getAnimationText(part, category);
|
||||
const isError = part.state === "output-error";
|
||||
|
||||
return (
|
||||
<div
|
||||
key={part.toolCallId}
|
||||
className={
|
||||
"flex items-center gap-1.5 text-xs " +
|
||||
(isError ? "text-red-500" : "text-muted-foreground")
|
||||
}
|
||||
>
|
||||
<EntryIcon category={category} isError={isError} />
|
||||
<span>{text}</span>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -15,6 +15,7 @@ export function FeedbackModal({ isOpen, onSubmit, onCancel }: Props) {
|
||||
const [comment, setComment] = useState("");
|
||||
|
||||
function handleSubmit() {
|
||||
if (!comment.trim()) return;
|
||||
onSubmit(comment);
|
||||
setComment("");
|
||||
}
|
||||
@@ -36,7 +37,7 @@ export function FeedbackModal({ isOpen, onSubmit, onCancel }: Props) {
|
||||
>
|
||||
<Dialog.Content>
|
||||
<div className="mx-auto w-[95%] space-y-4">
|
||||
<p className="text-sm text-slate-600">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Your feedback helps us improve. Share details below.
|
||||
</p>
|
||||
<Textarea
|
||||
@@ -48,12 +49,18 @@ export function FeedbackModal({ isOpen, onSubmit, onCancel }: Props) {
|
||||
className="resize-none"
|
||||
/>
|
||||
<div className="flex items-center justify-between">
|
||||
<p className="text-xs text-slate-400">{comment.length}/2000</p>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
{comment.length}/2000
|
||||
</p>
|
||||
<div className="flex gap-2">
|
||||
<Button variant="outline" size="sm" onClick={handleClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button size="sm" onClick={handleSubmit}>
|
||||
<Button
|
||||
size="sm"
|
||||
onClick={handleSubmit}
|
||||
disabled={!comment.trim()}
|
||||
>
|
||||
Submit feedback
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
SearchFeatureRequestsTool,
|
||||
} from "../../../tools/FeatureRequests/FeatureRequests";
|
||||
import { FindAgentsTool } from "../../../tools/FindAgents/FindAgents";
|
||||
import { FolderTool } from "../../../tools/FolderTool/FolderTool";
|
||||
import { FindBlocksTool } from "../../../tools/FindBlocks/FindBlocks";
|
||||
import { GenericTool } from "../../../tools/GenericTool/GenericTool";
|
||||
import { RunAgentTool } from "../../../tools/RunAgent/RunAgent";
|
||||
@@ -145,6 +146,13 @@ export function MessagePartRenderer({ part, messageID, partIndex }: Props) {
|
||||
return <SearchFeatureRequestsTool key={key} part={part as ToolUIPart} />;
|
||||
case "tool-create_feature_request":
|
||||
return <CreateFeatureRequestTool key={key} part={part as ToolUIPart} />;
|
||||
case "tool-create_folder":
|
||||
case "tool-list_folders":
|
||||
case "tool-update_folder":
|
||||
case "tool-move_folder":
|
||||
case "tool-delete_folder":
|
||||
case "tool-move_agents_to_folder":
|
||||
return <FolderTool key={key} part={part as ToolUIPart} />;
|
||||
default:
|
||||
// Render a generic tool indicator for SDK built-in
|
||||
// tools (Read, Glob, Grep, etc.) or any unrecognized tool
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
"use client";
|
||||
|
||||
import { LightbulbIcon } from "@phosphor-icons/react";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
|
||||
interface Props {
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
export function ReasoningCollapse({ children }: Props) {
|
||||
return (
|
||||
<Dialog title="Reasoning">
|
||||
<Dialog.Trigger>
|
||||
<button
|
||||
type="button"
|
||||
className="flex items-center gap-1 text-xs text-zinc-500 transition-colors hover:text-zinc-700"
|
||||
>
|
||||
<LightbulbIcon size={12} weight="bold" />
|
||||
<span>Show reasoning</span>
|
||||
</button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="space-y-1">{children}</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -7,11 +7,8 @@ import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { SpinnerGapIcon } from "@phosphor-icons/react";
|
||||
import { motion } from "framer-motion";
|
||||
import { useEffect, useState } from "react";
|
||||
import {
|
||||
getGreetingName,
|
||||
getInputPlaceholder,
|
||||
getQuickActions,
|
||||
} from "./helpers";
|
||||
import { getGreetingName, getInputPlaceholder } from "./helpers";
|
||||
import { useQuickActions } from "./useQuickActions";
|
||||
|
||||
interface Props {
|
||||
inputLayoutId: string;
|
||||
@@ -33,7 +30,7 @@ export function EmptySession({
|
||||
}: Props) {
|
||||
const { user } = useSupabase();
|
||||
const greetingName = getGreetingName(user);
|
||||
const quickActions = getQuickActions();
|
||||
const quickActions = useQuickActions(user);
|
||||
const [loadingAction, setLoadingAction] = useState<string | null>(null);
|
||||
const [inputPlaceholder, setInputPlaceholder] = useState(
|
||||
getInputPlaceholder(),
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
import { User } from "@supabase/supabase-js";
|
||||
|
||||
export const DEFAULT_QUICK_ACTIONS = [
|
||||
"I don't know where to start, just ask me stuff",
|
||||
"I do the same thing every week and it's killing me",
|
||||
"Help me find where I'm wasting my time",
|
||||
];
|
||||
|
||||
export function getInputPlaceholder(width?: number) {
|
||||
if (!width) return "What's your role and what eats up most of your day?";
|
||||
|
||||
@@ -12,12 +18,15 @@ export function getInputPlaceholder(width?: number) {
|
||||
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
|
||||
}
|
||||
|
||||
export function getQuickActions() {
|
||||
return [
|
||||
"I don't know where to start, just ask me stuff",
|
||||
"I do the same thing every week and it's killing me",
|
||||
"Help me find where I'm wasting my time",
|
||||
];
|
||||
export function getQuickActions(prompts?: string[] | null) {
|
||||
const normalizedPrompts =
|
||||
prompts
|
||||
?.map((prompt) => prompt.trim())
|
||||
.filter((prompt) => prompt.length > 0) ?? [];
|
||||
|
||||
return normalizedPrompts.length > 0
|
||||
? normalizedPrompts
|
||||
: DEFAULT_QUICK_ACTIONS;
|
||||
}
|
||||
|
||||
export function getGreetingName(user?: User | null) {
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
import { renderHook } from "@testing-library/react";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import type { User } from "@supabase/supabase-js";
|
||||
import { DEFAULT_QUICK_ACTIONS } from "./helpers";
|
||||
import { useQuickActions } from "./useQuickActions";
|
||||
|
||||
const { mockUseGetV1GetBusinessUnderstandingPrompts } = vi.hoisted(() => ({
|
||||
mockUseGetV1GetBusinessUnderstandingPrompts: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("@/app/api/__generated__/endpoints/auth/auth", () => ({
|
||||
useGetV1GetBusinessUnderstandingPrompts:
|
||||
mockUseGetV1GetBusinessUnderstandingPrompts,
|
||||
}));
|
||||
|
||||
function makeUser() {
|
||||
return { id: "user-1" } as User;
|
||||
}
|
||||
|
||||
describe("useQuickActions", () => {
|
||||
it("uses server prompts when available", () => {
|
||||
mockUseGetV1GetBusinessUnderstandingPrompts.mockReturnValue({
|
||||
data: ["Help me automate onboarding", "Find my biggest bottleneck"],
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useQuickActions(makeUser()));
|
||||
|
||||
expect(result.current).toEqual([
|
||||
"Help me automate onboarding",
|
||||
"Find my biggest bottleneck",
|
||||
]);
|
||||
expect(mockUseGetV1GetBusinessUnderstandingPrompts).toHaveBeenCalledWith({
|
||||
query: expect.objectContaining({ enabled: true }),
|
||||
});
|
||||
});
|
||||
|
||||
it("falls back to defaults when the user is not authenticated", () => {
|
||||
mockUseGetV1GetBusinessUnderstandingPrompts.mockReturnValue({
|
||||
data: undefined,
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useQuickActions(null));
|
||||
|
||||
expect(result.current).toEqual(DEFAULT_QUICK_ACTIONS);
|
||||
expect(mockUseGetV1GetBusinessUnderstandingPrompts).toHaveBeenCalledWith({
|
||||
query: expect.objectContaining({ enabled: false }),
|
||||
});
|
||||
});
|
||||
|
||||
it("falls back to defaults when the API returns no prompts", () => {
|
||||
mockUseGetV1GetBusinessUnderstandingPrompts.mockReturnValue({
|
||||
data: [],
|
||||
error: new Error("no prompts"),
|
||||
isError: true,
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useQuickActions(makeUser()));
|
||||
|
||||
expect(result.current).toEqual(DEFAULT_QUICK_ACTIONS);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,17 @@
|
||||
"use client";
|
||||
|
||||
import { useGetV1GetBusinessUnderstandingPrompts } from "@/app/api/__generated__/endpoints/auth/auth";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { User } from "@supabase/supabase-js";
|
||||
import { getQuickActions } from "./helpers";
|
||||
|
||||
export function useQuickActions(user?: User | null) {
|
||||
const quickPrompts = useGetV1GetBusinessUnderstandingPrompts({
|
||||
query: {
|
||||
enabled: Boolean(user),
|
||||
select: (response) => okData(response)?.prompts,
|
||||
},
|
||||
}).data;
|
||||
|
||||
return getQuickActions(quickPrompts);
|
||||
}
|
||||
@@ -181,6 +181,14 @@ export function convertChatSessionMessagesToUiMessages(
|
||||
|
||||
if (parts.length === 0) return;
|
||||
|
||||
// Merge consecutive assistant messages into a single UIMessage
|
||||
// to avoid split bubbles on page reload.
|
||||
const prevUI = uiMessages[uiMessages.length - 1];
|
||||
if (msg.role === "assistant" && prevUI && prevUI.role === "assistant") {
|
||||
prevUI.parts.push(...parts);
|
||||
return;
|
||||
}
|
||||
|
||||
uiMessages.push({
|
||||
id: `${sessionId}-${index}`,
|
||||
role: msg.role,
|
||||
|
||||
@@ -10,7 +10,7 @@ import {
|
||||
WarningDiamondIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import type { ToolUIPart } from "ai";
|
||||
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
|
||||
import { ScaleLoader } from "../../components/ScaleLoader/ScaleLoader";
|
||||
|
||||
export type CreateAgentToolOutput =
|
||||
| AgentPreviewResponse
|
||||
@@ -134,7 +134,7 @@ export function ToolIcon({
|
||||
);
|
||||
}
|
||||
if (isStreaming) {
|
||||
return <OrbitLoader size={24} />;
|
||||
return <ScaleLoader size={14} />;
|
||||
}
|
||||
return <PlusIcon size={14} weight="regular" className="text-neutral-400" />;
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import {
|
||||
WarningDiamondIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import type { ToolUIPart } from "ai";
|
||||
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
|
||||
import { ScaleLoader } from "../../components/ScaleLoader/ScaleLoader";
|
||||
|
||||
export type EditAgentToolOutput =
|
||||
| AgentPreviewResponse
|
||||
@@ -121,7 +121,7 @@ export function ToolIcon({
|
||||
);
|
||||
}
|
||||
if (isStreaming) {
|
||||
return <OrbitLoader size={24} />;
|
||||
return <ScaleLoader size={14} />;
|
||||
}
|
||||
return (
|
||||
<PencilLineIcon size={14} weight="regular" className="text-neutral-400" />
|
||||
|
||||
@@ -0,0 +1,296 @@
|
||||
"use client";
|
||||
|
||||
import type { ToolUIPart } from "ai";
|
||||
import {
|
||||
FileIcon,
|
||||
FolderIcon,
|
||||
FolderPlusIcon,
|
||||
FoldersIcon,
|
||||
TrashIcon,
|
||||
WarningDiamondIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import {
|
||||
File as TreeFile,
|
||||
Folder as TreeFolder,
|
||||
Tree,
|
||||
type TreeViewElement,
|
||||
} from "@/components/molecules/file-tree";
|
||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
||||
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
|
||||
import {
|
||||
ContentCard,
|
||||
ContentCardHeader,
|
||||
ContentCardTitle,
|
||||
ContentGrid,
|
||||
ContentHint,
|
||||
ContentMessage,
|
||||
} from "../../components/ToolAccordion/AccordionContent";
|
||||
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
|
||||
import {
|
||||
getAnimationText,
|
||||
getFolderToolOutput,
|
||||
isAgentsMoved,
|
||||
isErrorOutput,
|
||||
isFolderCreated,
|
||||
isFolderDeleted,
|
||||
isFolderList,
|
||||
isFolderMoved,
|
||||
isFolderUpdated,
|
||||
type FolderInfo,
|
||||
type FolderToolOutput,
|
||||
type FolderTreeInfo,
|
||||
} from "./helpers";
|
||||
|
||||
interface Props {
|
||||
part: ToolUIPart;
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Icons */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function ToolStatusIcon({
|
||||
isStreaming,
|
||||
isError,
|
||||
}: {
|
||||
isStreaming: boolean;
|
||||
isError: boolean;
|
||||
}) {
|
||||
if (isError) {
|
||||
return (
|
||||
<WarningDiamondIcon size={14} weight="regular" className="text-red-500" />
|
||||
);
|
||||
}
|
||||
if (isStreaming) {
|
||||
return <OrbitLoader size={14} />;
|
||||
}
|
||||
return <FolderIcon size={14} weight="regular" className="text-neutral-400" />;
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Folder card */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function FolderCard({ folder }: { folder: FolderInfo }) {
|
||||
return (
|
||||
<ContentCard>
|
||||
<ContentCardHeader>
|
||||
<div className="flex items-center gap-2">
|
||||
{folder.color ? (
|
||||
<span
|
||||
className="inline-block h-3 w-3 rounded-full"
|
||||
style={{ backgroundColor: folder.color }}
|
||||
/>
|
||||
) : (
|
||||
<FolderIcon size={14} weight="fill" className="text-neutral-600" />
|
||||
)}
|
||||
<ContentCardTitle>{folder.name}</ContentCardTitle>
|
||||
</div>
|
||||
</ContentCardHeader>
|
||||
<ContentHint>
|
||||
{folder.agent_count} agent{folder.agent_count !== 1 ? "s" : ""}
|
||||
{folder.subfolder_count > 0 &&
|
||||
` · ${folder.subfolder_count} subfolder${folder.subfolder_count !== 1 ? "s" : ""}`}
|
||||
</ContentHint>
|
||||
{folder.agents && folder.agents.length > 0 && (
|
||||
<div className="mt-2 space-y-1 border-t border-neutral-200 pt-2">
|
||||
{folder.agents.map((a) => (
|
||||
<div key={a.id} className="flex items-center gap-1.5">
|
||||
<FileIcon
|
||||
size={12}
|
||||
weight="duotone"
|
||||
className="text-neutral-600"
|
||||
/>
|
||||
<span className="text-xs text-zinc-600">{a.name}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</ContentCard>
|
||||
);
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Tree renderer using file-tree component */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
type TreeNode = TreeViewElement & { isAgent?: boolean };
|
||||
|
||||
function folderTreeToElements(nodes: FolderTreeInfo[]): TreeNode[] {
|
||||
return nodes.map((node) => {
|
||||
const children: TreeNode[] = [
|
||||
...folderTreeToElements(node.children),
|
||||
...(node.agents ?? []).map((a) => ({
|
||||
id: a.id,
|
||||
name: a.name,
|
||||
isAgent: true,
|
||||
})),
|
||||
];
|
||||
return {
|
||||
id: node.id,
|
||||
name: `${node.name} (${node.agent_count} agent${node.agent_count !== 1 ? "s" : ""})`,
|
||||
children: children.length > 0 ? children : undefined,
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
function collectAllIDs(nodes: FolderTreeInfo[]): string[] {
|
||||
return nodes.flatMap((n) => [n.id, ...collectAllIDs(n.children)]);
|
||||
}
|
||||
|
||||
function FolderTreeView({ tree }: { tree: FolderTreeInfo[] }) {
|
||||
const elements = folderTreeToElements(tree);
|
||||
const allIDs = collectAllIDs(tree);
|
||||
|
||||
return (
|
||||
<Tree
|
||||
initialExpandedItems={allIDs}
|
||||
elements={elements}
|
||||
openIcon={
|
||||
<FolderIcon size={16} weight="fill" className="text-neutral-600" />
|
||||
}
|
||||
closeIcon={
|
||||
<FolderIcon size={16} weight="duotone" className="text-neutral-600" />
|
||||
}
|
||||
className="max-h-64"
|
||||
>
|
||||
{elements.map((el) => (
|
||||
<FolderTreeNodes key={el.id} element={el} />
|
||||
))}
|
||||
</Tree>
|
||||
);
|
||||
}
|
||||
|
||||
function FolderTreeNodes({ element }: { element: TreeNode }) {
|
||||
if (element.isAgent) {
|
||||
return (
|
||||
<TreeFile
|
||||
value={element.id}
|
||||
fileIcon={
|
||||
<FileIcon size={14} weight="duotone" className="text-neutral-600" />
|
||||
}
|
||||
>
|
||||
<span className="text-sm text-zinc-700">{element.name}</span>
|
||||
</TreeFile>
|
||||
);
|
||||
}
|
||||
|
||||
if (element.children && element.children.length > 0) {
|
||||
return (
|
||||
<TreeFolder value={element.id} element={element.name} isSelectable>
|
||||
{element.children.map((child) => (
|
||||
<FolderTreeNodes key={child.id} element={child as TreeNode} />
|
||||
))}
|
||||
</TreeFolder>
|
||||
);
|
||||
}
|
||||
|
||||
return <TreeFolder value={element.id} element={element.name} isSelectable />;
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Accordion content per output type */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function AccordionContent({ output }: { output: FolderToolOutput }) {
|
||||
if (isFolderCreated(output)) {
|
||||
return (
|
||||
<ContentGrid>
|
||||
<FolderCard folder={output.folder} />
|
||||
</ContentGrid>
|
||||
);
|
||||
}
|
||||
|
||||
if (isFolderList(output)) {
|
||||
if (output.tree && output.tree.length > 0) {
|
||||
return <FolderTreeView tree={output.tree} />;
|
||||
}
|
||||
if (output.folders && output.folders.length > 0) {
|
||||
return (
|
||||
<ContentGrid className="sm:grid-cols-2">
|
||||
{output.folders.map((folder) => (
|
||||
<FolderCard key={folder.id} folder={folder} />
|
||||
))}
|
||||
</ContentGrid>
|
||||
);
|
||||
}
|
||||
return <ContentMessage>No folders found.</ContentMessage>;
|
||||
}
|
||||
|
||||
if (isFolderUpdated(output) || isFolderMoved(output)) {
|
||||
return (
|
||||
<ContentGrid>
|
||||
<FolderCard folder={output.folder} />
|
||||
</ContentGrid>
|
||||
);
|
||||
}
|
||||
|
||||
if (isFolderDeleted(output)) {
|
||||
return <ContentMessage>{output.message}</ContentMessage>;
|
||||
}
|
||||
|
||||
if (isAgentsMoved(output)) {
|
||||
return <ContentMessage>{output.message}</ContentMessage>;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Main component */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function getAccordionTitle(output: FolderToolOutput): string {
|
||||
if (isFolderCreated(output)) return `Created "${output.folder.name}"`;
|
||||
if (isFolderList(output))
|
||||
return `${output.count} folder${output.count !== 1 ? "s" : ""}`;
|
||||
if (isFolderUpdated(output)) return `Updated "${output.folder.name}"`;
|
||||
if (isFolderMoved(output)) return `Moved "${output.folder.name}"`;
|
||||
if (isFolderDeleted(output)) return "Folder deleted";
|
||||
if (isAgentsMoved(output))
|
||||
return `Moved ${output.count} agent${output.count !== 1 ? "s" : ""}`;
|
||||
return "Folder operation";
|
||||
}
|
||||
|
||||
function getAccordionIcon(output: FolderToolOutput) {
|
||||
if (isFolderCreated(output))
|
||||
return <FolderPlusIcon size={32} weight="light" />;
|
||||
if (isFolderList(output)) return <FoldersIcon size={32} weight="light" />;
|
||||
if (isFolderDeleted(output)) return <TrashIcon size={32} weight="light" />;
|
||||
return <FolderIcon size={32} weight="light" />;
|
||||
}
|
||||
|
||||
export function FolderTool({ part }: Props) {
|
||||
const text = getAnimationText(part);
|
||||
const output = getFolderToolOutput(part);
|
||||
|
||||
const isStreaming =
|
||||
part.state === "input-streaming" || part.state === "input-available";
|
||||
const isError =
|
||||
part.state === "output-error" || (!!output && isErrorOutput(output));
|
||||
|
||||
const hasContent =
|
||||
part.state === "output-available" && !!output && !isErrorOutput(output);
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<ToolStatusIcon isStreaming={isStreaming} isError={isError} />
|
||||
<MorphingTextAnimation
|
||||
text={text}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{hasContent && output && (
|
||||
<ToolAccordion
|
||||
icon={getAccordionIcon(output)}
|
||||
title={getAccordionTitle(output)}
|
||||
defaultExpanded={isFolderList(output)}
|
||||
>
|
||||
<AccordionContent output={output} />
|
||||
</ToolAccordion>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
import type { ToolUIPart } from "ai";
|
||||
|
||||
interface FolderAgentSummary {
|
||||
id: string;
|
||||
name: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
interface FolderInfo {
|
||||
id: string;
|
||||
name: string;
|
||||
parent_id?: string | null;
|
||||
icon?: string | null;
|
||||
color?: string | null;
|
||||
agent_count: number;
|
||||
subfolder_count: number;
|
||||
agents?: FolderAgentSummary[] | null;
|
||||
}
|
||||
|
||||
interface FolderTreeInfo extends FolderInfo {
|
||||
children: FolderTreeInfo[];
|
||||
}
|
||||
|
||||
interface FolderCreatedOutput {
|
||||
type: "folder_created";
|
||||
message: string;
|
||||
folder: FolderInfo;
|
||||
}
|
||||
|
||||
interface FolderListOutput {
|
||||
type: "folder_list";
|
||||
message: string;
|
||||
folders?: FolderInfo[];
|
||||
tree?: FolderTreeInfo[];
|
||||
count: number;
|
||||
}
|
||||
|
||||
interface FolderUpdatedOutput {
|
||||
type: "folder_updated";
|
||||
message: string;
|
||||
folder: FolderInfo;
|
||||
}
|
||||
|
||||
interface FolderMovedOutput {
|
||||
type: "folder_moved";
|
||||
message: string;
|
||||
folder: FolderInfo;
|
||||
target_parent_id?: string | null;
|
||||
}
|
||||
|
||||
interface FolderDeletedOutput {
|
||||
type: "folder_deleted";
|
||||
message: string;
|
||||
folder_id: string;
|
||||
}
|
||||
|
||||
interface AgentsMovedOutput {
|
||||
type: "agents_moved_to_folder";
|
||||
message: string;
|
||||
agent_ids: string[];
|
||||
folder_id?: string | null;
|
||||
count: number;
|
||||
}
|
||||
|
||||
interface ErrorOutput {
|
||||
type: "error";
|
||||
message: string;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export type FolderToolOutput =
|
||||
| FolderCreatedOutput
|
||||
| FolderListOutput
|
||||
| FolderUpdatedOutput
|
||||
| FolderMovedOutput
|
||||
| FolderDeletedOutput
|
||||
| AgentsMovedOutput
|
||||
| ErrorOutput;
|
||||
|
||||
export type { FolderAgentSummary, FolderInfo, FolderTreeInfo };
|
||||
|
||||
function parseOutput(output: unknown): FolderToolOutput | null {
|
||||
if (!output) return null;
|
||||
if (typeof output === "string") {
|
||||
const trimmed = output.trim();
|
||||
if (!trimmed) return null;
|
||||
try {
|
||||
return parseOutput(JSON.parse(trimmed) as unknown);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
if (typeof output === "object") {
|
||||
const obj = output as Record<string, unknown>;
|
||||
if (typeof obj.type === "string") {
|
||||
return output as FolderToolOutput;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function getFolderToolOutput(part: {
|
||||
output?: unknown;
|
||||
}): FolderToolOutput | null {
|
||||
return parseOutput(part.output);
|
||||
}
|
||||
|
||||
export function isFolderCreated(o: FolderToolOutput): o is FolderCreatedOutput {
|
||||
return o.type === "folder_created";
|
||||
}
|
||||
|
||||
export function isFolderList(o: FolderToolOutput): o is FolderListOutput {
|
||||
return o.type === "folder_list";
|
||||
}
|
||||
|
||||
export function isFolderUpdated(o: FolderToolOutput): o is FolderUpdatedOutput {
|
||||
return o.type === "folder_updated";
|
||||
}
|
||||
|
||||
export function isFolderMoved(o: FolderToolOutput): o is FolderMovedOutput {
|
||||
return o.type === "folder_moved";
|
||||
}
|
||||
|
||||
export function isFolderDeleted(o: FolderToolOutput): o is FolderDeletedOutput {
|
||||
return o.type === "folder_deleted";
|
||||
}
|
||||
|
||||
export function isAgentsMoved(o: FolderToolOutput): o is AgentsMovedOutput {
|
||||
return o.type === "agents_moved_to_folder";
|
||||
}
|
||||
|
||||
export function isErrorOutput(o: FolderToolOutput): o is ErrorOutput {
|
||||
return o.type === "error";
|
||||
}
|
||||
|
||||
export function getAnimationText(part: {
|
||||
type: string;
|
||||
state: ToolUIPart["state"];
|
||||
output?: unknown;
|
||||
}): string {
|
||||
const toolName = part.type.replace(/^tool-/, "");
|
||||
|
||||
switch (part.state) {
|
||||
case "input-streaming":
|
||||
case "input-available": {
|
||||
switch (toolName) {
|
||||
case "create_folder":
|
||||
return "Creating folder…";
|
||||
case "list_folders":
|
||||
return "Loading folders…";
|
||||
case "update_folder":
|
||||
return "Updating folder…";
|
||||
case "move_folder":
|
||||
return "Moving folder…";
|
||||
case "delete_folder":
|
||||
return "Deleting folder…";
|
||||
case "move_agents_to_folder":
|
||||
return "Moving agents…";
|
||||
default:
|
||||
return "Managing folders…";
|
||||
}
|
||||
}
|
||||
case "output-available": {
|
||||
const output = getFolderToolOutput(part);
|
||||
if (!output) return "Done";
|
||||
if (isErrorOutput(output)) return "Folder operation failed";
|
||||
return output.message;
|
||||
}
|
||||
case "output-error":
|
||||
return "Folder operation failed";
|
||||
default:
|
||||
return "Managing folders…";
|
||||
}
|
||||
}
|
||||
@@ -31,6 +31,13 @@ import {
|
||||
OutputItem,
|
||||
} from "@/components/contextual/OutputRenderers";
|
||||
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||
import {
|
||||
type ToolCategory,
|
||||
extractToolName,
|
||||
getAnimationText,
|
||||
getToolCategory,
|
||||
truncate,
|
||||
} from "./helpers";
|
||||
|
||||
interface Props {
|
||||
part: ToolUIPart;
|
||||
@@ -48,77 +55,6 @@ function RenderMedia({
|
||||
return <OutputItem value={value} metadata={metadata} renderer={renderer} />;
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Tool name helpers */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function extractToolName(part: ToolUIPart): string {
|
||||
return part.type.replace(/^tool-/, "");
|
||||
}
|
||||
|
||||
function formatToolName(name: string): string {
|
||||
return name.replace(/_/g, " ").replace(/^\w/, (c) => c.toUpperCase());
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Tool categorization */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
type ToolCategory =
|
||||
| "bash"
|
||||
| "web"
|
||||
| "browser"
|
||||
| "file-read"
|
||||
| "file-write"
|
||||
| "file-delete"
|
||||
| "file-list"
|
||||
| "search"
|
||||
| "edit"
|
||||
| "todo"
|
||||
| "compaction"
|
||||
| "other";
|
||||
|
||||
function getToolCategory(toolName: string): ToolCategory {
|
||||
switch (toolName) {
|
||||
case "bash_exec":
|
||||
return "bash";
|
||||
case "web_fetch":
|
||||
case "WebSearch":
|
||||
case "WebFetch":
|
||||
return "web";
|
||||
case "browser_navigate":
|
||||
case "browser_act":
|
||||
case "browser_screenshot":
|
||||
return "browser";
|
||||
case "read_workspace_file":
|
||||
case "read_file":
|
||||
case "Read":
|
||||
return "file-read";
|
||||
case "write_workspace_file":
|
||||
case "write_file":
|
||||
case "Write":
|
||||
return "file-write";
|
||||
case "delete_workspace_file":
|
||||
return "file-delete";
|
||||
case "list_workspace_files":
|
||||
case "glob":
|
||||
case "Glob":
|
||||
return "file-list";
|
||||
case "grep":
|
||||
case "Grep":
|
||||
return "search";
|
||||
case "edit_file":
|
||||
case "Edit":
|
||||
return "edit";
|
||||
case "TodoWrite":
|
||||
return "todo";
|
||||
case "context_compaction":
|
||||
return "compaction";
|
||||
default:
|
||||
return "other";
|
||||
}
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Tool icon */
|
||||
/* ------------------------------------------------------------------ */
|
||||
@@ -141,7 +77,7 @@ function ToolIcon({
|
||||
return <OrbitLoader size={14} />;
|
||||
}
|
||||
|
||||
const iconClass = "text-neutral-400";
|
||||
const iconClass = "text-green-500";
|
||||
switch (category) {
|
||||
case "bash":
|
||||
return <TerminalIcon size={14} weight="regular" className={iconClass} />;
|
||||
@@ -210,191 +146,6 @@ function AccordionIcon({ category }: { category: ToolCategory }) {
|
||||
}
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Input extraction */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function getInputSummary(toolName: string, input: unknown): string | null {
|
||||
if (!input || typeof input !== "object") return null;
|
||||
const inp = input as Record<string, unknown>;
|
||||
|
||||
switch (toolName) {
|
||||
case "bash_exec":
|
||||
return typeof inp.command === "string" ? inp.command : null;
|
||||
case "web_fetch":
|
||||
case "WebFetch":
|
||||
return typeof inp.url === "string" ? inp.url : null;
|
||||
case "WebSearch":
|
||||
return typeof inp.query === "string" ? inp.query : null;
|
||||
case "browser_navigate":
|
||||
return typeof inp.url === "string" ? inp.url : null;
|
||||
case "browser_act":
|
||||
return typeof inp.action === "string"
|
||||
? inp.target
|
||||
? `${inp.action} ${inp.target}`
|
||||
: (inp.action as string)
|
||||
: null;
|
||||
case "browser_screenshot":
|
||||
return null;
|
||||
case "read_workspace_file":
|
||||
case "read_file":
|
||||
case "Read":
|
||||
return (
|
||||
(typeof inp.file_path === "string" ? inp.file_path : null) ??
|
||||
(typeof inp.path === "string" ? inp.path : null)
|
||||
);
|
||||
case "write_workspace_file":
|
||||
case "write_file":
|
||||
case "Write":
|
||||
return (
|
||||
(typeof inp.file_path === "string" ? inp.file_path : null) ??
|
||||
(typeof inp.path === "string" ? inp.path : null)
|
||||
);
|
||||
case "delete_workspace_file":
|
||||
return typeof inp.file_path === "string" ? inp.file_path : null;
|
||||
case "glob":
|
||||
case "Glob":
|
||||
return typeof inp.pattern === "string" ? inp.pattern : null;
|
||||
case "grep":
|
||||
case "Grep":
|
||||
return typeof inp.pattern === "string" ? inp.pattern : null;
|
||||
case "edit_file":
|
||||
case "Edit":
|
||||
return typeof inp.file_path === "string" ? inp.file_path : null;
|
||||
case "TodoWrite": {
|
||||
// Extract the in-progress task name for the status line
|
||||
const todos = Array.isArray(inp.todos) ? inp.todos : [];
|
||||
const active = todos.find(
|
||||
(t: Record<string, unknown>) => t.status === "in_progress",
|
||||
);
|
||||
if (active && typeof active.activeForm === "string")
|
||||
return active.activeForm;
|
||||
if (active && typeof active.content === "string") return active.content;
|
||||
return null;
|
||||
}
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function truncate(text: string, maxLen: number): string {
|
||||
if (text.length <= maxLen) return text;
|
||||
return text.slice(0, maxLen).trimEnd() + "…";
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Animation text */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function getAnimationText(part: ToolUIPart, category: ToolCategory): string {
|
||||
const toolName = extractToolName(part);
|
||||
const summary = getInputSummary(toolName, part.input);
|
||||
const shortSummary = summary ? truncate(summary, 60) : null;
|
||||
|
||||
switch (part.state) {
|
||||
case "input-streaming":
|
||||
case "input-available": {
|
||||
switch (category) {
|
||||
case "bash":
|
||||
return shortSummary ? `Running: ${shortSummary}` : "Running command…";
|
||||
case "web":
|
||||
if (toolName === "WebSearch") {
|
||||
return shortSummary
|
||||
? `Searching "${shortSummary}"`
|
||||
: "Searching the web…";
|
||||
}
|
||||
return shortSummary
|
||||
? `Fetching ${shortSummary}`
|
||||
: "Fetching web content…";
|
||||
case "browser":
|
||||
if (toolName === "browser_screenshot") return "Taking screenshot…";
|
||||
return shortSummary
|
||||
? `Browsing ${shortSummary}`
|
||||
: "Interacting with browser…";
|
||||
case "file-read":
|
||||
return shortSummary ? `Reading ${shortSummary}` : "Reading file…";
|
||||
case "file-write":
|
||||
return shortSummary ? `Writing ${shortSummary}` : "Writing file…";
|
||||
case "file-delete":
|
||||
return shortSummary ? `Deleting ${shortSummary}` : "Deleting file…";
|
||||
case "file-list":
|
||||
return shortSummary ? `Listing ${shortSummary}` : "Listing files…";
|
||||
case "search":
|
||||
return shortSummary
|
||||
? `Searching for "${shortSummary}"`
|
||||
: "Searching…";
|
||||
case "edit":
|
||||
return shortSummary ? `Editing ${shortSummary}` : "Editing file…";
|
||||
case "todo":
|
||||
return shortSummary ? `${shortSummary}` : "Updating task list…";
|
||||
case "compaction":
|
||||
return "Summarizing earlier messages…";
|
||||
default:
|
||||
return `Running ${formatToolName(toolName)}…`;
|
||||
}
|
||||
}
|
||||
case "output-available": {
|
||||
switch (category) {
|
||||
case "bash": {
|
||||
const exitCode = getExitCode(part.output);
|
||||
if (exitCode !== null && exitCode !== 0) {
|
||||
return `Command exited with code ${exitCode}`;
|
||||
}
|
||||
return shortSummary ? `Ran: ${shortSummary}` : "Command completed";
|
||||
}
|
||||
case "web":
|
||||
if (toolName === "WebSearch") {
|
||||
return shortSummary
|
||||
? `Searched "${shortSummary}"`
|
||||
: "Web search completed";
|
||||
}
|
||||
return shortSummary
|
||||
? `Fetched ${shortSummary}`
|
||||
: "Fetched web content";
|
||||
case "browser":
|
||||
if (toolName === "browser_screenshot") return "Screenshot captured";
|
||||
return shortSummary
|
||||
? `Browsed ${shortSummary}`
|
||||
: "Browser action completed";
|
||||
case "file-read":
|
||||
return shortSummary ? `Read ${shortSummary}` : "File read completed";
|
||||
case "file-write":
|
||||
return shortSummary ? `Wrote ${shortSummary}` : "File written";
|
||||
case "file-delete":
|
||||
return shortSummary ? `Deleted ${shortSummary}` : "File deleted";
|
||||
case "file-list":
|
||||
return "Listed files";
|
||||
case "search":
|
||||
return shortSummary
|
||||
? `Searched for "${shortSummary}"`
|
||||
: "Search completed";
|
||||
case "edit":
|
||||
return shortSummary ? `Edited ${shortSummary}` : "Edit completed";
|
||||
case "todo":
|
||||
return "Updated task list";
|
||||
case "compaction":
|
||||
return "Earlier messages were summarized";
|
||||
default:
|
||||
return `${formatToolName(toolName)} completed`;
|
||||
}
|
||||
}
|
||||
case "output-error": {
|
||||
switch (category) {
|
||||
case "bash":
|
||||
return "Command failed";
|
||||
case "web":
|
||||
return toolName === "WebSearch" ? "Search failed" : "Fetch failed";
|
||||
case "browser":
|
||||
return "Browser action failed";
|
||||
default:
|
||||
return `${formatToolName(toolName)} failed`;
|
||||
}
|
||||
}
|
||||
default:
|
||||
return `Running ${formatToolName(toolName)}…`;
|
||||
}
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Output parsing helpers */
|
||||
/* ------------------------------------------------------------------ */
|
||||
@@ -435,13 +186,6 @@ function extractMcpText(output: Record<string, unknown>): string | null {
|
||||
return null;
|
||||
}
|
||||
|
||||
function getExitCode(output: unknown): number | null {
|
||||
const parsed = parseOutput(output);
|
||||
if (!parsed) return null;
|
||||
if (typeof parsed.exit_code === "number") return parsed.exit_code;
|
||||
return null;
|
||||
}
|
||||
|
||||
function getStringField(
|
||||
obj: Record<string, unknown>,
|
||||
...keys: string[]
|
||||
|
||||
@@ -0,0 +1,285 @@
|
||||
import type { ToolUIPart } from "ai";
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Tool name helpers */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
export function extractToolName(part: ToolUIPart): string {
|
||||
return part.type.replace(/^tool-/, "");
|
||||
}
|
||||
|
||||
export function formatToolName(name: string): string {
|
||||
return name.replace(/_/g, " ").replace(/^\w/, (c) => c.toUpperCase());
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Tool categorization */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
export type ToolCategory =
|
||||
| "bash"
|
||||
| "web"
|
||||
| "browser"
|
||||
| "file-read"
|
||||
| "file-write"
|
||||
| "file-delete"
|
||||
| "file-list"
|
||||
| "search"
|
||||
| "edit"
|
||||
| "todo"
|
||||
| "compaction"
|
||||
| "other";
|
||||
|
||||
export function getToolCategory(toolName: string): ToolCategory {
|
||||
switch (toolName) {
|
||||
case "bash_exec":
|
||||
return "bash";
|
||||
case "web_fetch":
|
||||
case "WebSearch":
|
||||
case "WebFetch":
|
||||
return "web";
|
||||
case "browser_navigate":
|
||||
case "browser_act":
|
||||
case "browser_screenshot":
|
||||
return "browser";
|
||||
case "read_workspace_file":
|
||||
case "read_file":
|
||||
case "Read":
|
||||
return "file-read";
|
||||
case "write_workspace_file":
|
||||
case "write_file":
|
||||
case "Write":
|
||||
return "file-write";
|
||||
case "delete_workspace_file":
|
||||
return "file-delete";
|
||||
case "list_workspace_files":
|
||||
case "glob":
|
||||
case "Glob":
|
||||
return "file-list";
|
||||
case "grep":
|
||||
case "Grep":
|
||||
return "search";
|
||||
case "edit_file":
|
||||
case "Edit":
|
||||
return "edit";
|
||||
case "TodoWrite":
|
||||
return "todo";
|
||||
case "context_compaction":
|
||||
return "compaction";
|
||||
default:
|
||||
return "other";
|
||||
}
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Input summary */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function getInputSummary(toolName: string, input: unknown): string | null {
|
||||
if (!input || typeof input !== "object") return null;
|
||||
const inp = input as Record<string, unknown>;
|
||||
|
||||
switch (toolName) {
|
||||
case "bash_exec":
|
||||
return typeof inp.command === "string" ? inp.command : null;
|
||||
case "web_fetch":
|
||||
case "WebFetch":
|
||||
return typeof inp.url === "string" ? inp.url : null;
|
||||
case "WebSearch":
|
||||
return typeof inp.query === "string" ? inp.query : null;
|
||||
case "browser_navigate":
|
||||
return typeof inp.url === "string" ? inp.url : null;
|
||||
case "browser_act":
|
||||
if (typeof inp.action !== "string") return null;
|
||||
return typeof inp.target === "string"
|
||||
? `${inp.action} ${inp.target}`
|
||||
: inp.action;
|
||||
case "browser_screenshot":
|
||||
return null;
|
||||
case "read_workspace_file":
|
||||
case "read_file":
|
||||
case "Read":
|
||||
return (
|
||||
(typeof inp.file_path === "string" ? inp.file_path : null) ??
|
||||
(typeof inp.path === "string" ? inp.path : null)
|
||||
);
|
||||
case "write_workspace_file":
|
||||
case "write_file":
|
||||
case "Write":
|
||||
return (
|
||||
(typeof inp.file_path === "string" ? inp.file_path : null) ??
|
||||
(typeof inp.path === "string" ? inp.path : null)
|
||||
);
|
||||
case "delete_workspace_file":
|
||||
return typeof inp.file_path === "string" ? inp.file_path : null;
|
||||
case "glob":
|
||||
case "Glob":
|
||||
return typeof inp.pattern === "string" ? inp.pattern : null;
|
||||
case "grep":
|
||||
case "Grep":
|
||||
return typeof inp.pattern === "string" ? inp.pattern : null;
|
||||
case "edit_file":
|
||||
case "Edit":
|
||||
return typeof inp.file_path === "string" ? inp.file_path : null;
|
||||
case "TodoWrite": {
|
||||
const todos = Array.isArray(inp.todos) ? inp.todos : [];
|
||||
const active = todos.find(
|
||||
(t: unknown) =>
|
||||
t !== null &&
|
||||
typeof t === "object" &&
|
||||
(t as Record<string, unknown>).status === "in_progress",
|
||||
) as Record<string, unknown> | undefined;
|
||||
if (active && typeof active.activeForm === "string")
|
||||
return active.activeForm;
|
||||
if (active && typeof active.content === "string") return active.content;
|
||||
return null;
|
||||
}
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export function truncate(text: string, maxLen: number): string {
|
||||
if (text.length <= maxLen) return text;
|
||||
return text.slice(0, maxLen).trimEnd() + "\u2026";
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Exit code helper */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function getExitCode(output: unknown): number | null {
|
||||
if (!output || typeof output !== "object") return null;
|
||||
const parsed = output as Record<string, unknown>;
|
||||
if (typeof parsed.exit_code === "number") return parsed.exit_code;
|
||||
return null;
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Animation text */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
export function getAnimationText(
|
||||
part: ToolUIPart,
|
||||
category: ToolCategory,
|
||||
): string {
|
||||
const toolName = extractToolName(part);
|
||||
const summary = getInputSummary(toolName, part.input);
|
||||
const shortSummary = summary ? truncate(summary, 60) : null;
|
||||
|
||||
switch (part.state) {
|
||||
case "input-streaming":
|
||||
case "input-available": {
|
||||
switch (category) {
|
||||
case "bash":
|
||||
return shortSummary
|
||||
? `Running: ${shortSummary}`
|
||||
: "Running command\u2026";
|
||||
case "web":
|
||||
if (toolName === "WebSearch") {
|
||||
return shortSummary
|
||||
? `Searching "${shortSummary}"`
|
||||
: "Searching the web\u2026";
|
||||
}
|
||||
return shortSummary
|
||||
? `Fetching ${shortSummary}`
|
||||
: "Fetching web content\u2026";
|
||||
case "browser":
|
||||
if (toolName === "browser_screenshot")
|
||||
return "Taking screenshot\u2026";
|
||||
return shortSummary
|
||||
? `Browsing ${shortSummary}`
|
||||
: "Interacting with browser\u2026";
|
||||
case "file-read":
|
||||
return shortSummary
|
||||
? `Reading ${shortSummary}`
|
||||
: "Reading file\u2026";
|
||||
case "file-write":
|
||||
return shortSummary
|
||||
? `Writing ${shortSummary}`
|
||||
: "Writing file\u2026";
|
||||
case "file-delete":
|
||||
return shortSummary
|
||||
? `Deleting ${shortSummary}`
|
||||
: "Deleting file\u2026";
|
||||
case "file-list":
|
||||
return shortSummary
|
||||
? `Listing ${shortSummary}`
|
||||
: "Listing files\u2026";
|
||||
case "search":
|
||||
return shortSummary
|
||||
? `Searching for "${shortSummary}"`
|
||||
: "Searching\u2026";
|
||||
case "edit":
|
||||
return shortSummary
|
||||
? `Editing ${shortSummary}`
|
||||
: "Editing file\u2026";
|
||||
case "todo":
|
||||
return shortSummary ? `${shortSummary}` : "Updating task list\u2026";
|
||||
case "compaction":
|
||||
return "Summarizing earlier messages\u2026";
|
||||
default:
|
||||
return `Running ${formatToolName(toolName)}\u2026`;
|
||||
}
|
||||
}
|
||||
case "output-available": {
|
||||
switch (category) {
|
||||
case "bash": {
|
||||
const exitCode = getExitCode(part.output);
|
||||
if (exitCode !== null && exitCode !== 0) {
|
||||
return `Command exited with code ${exitCode}`;
|
||||
}
|
||||
return shortSummary ? `Ran: ${shortSummary}` : "Command completed";
|
||||
}
|
||||
case "web":
|
||||
if (toolName === "WebSearch") {
|
||||
return shortSummary
|
||||
? `Searched "${shortSummary}"`
|
||||
: "Web search completed";
|
||||
}
|
||||
return shortSummary
|
||||
? `Fetched ${shortSummary}`
|
||||
: "Fetched web content";
|
||||
case "browser":
|
||||
if (toolName === "browser_screenshot") return "Screenshot captured";
|
||||
return shortSummary
|
||||
? `Browsed ${shortSummary}`
|
||||
: "Browser action completed";
|
||||
case "file-read":
|
||||
return shortSummary ? `Read ${shortSummary}` : "File read completed";
|
||||
case "file-write":
|
||||
return shortSummary ? `Wrote ${shortSummary}` : "File written";
|
||||
case "file-delete":
|
||||
return shortSummary ? `Deleted ${shortSummary}` : "File deleted";
|
||||
case "file-list":
|
||||
return "Listed files";
|
||||
case "search":
|
||||
return shortSummary
|
||||
? `Searched for "${shortSummary}"`
|
||||
: "Search completed";
|
||||
case "edit":
|
||||
return shortSummary ? `Edited ${shortSummary}` : "Edit completed";
|
||||
case "todo":
|
||||
return "Updated task list";
|
||||
case "compaction":
|
||||
return "Earlier messages were summarized";
|
||||
default:
|
||||
return `${formatToolName(toolName)} completed`;
|
||||
}
|
||||
}
|
||||
case "output-error": {
|
||||
switch (category) {
|
||||
case "bash":
|
||||
return "Command failed";
|
||||
case "web":
|
||||
return toolName === "WebSearch" ? "Search failed" : "Fetch failed";
|
||||
case "browser":
|
||||
return "Browser action failed";
|
||||
default:
|
||||
return `${formatToolName(toolName)} failed`;
|
||||
}
|
||||
}
|
||||
default:
|
||||
return `Running ${formatToolName(toolName)}\u2026`;
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import type { ToolUIPart } from "ai";
|
||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
||||
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
|
||||
import { ScaleLoader } from "../../components/ScaleLoader/ScaleLoader";
|
||||
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
|
||||
import {
|
||||
ContentGrid,
|
||||
@@ -86,7 +86,7 @@ export function RunAgentTool({ part }: Props) {
|
||||
|
||||
{isStreaming && !output && (
|
||||
<ToolAccordion
|
||||
icon={<OrbitLoader size={32} />}
|
||||
icon={<ScaleLoader size={14} />}
|
||||
title="Running agent, this may take a few minutes. Play while you wait."
|
||||
expanded={true}
|
||||
>
|
||||
|
||||
@@ -10,7 +10,7 @@ import {
|
||||
WarningDiamondIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import type { ToolUIPart } from "ai";
|
||||
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
|
||||
import { ScaleLoader } from "../../components/ScaleLoader/ScaleLoader";
|
||||
|
||||
export interface RunAgentInput {
|
||||
username_agent_slug?: string;
|
||||
@@ -171,7 +171,7 @@ export function ToolIcon({
|
||||
);
|
||||
}
|
||||
if (isStreaming) {
|
||||
return <OrbitLoader size={24} />;
|
||||
return <ScaleLoader size={14} />;
|
||||
}
|
||||
return <PlayIcon size={14} weight="regular" className="text-neutral-400" />;
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import {
|
||||
WarningDiamondIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import type { ToolUIPart } from "ai";
|
||||
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
|
||||
import { ScaleLoader } from "../../components/ScaleLoader/ScaleLoader";
|
||||
|
||||
/** Block details returned on first run_block attempt (before input_data provided). */
|
||||
export interface BlockDetailsResponse {
|
||||
@@ -157,7 +157,7 @@ export function ToolIcon({
|
||||
);
|
||||
}
|
||||
if (isStreaming) {
|
||||
return <OrbitLoader size={24} />;
|
||||
return <ScaleLoader size={14} />;
|
||||
}
|
||||
return <PlayIcon size={14} weight="regular" className="text-neutral-400" />;
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import { ResponseType } from "@/app/api/__generated__/models/responseType";
|
||||
import type { SetupRequirementsResponse } from "@/app/api/__generated__/models/setupRequirementsResponse";
|
||||
import { WarningDiamondIcon, PlugsConnectedIcon } from "@phosphor-icons/react";
|
||||
import type { ToolUIPart } from "ai";
|
||||
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
|
||||
import { ScaleLoader } from "../../components/ScaleLoader/ScaleLoader";
|
||||
|
||||
// ------------------------------------------------------------------ //
|
||||
// Re-export generated types for use by RunMCPTool components
|
||||
@@ -212,7 +212,7 @@ export function ToolIcon({
|
||||
);
|
||||
}
|
||||
if (isStreaming) {
|
||||
return <OrbitLoader size={24} />;
|
||||
return <ScaleLoader size={14} />;
|
||||
}
|
||||
return (
|
||||
<PlugsConnectedIcon
|
||||
|
||||
@@ -5,10 +5,9 @@ import {
|
||||
type getV2ListSessionsResponse,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { uploadFileDirect } from "@/lib/direct-upload";
|
||||
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||
import { getWebSocketToken } from "@/lib/supabase/actions";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { environment } from "@/services/environment";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import type { FileUIPart } from "ai";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
@@ -129,49 +128,25 @@ export function useCopilotPage() {
|
||||
files: File[],
|
||||
sid: string,
|
||||
): Promise<UploadedFile[]> {
|
||||
// Upload directly to the Python backend, bypassing the Next.js serverless
|
||||
// proxy. Vercel's 4.5 MB function payload limit would reject larger files
|
||||
// when routed through /api/workspace/files/upload.
|
||||
const { token, error: tokenError } = await getWebSocketToken();
|
||||
if (tokenError || !token) {
|
||||
toast({
|
||||
title: "Authentication error",
|
||||
description: "Please sign in again.",
|
||||
variant: "destructive",
|
||||
});
|
||||
return [];
|
||||
}
|
||||
|
||||
const backendBase = environment.getAGPTServerBaseUrl();
|
||||
|
||||
const results = await Promise.allSettled(
|
||||
files.map(async (file) => {
|
||||
const formData = new FormData();
|
||||
formData.append("file", file);
|
||||
const url = new URL("/api/workspace/files/upload", backendBase);
|
||||
url.searchParams.set("session_id", sid);
|
||||
const res = await fetch(url.toString(), {
|
||||
method: "POST",
|
||||
headers: { Authorization: `Bearer ${token}` },
|
||||
body: formData,
|
||||
});
|
||||
if (!res.ok) {
|
||||
const err = await res.text();
|
||||
try {
|
||||
const data = await uploadFileDirect(file, sid);
|
||||
if (!data.file_id) throw new Error("No file_id returned");
|
||||
return {
|
||||
file_id: data.file_id,
|
||||
name: data.name || file.name,
|
||||
mime_type: data.mime_type || "application/octet-stream",
|
||||
} as UploadedFile;
|
||||
} catch (err) {
|
||||
console.error("File upload failed:", err);
|
||||
toast({
|
||||
title: "File upload failed",
|
||||
description: file.name,
|
||||
variant: "destructive",
|
||||
});
|
||||
throw new Error(err);
|
||||
throw err;
|
||||
}
|
||||
const data = await res.json();
|
||||
if (!data.file_id) throw new Error("No file_id returned");
|
||||
return {
|
||||
file_id: data.file_id,
|
||||
name: data.name || file.name,
|
||||
mime_type: data.mime_type || "application/octet-stream",
|
||||
} as UploadedFile;
|
||||
}),
|
||||
);
|
||||
return results
|
||||
|
||||
@@ -92,12 +92,18 @@ export function useCopilotStream({
|
||||
// Set when the user explicitly clicks stop — prevents onError from
|
||||
// triggering a reconnect cycle for the resulting AbortError.
|
||||
const isUserStoppingRef = useRef(false);
|
||||
// Set when all reconnect attempts are exhausted — prevents hasActiveStream
|
||||
// from keeping the UI blocked forever when the backend is slow to clear it.
|
||||
// Must be state (not ref) so that setting it triggers a re-render and
|
||||
// recomputes `isReconnecting`.
|
||||
const [reconnectExhausted, setReconnectExhausted] = useState(false);
|
||||
|
||||
function handleReconnect(sid: string) {
|
||||
if (isReconnectScheduledRef.current || !sid) return;
|
||||
|
||||
const nextAttempt = reconnectAttemptsRef.current + 1;
|
||||
if (nextAttempt > RECONNECT_MAX_ATTEMPTS) {
|
||||
setReconnectExhausted(true);
|
||||
toast({
|
||||
title: "Connection lost",
|
||||
description: "Unable to reconnect. Please refresh the page.",
|
||||
@@ -146,7 +152,11 @@ export function useCopilotStream({
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if backend executor is still running after clean close
|
||||
// Check if backend executor is still running after clean close.
|
||||
// Brief delay to let the backend clear active_stream — without this,
|
||||
// the refetch often races and sees stale active_stream=true, triggering
|
||||
// unnecessary reconnect cycles.
|
||||
await new Promise((r) => setTimeout(r, 500));
|
||||
const result = await refetchSession();
|
||||
const d = result.data;
|
||||
const backendActive =
|
||||
@@ -276,6 +286,7 @@ export function useCopilotStream({
|
||||
setIsReconnectScheduled(false);
|
||||
hasShownDisconnectToast.current = false;
|
||||
isUserStoppingRef.current = false;
|
||||
setReconnectExhausted(false);
|
||||
hasResumedRef.current.clear();
|
||||
return () => {
|
||||
clearTimeout(reconnectTimerRef.current);
|
||||
@@ -299,6 +310,7 @@ export function useCopilotStream({
|
||||
if (status === "ready") {
|
||||
reconnectAttemptsRef.current = 0;
|
||||
hasShownDisconnectToast.current = false;
|
||||
setReconnectExhausted(false);
|
||||
}
|
||||
}
|
||||
}, [status, sessionId, queryClient, isReconnectScheduled]);
|
||||
@@ -358,10 +370,12 @@ export function useCopilotStream({
|
||||
}, [hasActiveStream]);
|
||||
|
||||
// True while reconnecting or backend has active stream but we haven't connected yet.
|
||||
// Suppressed when the user explicitly stopped — the backend may take a moment
|
||||
// to clear active_stream but the UI should be responsive immediately.
|
||||
// Suppressed when the user explicitly stopped or when all reconnect attempts
|
||||
// are exhausted — the backend may be slow to clear active_stream but the UI
|
||||
// should remain responsive.
|
||||
const isReconnecting =
|
||||
!isUserStoppingRef.current &&
|
||||
!reconnectExhausted &&
|
||||
(isReconnectScheduled ||
|
||||
(hasActiveStream && status !== "streaming" && status !== "submitted"));
|
||||
|
||||
|
||||
@@ -42,8 +42,8 @@ export function AgentVersionChangelog({
|
||||
|
||||
// Create version info from available graph versions
|
||||
const storeData = okData(storeAgentData) as StoreAgentDetails | undefined;
|
||||
const agentVersions: VersionInfo[] = storeData?.agentGraphVersions
|
||||
? storeData.agentGraphVersions
|
||||
const agentVersions: VersionInfo[] = storeData?.graph_versions
|
||||
? storeData.graph_versions
|
||||
.map((versionStr: string) => parseInt(versionStr, 10))
|
||||
.sort((a: number, b: number) => b - a) // Sort descending (newest first)
|
||||
.map((version: number) => ({
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user