mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
17 Commits
ntindle/wa
...
swiftyos/i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
af7e8d19fe | ||
|
|
eadc68f2a5 | ||
|
|
eca7b5e793 | ||
|
|
c304a4937a | ||
|
|
8cfabcf4fd | ||
|
|
7bf407b66c | ||
|
|
7ead4c040f | ||
|
|
0f813f1bf9 | ||
|
|
aa08063939 | ||
|
|
bde6a4c0df | ||
|
|
d56452898a | ||
|
|
7507240177 | ||
|
|
d7c3f5b8fc | ||
|
|
3e108a813a | ||
|
|
08c49a78f8 | ||
|
|
0b9e0665dd | ||
|
|
f6f268a1f0 |
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import urllib.parse
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, Any, Literal, Optional, Sequence
|
||||
from typing import Annotated, Any, Optional, Sequence
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Security
|
||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||
@@ -9,9 +9,10 @@ from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import backend.api.features.store.cache as store_cache
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
import backend.blocks
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.external.middleware import require_auth, require_permission
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data import user as user_db
|
||||
@@ -230,13 +231,13 @@ async def get_graph_execution_results(
|
||||
@v1_router.get(
|
||||
path="/store/agents",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
response_model=store_model.StoreAgentsResponse,
|
||||
)
|
||||
async def get_store_agents(
|
||||
featured: bool = False,
|
||||
creator: str | None = None,
|
||||
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
|
||||
sorted_by: store_db.StoreAgentsSortOptions | None = None,
|
||||
search_query: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
@@ -278,7 +279,7 @@ async def get_store_agents(
|
||||
@v1_router.get(
|
||||
path="/store/agents/{username}/{agent_name}",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
response_model=store_model.StoreAgentDetails,
|
||||
)
|
||||
async def get_store_agent(
|
||||
@@ -306,13 +307,13 @@ async def get_store_agent(
|
||||
@v1_router.get(
|
||||
path="/store/creators",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
response_model=store_model.CreatorsResponse,
|
||||
)
|
||||
async def get_store_creators(
|
||||
featured: bool = False,
|
||||
search_query: str | None = None,
|
||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
|
||||
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> store_model.CreatorsResponse:
|
||||
@@ -348,7 +349,7 @@ async def get_store_creators(
|
||||
@v1_router.get(
|
||||
path="/store/creators/{username}",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
response_model=store_model.CreatorDetails,
|
||||
)
|
||||
async def get_store_creator(
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import prisma.enums
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
from backend.data.model import UserTransaction
|
||||
from backend.util.models import Pagination
|
||||
@@ -14,3 +18,42 @@ class UserHistoryResponse(BaseModel):
|
||||
class AddUserCreditsResponse(BaseModel):
|
||||
new_balance: int
|
||||
transaction_key: str
|
||||
|
||||
|
||||
class CreateInvitedUserRequest(BaseModel):
|
||||
email: EmailStr
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class InvitedUserResponse(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
status: prisma.enums.InvitedUserStatus
|
||||
auth_user_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
tally_understanding: Optional[dict[str, Any]] = None
|
||||
tally_status: prisma.enums.TallyComputationStatus
|
||||
tally_computed_at: Optional[datetime] = None
|
||||
tally_error: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class InvitedUsersResponse(BaseModel):
|
||||
invited_users: list[InvitedUserResponse]
|
||||
|
||||
|
||||
class BulkInvitedUserRowResponse(BaseModel):
|
||||
row_number: int
|
||||
email: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
status: Literal["CREATED", "SKIPPED", "ERROR"]
|
||||
message: str
|
||||
invited_user: Optional[InvitedUserResponse] = None
|
||||
|
||||
|
||||
class BulkInvitedUsersResponse(BaseModel):
|
||||
created_count: int
|
||||
skipped_count: int
|
||||
error_count: int
|
||||
results: list[BulkInvitedUserRowResponse]
|
||||
|
||||
@@ -24,14 +24,13 @@ router = fastapi.APIRouter(
|
||||
@router.get(
|
||||
"/listings",
|
||||
summary="Get Admin Listings History",
|
||||
response_model=store_model.StoreListingsWithVersionsResponse,
|
||||
)
|
||||
async def get_admin_listings_with_versions(
|
||||
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
|
||||
search: typing.Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
) -> store_model.StoreListingsWithVersionsAdminViewResponse:
|
||||
"""
|
||||
Get store listings with their version history for admins.
|
||||
|
||||
@@ -45,36 +44,26 @@ async def get_admin_listings_with_versions(
|
||||
page_size: Number of items per page
|
||||
|
||||
Returns:
|
||||
StoreListingsWithVersionsResponse with listings and their versions
|
||||
Paginated listings with their versions
|
||||
"""
|
||||
try:
|
||||
listings = await store_db.get_admin_listings_with_versions(
|
||||
status=status,
|
||||
search_query=search,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return listings
|
||||
except Exception as e:
|
||||
logger.exception("Error getting admin listings with versions: %s", e)
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "An error occurred while retrieving listings with versions"
|
||||
},
|
||||
)
|
||||
listings = await store_db.get_admin_listings_with_versions(
|
||||
status=status,
|
||||
search_query=search,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return listings
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/{store_listing_version_id}/review",
|
||||
summary="Review Store Submission",
|
||||
response_model=store_model.StoreSubmission,
|
||||
)
|
||||
async def review_submission(
|
||||
store_listing_version_id: str,
|
||||
request: store_model.ReviewSubmissionRequest,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
) -> store_model.StoreSubmissionAdminView:
|
||||
"""
|
||||
Review a store listing submission.
|
||||
|
||||
@@ -84,31 +73,24 @@ async def review_submission(
|
||||
user_id: Authenticated admin user performing the review
|
||||
|
||||
Returns:
|
||||
StoreSubmission with updated review information
|
||||
StoreSubmissionAdminView with updated review information
|
||||
"""
|
||||
try:
|
||||
already_approved = await store_db.check_submission_already_approved(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
submission = await store_db.review_store_submission(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
is_approved=request.is_approved,
|
||||
external_comments=request.comments,
|
||||
internal_comments=request.internal_comments or "",
|
||||
reviewer_id=user_id,
|
||||
)
|
||||
already_approved = await store_db.check_submission_already_approved(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
submission = await store_db.review_store_submission(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
is_approved=request.is_approved,
|
||||
external_comments=request.comments,
|
||||
internal_comments=request.internal_comments or "",
|
||||
reviewer_id=user_id,
|
||||
)
|
||||
|
||||
state_changed = already_approved != request.is_approved
|
||||
# Clear caches when the request is approved as it updates what is shown on the store
|
||||
if state_changed:
|
||||
store_cache.clear_all_caches()
|
||||
return submission
|
||||
except Exception as e:
|
||||
logger.exception("Error reviewing submission: %s", e)
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while reviewing the submission"},
|
||||
)
|
||||
state_changed = already_approved != request.is_approved
|
||||
# Clear caches whenever approval state changes, since store visibility can change
|
||||
if state_changed:
|
||||
store_cache.clear_all_caches()
|
||||
return submission
|
||||
|
||||
|
||||
@router.get(
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
import logging
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, File, Security, UploadFile
|
||||
|
||||
from backend.data.invited_user import (
|
||||
BulkInvitedUsersResult,
|
||||
InvitedUserRecord,
|
||||
bulk_create_invited_users_from_file,
|
||||
create_invited_user,
|
||||
list_invited_users,
|
||||
retry_invited_user_tally,
|
||||
revoke_invited_user,
|
||||
)
|
||||
|
||||
from .model import (
|
||||
BulkInvitedUserRowResponse,
|
||||
BulkInvitedUsersResponse,
|
||||
CreateInvitedUserRequest,
|
||||
InvitedUserResponse,
|
||||
InvitedUsersResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["users", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
def _to_response(invited_user: InvitedUserRecord) -> InvitedUserResponse:
|
||||
return InvitedUserResponse(**invited_user.model_dump())
|
||||
|
||||
|
||||
def _to_bulk_response(result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
|
||||
return BulkInvitedUsersResponse(
|
||||
created_count=result.created_count,
|
||||
skipped_count=result.skipped_count,
|
||||
error_count=result.error_count,
|
||||
results=[
|
||||
BulkInvitedUserRowResponse(
|
||||
row_number=row.row_number,
|
||||
email=row.email,
|
||||
name=row.name,
|
||||
status=row.status,
|
||||
message=row.message,
|
||||
invited_user=(
|
||||
_to_response(row.invited_user)
|
||||
if row.invited_user is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
for row in result.results
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/invited-users",
|
||||
response_model=InvitedUsersResponse,
|
||||
summary="List Invited Users",
|
||||
)
|
||||
async def get_invited_users(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUsersResponse:
|
||||
logger.info("Admin user %s requested invited users", admin_user_id)
|
||||
invited_users = await list_invited_users()
|
||||
return InvitedUsersResponse(
|
||||
invited_users=[_to_response(invited_user) for invited_user in invited_users]
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Create Invited User",
|
||||
)
|
||||
async def create_invited_user_route(
|
||||
request: CreateInvitedUserRequest,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s created invited user for %s",
|
||||
admin_user_id,
|
||||
request.email,
|
||||
)
|
||||
invited_user = await create_invited_user(request.email, request.name)
|
||||
return _to_response(invited_user)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/bulk",
|
||||
response_model=BulkInvitedUsersResponse,
|
||||
summary="Bulk Create Invited Users",
|
||||
operation_id="postV2BulkCreateInvitedUsers",
|
||||
)
|
||||
async def bulk_create_invited_users_route(
|
||||
file: UploadFile = File(...),
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> BulkInvitedUsersResponse:
|
||||
logger.info(
|
||||
"Admin user %s bulk invited users from %s",
|
||||
admin_user_id,
|
||||
file.filename or "<unnamed>",
|
||||
)
|
||||
content = await file.read()
|
||||
result = await bulk_create_invited_users_from_file(file.filename, content)
|
||||
return _to_bulk_response(result)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/{invited_user_id}/revoke",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Revoke Invited User",
|
||||
)
|
||||
async def revoke_invited_user_route(
|
||||
invited_user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info("Admin user %s revoked invited user %s", admin_user_id, invited_user_id)
|
||||
invited_user = await revoke_invited_user(invited_user_id)
|
||||
return _to_response(invited_user)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/{invited_user_id}/retry-tally",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Retry Invited User Tally",
|
||||
)
|
||||
async def retry_invited_user_tally_route(
|
||||
invited_user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s retried Tally seed for invited user %s",
|
||||
admin_user_id,
|
||||
invited_user_id,
|
||||
)
|
||||
invited_user = await retry_invited_user_tally(invited_user_id)
|
||||
return _to_response(invited_user)
|
||||
@@ -0,0 +1,165 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import prisma.enums
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.data.invited_user import (
|
||||
BulkInvitedUserRowResult,
|
||||
BulkInvitedUsersResult,
|
||||
InvitedUserRecord,
|
||||
)
|
||||
|
||||
from .user_admin_routes import router as user_admin_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(user_admin_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _sample_invited_user() -> InvitedUserRecord:
|
||||
now = datetime.now(timezone.utc)
|
||||
return InvitedUserRecord(
|
||||
id="invite-1",
|
||||
email="invited@example.com",
|
||||
status=prisma.enums.InvitedUserStatus.INVITED,
|
||||
auth_user_id=None,
|
||||
name="Invited User",
|
||||
tally_understanding=None,
|
||||
tally_status=prisma.enums.TallyComputationStatus.PENDING,
|
||||
tally_computed_at=None,
|
||||
tally_error=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
|
||||
return BulkInvitedUsersResult(
|
||||
created_count=1,
|
||||
skipped_count=1,
|
||||
error_count=0,
|
||||
results=[
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=1,
|
||||
email="invited@example.com",
|
||||
name=None,
|
||||
status="CREATED",
|
||||
message="Invite created",
|
||||
invited_user=_sample_invited_user(),
|
||||
),
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=2,
|
||||
email="duplicate@example.com",
|
||||
name=None,
|
||||
status="SKIPPED",
|
||||
message="An invited user with this email already exists",
|
||||
invited_user=None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_get_invited_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.list_invited_users",
|
||||
AsyncMock(return_value=[_sample_invited_user()]),
|
||||
)
|
||||
|
||||
response = client.get("/admin/invited-users")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["invited_users"]) == 1
|
||||
assert data["invited_users"][0]["email"] == "invited@example.com"
|
||||
assert data["invited_users"][0]["status"] == "INVITED"
|
||||
|
||||
|
||||
def test_create_invited_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.create_invited_user",
|
||||
AsyncMock(return_value=_sample_invited_user()),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/invited-users",
|
||||
json={"email": "invited@example.com", "name": "Invited User"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == "invited@example.com"
|
||||
assert data["name"] == "Invited User"
|
||||
|
||||
|
||||
def test_bulk_create_invited_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.bulk_create_invited_users_from_file",
|
||||
AsyncMock(return_value=_sample_bulk_invited_users_result()),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/invited-users/bulk",
|
||||
files={
|
||||
"file": ("invites.txt", b"invited@example.com\nduplicate@example.com\n")
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["created_count"] == 1
|
||||
assert data["skipped_count"] == 1
|
||||
assert data["results"][0]["status"] == "CREATED"
|
||||
assert data["results"][1]["status"] == "SKIPPED"
|
||||
|
||||
|
||||
def test_revoke_invited_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
revoked = _sample_invited_user().model_copy(
|
||||
update={"status": prisma.enums.InvitedUserStatus.REVOKED}
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.revoke_invited_user",
|
||||
AsyncMock(return_value=revoked),
|
||||
)
|
||||
|
||||
response = client.post("/admin/invited-users/invite-1/revoke")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "REVOKED"
|
||||
|
||||
|
||||
def test_retry_invited_user_tally(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
retried = _sample_invited_user().model_copy(
|
||||
update={"tally_status": prisma.enums.TallyComputationStatus.RUNNING}
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.retry_invited_user_tally",
|
||||
AsyncMock(return_value=retried),
|
||||
)
|
||||
|
||||
response = client.post("/admin/invited-users/invite-1/retry-tally")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["tally_status"] == "RUNNING"
|
||||
@@ -11,7 +11,7 @@ from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from prisma.models import UserWorkspaceFile
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
@@ -25,6 +25,7 @@ from backend.copilot.model import (
|
||||
delete_chat_session,
|
||||
get_chat_session,
|
||||
get_user_sessions,
|
||||
update_session_title,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.models import (
|
||||
@@ -141,6 +142,20 @@ class CancelSessionResponse(BaseModel):
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class UpdateSessionTitleRequest(BaseModel):
|
||||
"""Request model for updating a session's title."""
|
||||
|
||||
title: str
|
||||
|
||||
@field_validator("title")
|
||||
@classmethod
|
||||
def title_must_not_be_blank(cls, v: str) -> str:
|
||||
stripped = v.strip()
|
||||
if not stripped:
|
||||
raise ValueError("Title must not be blank")
|
||||
return stripped
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@@ -264,6 +279,43 @@ async def delete_session(
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/sessions/{session_id}/title",
|
||||
summary="Update session title",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=200,
|
||||
responses={404: {"description": "Session not found or access denied"}},
|
||||
)
|
||||
async def update_session_title_route(
|
||||
session_id: str,
|
||||
request: UpdateSessionTitleRequest,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> dict:
|
||||
"""
|
||||
Update the title of a chat session.
|
||||
|
||||
Allows the user to rename their chat session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to update.
|
||||
request: Request body containing the new title.
|
||||
user_id: The authenticated user's ID.
|
||||
|
||||
Returns:
|
||||
dict: Status of the update.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if session not found or not owned by user.
|
||||
"""
|
||||
success = await update_session_title(session_id, user_id, request.title)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Session {session_id} not found or access denied",
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}",
|
||||
)
|
||||
@@ -753,7 +805,6 @@ async def resume_session_stream(
|
||||
@router.patch(
|
||||
"/sessions/{session_id}/assign-user",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=200,
|
||||
)
|
||||
async def session_assign_user(
|
||||
session_id: str,
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
"""Tests for chat route file_ids validation and enrichment."""
|
||||
"""Tests for chat API routes: session title update and file attachment validation."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
@@ -17,6 +19,7 @@ TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
@@ -24,7 +27,95 @@ def setup_app_auth(mock_jwt_user):
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ---- file_ids Pydantic validation (B1) ----
|
||||
def _mock_update_session_title(
|
||||
mocker: pytest_mock.MockerFixture, *, success: bool = True
|
||||
):
|
||||
"""Mock update_session_title."""
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.update_session_title",
|
||||
new_callable=AsyncMock,
|
||||
return_value=success,
|
||||
)
|
||||
|
||||
|
||||
# ─── Update title: success ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_update_title_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_update = _mock_update_session_title(mocker, success=True)
|
||||
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": "My project"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
mock_update.assert_called_once_with("sess-1", test_user_id, "My project")
|
||||
|
||||
|
||||
def test_update_title_trims_whitespace(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_update = _mock_update_session_title(mocker, success=True)
|
||||
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": " trimmed "},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_update.assert_called_once_with("sess-1", test_user_id, "trimmed")
|
||||
|
||||
|
||||
# ─── Update title: blank / whitespace-only → 422 ──────────────────────
|
||||
|
||||
|
||||
def test_update_title_blank_rejected(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Whitespace-only titles must be rejected before hitting the DB."""
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": " "},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_update_title_empty_rejected(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": ""},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
# ─── Update title: session not found or wrong user → 404 ──────────────
|
||||
|
||||
|
||||
def test_update_title_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
_mock_update_session_title(mocker, success=False)
|
||||
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": "New name"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ─── file_ids Pydantic validation ─────────────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_rejects_too_many_file_ids():
|
||||
@@ -92,7 +183,7 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ---- UUID format filtering ----
|
||||
# ─── UUID format filtering ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
@@ -131,7 +222,7 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
assert call_kwargs["where"]["id"]["in"] == [valid_id]
|
||||
|
||||
|
||||
# ---- Cross-workspace file_ids ----
|
||||
# ─── Cross-workspace file_ids ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
|
||||
@@ -8,7 +8,6 @@ import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
import backend.api.features.store.exceptions as store_exceptions
|
||||
import backend.api.features.store.image_gen as store_image_gen
|
||||
import backend.api.features.store.media as store_media
|
||||
import backend.data.graph as graph_db
|
||||
@@ -251,7 +250,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
||||
The requested LibraryAgent.
|
||||
|
||||
Raises:
|
||||
AgentNotFoundError: If the specified agent does not exist.
|
||||
NotFoundError: If the specified agent does not exist.
|
||||
DatabaseError: If there's an error during retrieval.
|
||||
"""
|
||||
library_agent = await prisma.models.LibraryAgent.prisma().find_first(
|
||||
@@ -398,6 +397,7 @@ async def create_library_agent(
|
||||
hitl_safe_mode: bool = True,
|
||||
sensitive_action_safe_mode: bool = False,
|
||||
create_library_agents_for_sub_graphs: bool = True,
|
||||
folder_id: str | None = None,
|
||||
) -> list[library_model.LibraryAgent]:
|
||||
"""
|
||||
Adds an agent to the user's library (LibraryAgent table).
|
||||
@@ -414,12 +414,18 @@ async def create_library_agent(
|
||||
If the graph has sub-graphs, the parent graph will always be the first entry in the list.
|
||||
|
||||
Raises:
|
||||
AgentNotFoundError: If the specified agent does not exist.
|
||||
NotFoundError: If the specified agent does not exist.
|
||||
DatabaseError: If there's an error during creation or if image generation fails.
|
||||
"""
|
||||
logger.info(
|
||||
f"Creating library agent for graph #{graph.id} v{graph.version}; user:<redacted>"
|
||||
)
|
||||
|
||||
# Authorization: FK only checks existence, not ownership.
|
||||
# Verify the folder belongs to this user to prevent cross-user nesting.
|
||||
if folder_id:
|
||||
await get_folder(folder_id, user_id)
|
||||
|
||||
graph_entries = (
|
||||
[graph, *graph.sub_graphs] if create_library_agents_for_sub_graphs else [graph]
|
||||
)
|
||||
@@ -432,7 +438,6 @@ async def create_library_agent(
|
||||
isCreatedByUser=(user_id == user_id),
|
||||
useGraphIsActiveVersion=True,
|
||||
User={"connect": {"id": user_id}},
|
||||
# Creator={"connect": {"id": user_id}},
|
||||
AgentGraph={
|
||||
"connect": {
|
||||
"graphVersionId": {
|
||||
@@ -448,6 +453,11 @@ async def create_library_agent(
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
**(
|
||||
{"Folder": {"connect": {"id": folder_id}}}
|
||||
if folder_id and graph_entry is graph
|
||||
else {}
|
||||
),
|
||||
),
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
@@ -529,6 +539,7 @@ async def update_agent_version_in_library(
|
||||
async def create_graph_in_library(
|
||||
graph: graph_db.Graph,
|
||||
user_id: str,
|
||||
folder_id: str | None = None,
|
||||
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
||||
"""Create a new graph and add it to the user's library."""
|
||||
graph.version = 1
|
||||
@@ -542,6 +553,7 @@ async def create_graph_in_library(
|
||||
user_id=user_id,
|
||||
sensitive_action_safe_mode=True,
|
||||
create_library_agents_for_sub_graphs=False,
|
||||
folder_id=folder_id,
|
||||
)
|
||||
|
||||
if created_graph.is_active:
|
||||
@@ -817,7 +829,7 @@ async def add_store_agent_to_library(
|
||||
The newly created LibraryAgent if successfully added, the existing corresponding one if any.
|
||||
|
||||
Raises:
|
||||
AgentNotFoundError: If the store listing or associated agent is not found.
|
||||
NotFoundError: If the store listing or associated agent is not found.
|
||||
DatabaseError: If there's an issue creating the LibraryAgent record.
|
||||
"""
|
||||
logger.debug(
|
||||
@@ -832,7 +844,7 @@ async def add_store_agent_to_library(
|
||||
)
|
||||
if not store_listing_version or not store_listing_version.AgentGraph:
|
||||
logger.warning(f"Store listing version not found: {store_listing_version_id}")
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found or invalid"
|
||||
)
|
||||
|
||||
@@ -846,7 +858,7 @@ async def add_store_agent_to_library(
|
||||
include_subgraphs=False,
|
||||
)
|
||||
if not graph_model:
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
raise NotFoundError(
|
||||
f"Graph #{graph.id} v{graph.version} not found or accessible"
|
||||
)
|
||||
|
||||
@@ -1481,6 +1493,67 @@ async def bulk_move_agents_to_folder(
|
||||
return [library_model.LibraryAgent.from_db(agent) for agent in agents]
|
||||
|
||||
|
||||
def collect_tree_ids(
|
||||
nodes: list[library_model.LibraryFolderTree],
|
||||
visited: set[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Collect all folder IDs from a folder tree."""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
ids: list[str] = []
|
||||
for n in nodes:
|
||||
if n.id in visited:
|
||||
continue
|
||||
visited.add(n.id)
|
||||
ids.append(n.id)
|
||||
ids.extend(collect_tree_ids(n.children, visited))
|
||||
return ids
|
||||
|
||||
|
||||
async def get_folder_agent_summaries(
|
||||
user_id: str, folder_id: str
|
||||
) -> list[dict[str, str | None]]:
|
||||
"""Get a lightweight list of agents in a folder (id, name, description)."""
|
||||
all_agents: list[library_model.LibraryAgent] = []
|
||||
for page in itertools.count(1):
|
||||
resp = await list_library_agents(
|
||||
user_id=user_id, folder_id=folder_id, page=page
|
||||
)
|
||||
all_agents.extend(resp.agents)
|
||||
if page >= resp.pagination.total_pages:
|
||||
break
|
||||
return [
|
||||
{"id": a.id, "name": a.name, "description": a.description} for a in all_agents
|
||||
]
|
||||
|
||||
|
||||
async def get_root_agent_summaries(
|
||||
user_id: str,
|
||||
) -> list[dict[str, str | None]]:
|
||||
"""Get a lightweight list of root-level agents (folderId IS NULL)."""
|
||||
all_agents: list[library_model.LibraryAgent] = []
|
||||
for page in itertools.count(1):
|
||||
resp = await list_library_agents(
|
||||
user_id=user_id, include_root_only=True, page=page
|
||||
)
|
||||
all_agents.extend(resp.agents)
|
||||
if page >= resp.pagination.total_pages:
|
||||
break
|
||||
return [
|
||||
{"id": a.id, "name": a.name, "description": a.description} for a in all_agents
|
||||
]
|
||||
|
||||
|
||||
async def get_folder_agents_map(
|
||||
user_id: str, folder_ids: list[str]
|
||||
) -> dict[str, list[dict[str, str | None]]]:
|
||||
"""Get agent summaries for multiple folders concurrently."""
|
||||
results = await asyncio.gather(
|
||||
*(get_folder_agent_summaries(user_id, fid) for fid in folder_ids)
|
||||
)
|
||||
return dict(zip(folder_ids, results))
|
||||
|
||||
|
||||
##############################################
|
||||
########### Presets DB Functions #############
|
||||
##############################################
|
||||
|
||||
@@ -4,7 +4,6 @@ import prisma.enums
|
||||
import prisma.models
|
||||
import pytest
|
||||
|
||||
import backend.api.features.store.exceptions
|
||||
from backend.data.db import connect
|
||||
from backend.data.includes import library_agent_include
|
||||
|
||||
@@ -218,7 +217,7 @@ async def test_add_agent_to_library_not_found(mocker):
|
||||
)
|
||||
|
||||
# Call function and verify exception
|
||||
with pytest.raises(backend.api.features.store.exceptions.AgentNotFoundError):
|
||||
with pytest.raises(db.NotFoundError):
|
||||
await db.add_store_agent_to_library("version123", "test-user")
|
||||
|
||||
# Verify mock called correctly
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Literal
|
||||
|
||||
from backend.util.cache import cached
|
||||
|
||||
from . import db as store_db
|
||||
@@ -23,7 +21,7 @@ def clear_all_caches():
|
||||
async def _get_cached_store_agents(
|
||||
featured: bool,
|
||||
creator: str | None,
|
||||
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None,
|
||||
sorted_by: store_db.StoreAgentsSortOptions | None,
|
||||
search_query: str | None,
|
||||
category: str | None,
|
||||
page: int,
|
||||
@@ -57,7 +55,7 @@ async def _get_cached_agent_details(
|
||||
async def _get_cached_store_creators(
|
||||
featured: bool,
|
||||
search_query: str | None,
|
||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None,
|
||||
sorted_by: store_db.StoreCreatorsSortOptions | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
@@ -75,4 +73,4 @@ async def _get_cached_store_creators(
|
||||
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
|
||||
async def _get_cached_creator_details(username: str):
|
||||
"""Cached helper to get creator details."""
|
||||
return await store_db.get_store_creator_details(username=username.lower())
|
||||
return await store_db.get_store_creator(username=username.lower())
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -26,7 +26,7 @@ async def test_get_store_agents(mocker):
|
||||
mock_agents = [
|
||||
prisma.models.StoreAgent(
|
||||
listing_id="test-id",
|
||||
storeListingVersionId="version123",
|
||||
listing_version_id="version123",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video=None,
|
||||
@@ -40,11 +40,11 @@ async def test_get_store_agents(mocker):
|
||||
runs=10,
|
||||
rating=4.5,
|
||||
versions=["1.0"],
|
||||
agentGraphVersions=["1"],
|
||||
agentGraphId="test-graph-id",
|
||||
graph_id="test-graph-id",
|
||||
graph_versions=["1"],
|
||||
updated_at=datetime.now(),
|
||||
is_available=False,
|
||||
useForOnboarding=False,
|
||||
use_for_onboarding=False,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -68,10 +68,10 @@ async def test_get_store_agents(mocker):
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agent_details(mocker):
|
||||
# Mock data
|
||||
# Mock data - StoreAgent view already contains the active version data
|
||||
mock_agent = prisma.models.StoreAgent(
|
||||
listing_id="test-id",
|
||||
storeListingVersionId="version123",
|
||||
listing_version_id="version123",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video="video.mp4",
|
||||
@@ -85,102 +85,38 @@ async def test_get_store_agent_details(mocker):
|
||||
runs=10,
|
||||
rating=4.5,
|
||||
versions=["1.0"],
|
||||
agentGraphVersions=["1"],
|
||||
agentGraphId="test-graph-id",
|
||||
updated_at=datetime.now(),
|
||||
is_available=False,
|
||||
useForOnboarding=False,
|
||||
)
|
||||
|
||||
# Mock active version agent (what we want to return for active version)
|
||||
mock_active_agent = prisma.models.StoreAgent(
|
||||
listing_id="test-id",
|
||||
storeListingVersionId="active-version-id",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent Active",
|
||||
agent_video="active_video.mp4",
|
||||
agent_image=["active_image.jpg"],
|
||||
featured=False,
|
||||
creator_username="creator",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test heading active",
|
||||
description="Test description active",
|
||||
categories=["test"],
|
||||
runs=15,
|
||||
rating=4.8,
|
||||
versions=["1.0", "2.0"],
|
||||
agentGraphVersions=["1", "2"],
|
||||
agentGraphId="test-graph-id-active",
|
||||
graph_id="test-graph-id",
|
||||
graph_versions=["1"],
|
||||
updated_at=datetime.now(),
|
||||
is_available=True,
|
||||
useForOnboarding=False,
|
||||
use_for_onboarding=False,
|
||||
)
|
||||
|
||||
# Create a mock StoreListing result
|
||||
mock_store_listing = mocker.MagicMock()
|
||||
mock_store_listing.activeVersionId = "active-version-id"
|
||||
mock_store_listing.hasApprovedVersion = True
|
||||
mock_store_listing.ActiveVersion = mocker.MagicMock()
|
||||
mock_store_listing.ActiveVersion.recommendedScheduleCron = None
|
||||
|
||||
# Mock StoreAgent prisma call - need to handle multiple calls
|
||||
# Mock StoreAgent prisma call
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
|
||||
# Set up side_effect to return different results for different calls
|
||||
def mock_find_first_side_effect(*args, **kwargs):
|
||||
where_clause = kwargs.get("where", {})
|
||||
if "storeListingVersionId" in where_clause:
|
||||
# Second call for active version
|
||||
return mock_active_agent
|
||||
else:
|
||||
# First call for initial lookup
|
||||
return mock_agent
|
||||
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(
|
||||
side_effect=mock_find_first_side_effect
|
||||
)
|
||||
|
||||
# Mock Profile prisma call
|
||||
mock_profile = mocker.MagicMock()
|
||||
mock_profile.userId = "user-id-123"
|
||||
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
||||
mock_profile_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_profile
|
||||
)
|
||||
|
||||
# Mock StoreListing prisma call
|
||||
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_store_listing
|
||||
)
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_agent_details("creator", "test-agent")
|
||||
|
||||
# Verify results - should use active version data
|
||||
# Verify results - constructed from the StoreAgent view
|
||||
assert result.slug == "test-agent"
|
||||
assert result.agent_name == "Test Agent Active" # From active version
|
||||
assert result.active_version_id == "active-version-id"
|
||||
assert result.agent_name == "Test Agent"
|
||||
assert result.active_version_id == "version123"
|
||||
assert result.has_approved_version is True
|
||||
assert (
|
||||
result.store_listing_version_id == "active-version-id"
|
||||
) # Should be active version ID
|
||||
assert result.store_listing_version_id == "version123"
|
||||
assert result.graph_id == "test-graph-id"
|
||||
assert result.runs == 10
|
||||
assert result.rating == 4.5
|
||||
|
||||
# Verify mocks called correctly - now expecting 2 calls
|
||||
assert mock_store_agent.return_value.find_first.call_count == 2
|
||||
|
||||
# Check the specific calls
|
||||
calls = mock_store_agent.return_value.find_first.call_args_list
|
||||
assert calls[0] == mocker.call(
|
||||
# Verify single StoreAgent lookup
|
||||
mock_store_agent.return_value.find_first.assert_called_once_with(
|
||||
where={"creator_username": "creator", "slug": "test-agent"}
|
||||
)
|
||||
assert calls[1] == mocker.call(where={"storeListingVersionId": "active-version-id"})
|
||||
|
||||
mock_store_listing_db.return_value.find_first.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_creator_details(mocker):
|
||||
async def test_get_store_creator(mocker):
|
||||
# Mock data
|
||||
mock_creator_data = prisma.models.Creator(
|
||||
name="Test Creator",
|
||||
@@ -202,7 +138,7 @@ async def test_get_store_creator_details(mocker):
|
||||
mock_creator.return_value.find_unique.return_value = mock_creator_data
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_creator_details("creator")
|
||||
result = await db.get_store_creator("creator")
|
||||
|
||||
# Verify results
|
||||
assert result.username == "creator"
|
||||
@@ -218,61 +154,110 @@ async def test_get_store_creator_details(mocker):
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_create_store_submission(mocker):
|
||||
# Mock data
|
||||
now = datetime.now()
|
||||
|
||||
# Mock agent graph (with no pending submissions) and user with profile
|
||||
mock_profile = prisma.models.Profile(
|
||||
id="profile-id",
|
||||
userId="user-id",
|
||||
name="Test User",
|
||||
username="testuser",
|
||||
description="Test",
|
||||
isFeatured=False,
|
||||
links=[],
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
)
|
||||
mock_user = prisma.models.User(
|
||||
id="user-id",
|
||||
email="test@example.com",
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
Profile=[mock_profile],
|
||||
emailVerified=True,
|
||||
metadata="{}", # type: ignore[reportArgumentType]
|
||||
integrations="",
|
||||
maxEmailsPerDay=1,
|
||||
notifyOnAgentRun=True,
|
||||
notifyOnZeroBalance=True,
|
||||
notifyOnLowBalance=True,
|
||||
notifyOnBlockExecutionFailed=True,
|
||||
notifyOnContinuousAgentError=True,
|
||||
notifyOnDailySummary=True,
|
||||
notifyOnWeeklySummary=True,
|
||||
notifyOnMonthlySummary=True,
|
||||
notifyOnAgentApproved=True,
|
||||
notifyOnAgentRejected=True,
|
||||
timezone="Europe/Delft",
|
||||
)
|
||||
mock_agent = prisma.models.AgentGraph(
|
||||
id="agent-id",
|
||||
version=1,
|
||||
userId="user-id",
|
||||
createdAt=datetime.now(),
|
||||
createdAt=now,
|
||||
isActive=True,
|
||||
StoreListingVersions=[],
|
||||
User=mock_user,
|
||||
)
|
||||
|
||||
mock_listing = prisma.models.StoreListing(
|
||||
# Mock the created StoreListingVersion (returned by create)
|
||||
mock_store_listing_obj = prisma.models.StoreListing(
|
||||
id="listing-id",
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
isDeleted=False,
|
||||
hasApprovedVersion=False,
|
||||
slug="test-agent",
|
||||
agentGraphId="agent-id",
|
||||
agentGraphVersion=1,
|
||||
owningUserId="user-id",
|
||||
Versions=[
|
||||
prisma.models.StoreListingVersion(
|
||||
id="version-id",
|
||||
agentGraphId="agent-id",
|
||||
agentGraphVersion=1,
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
subHeading="Test heading",
|
||||
imageUrls=["image.jpg"],
|
||||
categories=["test"],
|
||||
isFeatured=False,
|
||||
isDeleted=False,
|
||||
version=1,
|
||||
storeListingId="listing-id",
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
isAvailable=True,
|
||||
)
|
||||
],
|
||||
useForOnboarding=False,
|
||||
)
|
||||
mock_version = prisma.models.StoreListingVersion(
|
||||
id="version-id",
|
||||
agentGraphId="agent-id",
|
||||
agentGraphVersion=1,
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
subHeading="",
|
||||
imageUrls=[],
|
||||
categories=[],
|
||||
isFeatured=False,
|
||||
isDeleted=False,
|
||||
version=1,
|
||||
storeListingId="listing-id",
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
isAvailable=True,
|
||||
submittedAt=now,
|
||||
StoreListing=mock_store_listing_obj,
|
||||
)
|
||||
|
||||
# Mock prisma calls
|
||||
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
|
||||
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
mock_store_listing = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_store_listing.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_store_listing.return_value.create = mocker.AsyncMock(return_value=mock_listing)
|
||||
# Mock transaction context manager
|
||||
mock_tx = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"backend.api.features.store.db.transaction",
|
||||
return_value=mocker.AsyncMock(
|
||||
__aenter__=mocker.AsyncMock(return_value=mock_tx),
|
||||
__aexit__=mocker.AsyncMock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
mock_sl = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_sl.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||
|
||||
mock_slv = mocker.patch("prisma.models.StoreListingVersion.prisma")
|
||||
mock_slv.return_value.create = mocker.AsyncMock(return_value=mock_version)
|
||||
|
||||
# Call function
|
||||
result = await db.create_store_submission(
|
||||
user_id="user-id",
|
||||
agent_id="agent-id",
|
||||
agent_version=1,
|
||||
graph_id="agent-id",
|
||||
graph_version=1,
|
||||
slug="test-agent",
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
@@ -281,11 +266,11 @@ async def test_create_store_submission(mocker):
|
||||
# Verify results
|
||||
assert result.name == "Test Agent"
|
||||
assert result.description == "Test description"
|
||||
assert result.store_listing_version_id == "version-id"
|
||||
assert result.listing_version_id == "version-id"
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_agent_graph.return_value.find_first.assert_called_once()
|
||||
mock_store_listing.return_value.create.assert_called_once()
|
||||
mock_slv.return_value.create.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -318,7 +303,6 @@ async def test_update_profile(mocker):
|
||||
description="Test description",
|
||||
links=["link1"],
|
||||
avatar_url="avatar.jpg",
|
||||
is_featured=False,
|
||||
)
|
||||
|
||||
# Call function
|
||||
@@ -389,7 +373,7 @@ async def test_get_store_agents_with_search_and_filters_parameterized():
|
||||
creators=["creator1'; DROP TABLE Users; --", "creator2"],
|
||||
category="AI'; DELETE FROM StoreAgent; --",
|
||||
featured=True,
|
||||
sorted_by="rating",
|
||||
sorted_by=db.StoreAgentsSortOptions.RATING,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
@@ -57,12 +57,6 @@ class StoreError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class AgentNotFoundError(NotFoundError):
|
||||
"""Raised when an agent is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CreatorNotFoundError(NotFoundError):
|
||||
"""Raised when a creator is not found"""
|
||||
|
||||
|
||||
@@ -568,7 +568,7 @@ async def hybrid_search(
|
||||
SELECT uce."contentId" as "storeListingVersionId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON uce."contentId" = sa."storeListingVersionId"
|
||||
ON uce."contentId" = sa.listing_version_id
|
||||
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
AND uce."userId" IS NULL
|
||||
AND uce.search @@ plainto_tsquery('english', {query_param})
|
||||
@@ -582,7 +582,7 @@ async def hybrid_search(
|
||||
SELECT uce."contentId", uce.embedding
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON uce."contentId" = sa."storeListingVersionId"
|
||||
ON uce."contentId" = sa.listing_version_id
|
||||
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
AND uce."userId" IS NULL
|
||||
AND {where_clause}
|
||||
@@ -605,7 +605,7 @@ async def hybrid_search(
|
||||
sa.featured,
|
||||
sa.is_available,
|
||||
sa.updated_at,
|
||||
sa."agentGraphId",
|
||||
sa.graph_id,
|
||||
-- Searchable text for BM25 reranking
|
||||
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
|
||||
-- Semantic score
|
||||
@@ -627,9 +627,9 @@ async def hybrid_search(
|
||||
sa.runs as popularity_raw
|
||||
FROM candidates c
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON c."storeListingVersionId" = sa."storeListingVersionId"
|
||||
ON c."storeListingVersionId" = sa.listing_version_id
|
||||
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
ON sa."storeListingVersionId" = uce."contentId"
|
||||
ON sa.listing_version_id = uce."contentId"
|
||||
AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
),
|
||||
max_vals AS (
|
||||
@@ -665,7 +665,7 @@ async def hybrid_search(
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
"agentGraphId",
|
||||
graph_id,
|
||||
searchable_text,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import datetime
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING, List, Self
|
||||
|
||||
import prisma.enums
|
||||
import pydantic
|
||||
|
||||
from backend.util.models import Pagination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import prisma.models
|
||||
|
||||
|
||||
class ChangelogEntry(pydantic.BaseModel):
|
||||
version: str
|
||||
@@ -13,9 +16,9 @@ class ChangelogEntry(pydantic.BaseModel):
|
||||
date: datetime.datetime
|
||||
|
||||
|
||||
class MyAgent(pydantic.BaseModel):
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
class MyUnpublishedAgent(pydantic.BaseModel):
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
agent_name: str
|
||||
agent_image: str | None = None
|
||||
description: str
|
||||
@@ -23,8 +26,8 @@ class MyAgent(pydantic.BaseModel):
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
|
||||
class MyAgentsResponse(pydantic.BaseModel):
|
||||
agents: list[MyAgent]
|
||||
class MyUnpublishedAgentsResponse(pydantic.BaseModel):
|
||||
agents: list[MyUnpublishedAgent]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
@@ -40,6 +43,21 @@ class StoreAgent(pydantic.BaseModel):
|
||||
rating: float
|
||||
agent_graph_id: str
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgent":
|
||||
return cls(
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_image=agent.agent_image[0] if agent.agent_image else "",
|
||||
creator=agent.creator_username or "Needs Profile",
|
||||
creator_avatar=agent.creator_avatar or "",
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
agent_graph_id=agent.graph_id,
|
||||
)
|
||||
|
||||
|
||||
class StoreAgentsResponse(pydantic.BaseModel):
|
||||
agents: list[StoreAgent]
|
||||
@@ -62,81 +80,192 @@ class StoreAgentDetails(pydantic.BaseModel):
|
||||
runs: int
|
||||
rating: float
|
||||
versions: list[str]
|
||||
agentGraphVersions: list[str]
|
||||
agentGraphId: str
|
||||
graph_id: str
|
||||
graph_versions: list[str]
|
||||
last_updated: datetime.datetime
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
active_version_id: str | None = None
|
||||
has_approved_version: bool = False
|
||||
active_version_id: str
|
||||
has_approved_version: bool
|
||||
|
||||
# Optional changelog data when include_changelog=True
|
||||
changelog: list[ChangelogEntry] | None = None
|
||||
|
||||
|
||||
class Creator(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
avatar_url: str
|
||||
num_agents: int
|
||||
agent_rating: float
|
||||
agent_runs: int
|
||||
is_featured: bool
|
||||
|
||||
|
||||
class CreatorsResponse(pydantic.BaseModel):
|
||||
creators: List[Creator]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class CreatorDetails(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
links: list[str]
|
||||
avatar_url: str
|
||||
agent_rating: float
|
||||
agent_runs: int
|
||||
top_categories: list[str]
|
||||
@classmethod
|
||||
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgentDetails":
|
||||
return cls(
|
||||
store_listing_version_id=agent.listing_version_id,
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_video=agent.agent_video or "",
|
||||
agent_output_demo=agent.agent_output_demo or "",
|
||||
agent_image=agent.agent_image,
|
||||
creator=agent.creator_username or "",
|
||||
creator_avatar=agent.creator_avatar or "",
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
categories=agent.categories,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
versions=agent.versions,
|
||||
graph_id=agent.graph_id,
|
||||
graph_versions=agent.graph_versions,
|
||||
last_updated=agent.updated_at,
|
||||
recommended_schedule_cron=agent.recommended_schedule_cron,
|
||||
active_version_id=agent.listing_version_id,
|
||||
has_approved_version=True, # StoreAgent view only has approved agents
|
||||
)
|
||||
|
||||
|
||||
class Profile(pydantic.BaseModel):
|
||||
name: str
|
||||
"""Marketplace user profile (only attributes that the user can update)"""
|
||||
|
||||
username: str
|
||||
name: str
|
||||
description: str
|
||||
avatar_url: str | None
|
||||
links: list[str]
|
||||
avatar_url: str
|
||||
is_featured: bool = False
|
||||
|
||||
|
||||
class ProfileDetails(Profile):
|
||||
"""Marketplace user profile (including read-only fields)"""
|
||||
|
||||
is_featured: bool
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, profile: "prisma.models.Profile") -> "ProfileDetails":
|
||||
return cls(
|
||||
name=profile.name,
|
||||
username=profile.username,
|
||||
avatar_url=profile.avatarUrl,
|
||||
description=profile.description,
|
||||
links=profile.links,
|
||||
is_featured=profile.isFeatured,
|
||||
)
|
||||
|
||||
|
||||
class CreatorDetails(ProfileDetails):
|
||||
"""Marketplace creator profile details, including aggregated stats"""
|
||||
|
||||
num_agents: int
|
||||
agent_runs: int
|
||||
agent_rating: float
|
||||
top_categories: list[str]
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, creator: "prisma.models.Creator") -> "CreatorDetails": # type: ignore[override]
|
||||
return cls(
|
||||
name=creator.name,
|
||||
username=creator.username,
|
||||
avatar_url=creator.avatar_url,
|
||||
description=creator.description,
|
||||
links=creator.links,
|
||||
is_featured=creator.is_featured,
|
||||
num_agents=creator.num_agents,
|
||||
agent_runs=creator.agent_runs,
|
||||
agent_rating=creator.agent_rating,
|
||||
top_categories=creator.top_categories,
|
||||
)
|
||||
|
||||
|
||||
class CreatorsResponse(pydantic.BaseModel):
|
||||
creators: List[CreatorDetails]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreSubmission(pydantic.BaseModel):
|
||||
# From StoreListing:
|
||||
listing_id: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
user_id: str
|
||||
slug: str
|
||||
|
||||
# From StoreListingVersion:
|
||||
listing_version_id: str
|
||||
listing_version: int
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
name: str
|
||||
sub_heading: str
|
||||
slug: str
|
||||
description: str
|
||||
instructions: str | None = None
|
||||
instructions: str | None
|
||||
categories: list[str]
|
||||
image_urls: list[str]
|
||||
date_submitted: datetime.datetime
|
||||
status: prisma.enums.SubmissionStatus
|
||||
runs: int
|
||||
rating: float
|
||||
store_listing_version_id: str | None = None
|
||||
version: int | None = None # Actual version number from the database
|
||||
video_url: str | None
|
||||
agent_output_demo_url: str | None
|
||||
|
||||
submitted_at: datetime.datetime | None
|
||||
changes_summary: str | None
|
||||
status: prisma.enums.SubmissionStatus
|
||||
reviewed_at: datetime.datetime | None = None
|
||||
reviewer_id: str | None = None
|
||||
review_comments: str | None = None # External comments visible to creator
|
||||
internal_comments: str | None = None # Private notes for admin use only
|
||||
reviewed_at: datetime.datetime | None = None
|
||||
changes_summary: str | None = None
|
||||
|
||||
# Additional fields for editing
|
||||
video_url: str | None = None
|
||||
agent_output_demo_url: str | None = None
|
||||
categories: list[str] = []
|
||||
# Aggregated from AgentGraphExecutions and StoreListingReviews:
|
||||
run_count: int = 0
|
||||
review_count: int = 0
|
||||
review_avg_rating: float = 0.0
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
|
||||
"""Construct from the StoreSubmission Prisma view."""
|
||||
return cls(
|
||||
listing_id=_sub.listing_id,
|
||||
user_id=_sub.user_id,
|
||||
slug=_sub.slug,
|
||||
listing_version_id=_sub.listing_version_id,
|
||||
listing_version=_sub.listing_version,
|
||||
graph_id=_sub.graph_id,
|
||||
graph_version=_sub.graph_version,
|
||||
name=_sub.name,
|
||||
sub_heading=_sub.sub_heading,
|
||||
description=_sub.description,
|
||||
instructions=_sub.instructions,
|
||||
categories=_sub.categories,
|
||||
image_urls=_sub.image_urls,
|
||||
video_url=_sub.video_url,
|
||||
agent_output_demo_url=_sub.agent_output_demo_url,
|
||||
submitted_at=_sub.submitted_at,
|
||||
changes_summary=_sub.changes_summary,
|
||||
status=_sub.status,
|
||||
reviewed_at=_sub.reviewed_at,
|
||||
reviewer_id=_sub.reviewer_id,
|
||||
review_comments=_sub.review_comments,
|
||||
run_count=_sub.run_count,
|
||||
review_count=_sub.review_count,
|
||||
review_avg_rating=_sub.review_avg_rating,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
|
||||
"""
|
||||
Construct from the StoreListingVersion Prisma model (with StoreListing included)
|
||||
"""
|
||||
if not (_l := _lv.StoreListing):
|
||||
raise ValueError("StoreListingVersion must have included StoreListing")
|
||||
|
||||
return cls(
|
||||
listing_id=_l.id,
|
||||
user_id=_l.owningUserId,
|
||||
slug=_l.slug,
|
||||
listing_version_id=_lv.id,
|
||||
listing_version=_lv.version,
|
||||
graph_id=_lv.agentGraphId,
|
||||
graph_version=_lv.agentGraphVersion,
|
||||
name=_lv.name,
|
||||
sub_heading=_lv.subHeading,
|
||||
description=_lv.description,
|
||||
instructions=_lv.instructions,
|
||||
categories=_lv.categories,
|
||||
image_urls=_lv.imageUrls,
|
||||
video_url=_lv.videoUrl,
|
||||
agent_output_demo_url=_lv.agentOutputDemoUrl,
|
||||
submitted_at=_lv.submittedAt,
|
||||
changes_summary=_lv.changesSummary,
|
||||
status=_lv.submissionStatus,
|
||||
reviewed_at=_lv.reviewedAt,
|
||||
reviewer_id=_lv.reviewerId,
|
||||
review_comments=_lv.reviewComments,
|
||||
)
|
||||
|
||||
|
||||
class StoreSubmissionsResponse(pydantic.BaseModel):
|
||||
@@ -144,33 +273,12 @@ class StoreSubmissionsResponse(pydantic.BaseModel):
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreListingWithVersions(pydantic.BaseModel):
|
||||
"""A store listing with its version history"""
|
||||
|
||||
listing_id: str
|
||||
slug: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
active_version_id: str | None = None
|
||||
has_approved_version: bool = False
|
||||
creator_email: str | None = None
|
||||
latest_version: StoreSubmission | None = None
|
||||
versions: list[StoreSubmission] = []
|
||||
|
||||
|
||||
class StoreListingsWithVersionsResponse(pydantic.BaseModel):
|
||||
"""Response model for listings with version history"""
|
||||
|
||||
listings: list[StoreListingWithVersions]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreSubmissionRequest(pydantic.BaseModel):
|
||||
agent_id: str = pydantic.Field(
|
||||
..., min_length=1, description="Agent ID cannot be empty"
|
||||
graph_id: str = pydantic.Field(
|
||||
..., min_length=1, description="Graph ID cannot be empty"
|
||||
)
|
||||
agent_version: int = pydantic.Field(
|
||||
..., gt=0, description="Agent version must be greater than 0"
|
||||
graph_version: int = pydantic.Field(
|
||||
..., gt=0, description="Graph version must be greater than 0"
|
||||
)
|
||||
slug: str
|
||||
name: str
|
||||
@@ -198,12 +306,42 @@ class StoreSubmissionEditRequest(pydantic.BaseModel):
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
|
||||
class ProfileDetails(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
links: list[str]
|
||||
avatar_url: str | None = None
|
||||
class StoreSubmissionAdminView(StoreSubmission):
|
||||
internal_comments: str | None # Private admin notes
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
|
||||
return cls(
|
||||
**StoreSubmission.from_db(_sub).model_dump(),
|
||||
internal_comments=_sub.internal_comments,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
|
||||
return cls(
|
||||
**StoreSubmission.from_listing_version(_lv).model_dump(),
|
||||
internal_comments=_lv.internalComments,
|
||||
)
|
||||
|
||||
|
||||
class StoreListingWithVersionsAdminView(pydantic.BaseModel):
|
||||
"""A store listing with its version history"""
|
||||
|
||||
listing_id: str
|
||||
graph_id: str
|
||||
slug: str
|
||||
active_listing_version_id: str | None = None
|
||||
has_approved_version: bool = False
|
||||
creator_email: str | None = None
|
||||
latest_version: StoreSubmissionAdminView | None = None
|
||||
versions: list[StoreSubmissionAdminView] = []
|
||||
|
||||
|
||||
class StoreListingsWithVersionsAdminViewResponse(pydantic.BaseModel):
|
||||
"""Response model for listings with version history"""
|
||||
|
||||
listings: list[StoreListingWithVersionsAdminView]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreReview(pydantic.BaseModel):
|
||||
|
||||
@@ -1,203 +0,0 @@
|
||||
import datetime
|
||||
|
||||
import prisma.enums
|
||||
|
||||
from . import model as store_model
|
||||
|
||||
|
||||
def test_pagination():
|
||||
pagination = store_model.Pagination(
|
||||
total_items=100, total_pages=5, current_page=2, page_size=20
|
||||
)
|
||||
assert pagination.total_items == 100
|
||||
assert pagination.total_pages == 5
|
||||
assert pagination.current_page == 2
|
||||
assert pagination.page_size == 20
|
||||
|
||||
|
||||
def test_store_agent():
|
||||
agent = store_model.StoreAgent(
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_image="test.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
agent_graph_id="test-graph-id",
|
||||
)
|
||||
assert agent.slug == "test-agent"
|
||||
assert agent.agent_name == "Test Agent"
|
||||
assert agent.runs == 50
|
||||
assert agent.rating == 4.5
|
||||
assert agent.agent_graph_id == "test-graph-id"
|
||||
|
||||
|
||||
def test_store_agents_response():
|
||||
response = store_model.StoreAgentsResponse(
|
||||
agents=[
|
||||
store_model.StoreAgent(
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_image="test.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
agent_graph_id="test-graph-id",
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
assert len(response.agents) == 1
|
||||
assert response.pagination.total_items == 1
|
||||
|
||||
|
||||
def test_store_agent_details():
|
||||
details = store_model.StoreAgentDetails(
|
||||
store_listing_version_id="version123",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video="video.mp4",
|
||||
agent_output_demo="demo.mp4",
|
||||
agent_image=["image1.jpg", "image2.jpg"],
|
||||
creator="creator1",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
categories=["cat1", "cat2"],
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
versions=["1.0", "2.0"],
|
||||
agentGraphVersions=["1", "2"],
|
||||
agentGraphId="test-graph-id",
|
||||
last_updated=datetime.datetime.now(),
|
||||
)
|
||||
assert details.slug == "test-agent"
|
||||
assert len(details.agent_image) == 2
|
||||
assert len(details.categories) == 2
|
||||
assert len(details.versions) == 2
|
||||
|
||||
|
||||
def test_creator():
|
||||
creator = store_model.Creator(
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
name="Test Creator",
|
||||
username="creator1",
|
||||
description="Test description",
|
||||
avatar_url="avatar.jpg",
|
||||
num_agents=5,
|
||||
is_featured=False,
|
||||
)
|
||||
assert creator.name == "Test Creator"
|
||||
assert creator.num_agents == 5
|
||||
|
||||
|
||||
def test_creators_response():
|
||||
response = store_model.CreatorsResponse(
|
||||
creators=[
|
||||
store_model.Creator(
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
name="Test Creator",
|
||||
username="creator1",
|
||||
description="Test description",
|
||||
avatar_url="avatar.jpg",
|
||||
num_agents=5,
|
||||
is_featured=False,
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
assert len(response.creators) == 1
|
||||
assert response.pagination.total_items == 1
|
||||
|
||||
|
||||
def test_creator_details():
|
||||
details = store_model.CreatorDetails(
|
||||
name="Test Creator",
|
||||
username="creator1",
|
||||
description="Test description",
|
||||
links=["link1.com", "link2.com"],
|
||||
avatar_url="avatar.jpg",
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
top_categories=["cat1", "cat2"],
|
||||
)
|
||||
assert details.name == "Test Creator"
|
||||
assert len(details.links) == 2
|
||||
assert details.agent_rating == 4.8
|
||||
assert len(details.top_categories) == 2
|
||||
|
||||
|
||||
def test_store_submission():
|
||||
submission = store_model.StoreSubmission(
|
||||
listing_id="listing123",
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
description="Test description",
|
||||
image_urls=["image1.jpg", "image2.jpg"],
|
||||
date_submitted=datetime.datetime(2023, 1, 1),
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
)
|
||||
assert submission.name == "Test Agent"
|
||||
assert len(submission.image_urls) == 2
|
||||
assert submission.status == prisma.enums.SubmissionStatus.PENDING
|
||||
|
||||
|
||||
def test_store_submissions_response():
|
||||
response = store_model.StoreSubmissionsResponse(
|
||||
submissions=[
|
||||
store_model.StoreSubmission(
|
||||
listing_id="listing123",
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
description="Test description",
|
||||
image_urls=["image1.jpg"],
|
||||
date_submitted=datetime.datetime(2023, 1, 1),
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
assert len(response.submissions) == 1
|
||||
assert response.pagination.total_items == 1
|
||||
|
||||
|
||||
def test_store_submission_request():
|
||||
request = store_model.StoreSubmissionRequest(
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
slug="test-agent",
|
||||
name="Test Agent",
|
||||
sub_heading="Test subheading",
|
||||
video_url="video.mp4",
|
||||
image_urls=["image1.jpg", "image2.jpg"],
|
||||
description="Test description",
|
||||
categories=["cat1", "cat2"],
|
||||
)
|
||||
assert request.agent_id == "agent123"
|
||||
assert request.agent_version == 1
|
||||
assert len(request.image_urls) == 2
|
||||
assert len(request.categories) == 2
|
||||
@@ -1,16 +1,17 @@
|
||||
import logging
|
||||
import tempfile
|
||||
import typing
|
||||
import urllib.parse
|
||||
from typing import Literal
|
||||
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
import prisma.enums
|
||||
from fastapi import Query, Security
|
||||
from pydantic import BaseModel
|
||||
|
||||
import backend.data.graph
|
||||
import backend.util.json
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from . import cache as store_cache
|
||||
@@ -34,22 +35,15 @@ router = fastapi.APIRouter()
|
||||
"/profile",
|
||||
summary="Get user profile",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.ProfileDetails,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_profile(
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Get the profile details for the authenticated user.
|
||||
Cached for 1 hour per user.
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.ProfileDetails:
|
||||
"""Get the profile details for the authenticated user."""
|
||||
profile = await store_db.get_user_profile(user_id)
|
||||
if profile is None:
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=404,
|
||||
content={"detail": "Profile not found"},
|
||||
)
|
||||
raise NotFoundError("User does not have a profile yet")
|
||||
return profile
|
||||
|
||||
|
||||
@@ -57,98 +51,17 @@ async def get_profile(
|
||||
"/profile",
|
||||
summary="Update user profile",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.CreatorDetails,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def update_or_create_profile(
|
||||
profile: store_model.Profile,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Update the store profile for the authenticated user.
|
||||
|
||||
Args:
|
||||
profile (Profile): The updated profile details
|
||||
user_id (str): ID of the authenticated user
|
||||
|
||||
Returns:
|
||||
CreatorDetails: The updated profile
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error updating the profile
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.ProfileDetails:
|
||||
"""Update the store profile for the authenticated user."""
|
||||
updated_profile = await store_db.update_profile(user_id=user_id, profile=profile)
|
||||
return updated_profile
|
||||
|
||||
|
||||
##############################################
|
||||
############### Agent Endpoints ##############
|
||||
##############################################
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents",
|
||||
summary="List store agents",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.StoreAgentsResponse,
|
||||
)
|
||||
async def get_agents(
|
||||
featured: bool = False,
|
||||
creator: str | None = None,
|
||||
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
|
||||
search_query: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""
|
||||
Get a paginated list of agents from the store with optional filtering and sorting.
|
||||
|
||||
Args:
|
||||
featured (bool, optional): Filter to only show featured agents. Defaults to False.
|
||||
creator (str | None, optional): Filter agents by creator username. Defaults to None.
|
||||
sorted_by (str | None, optional): Sort agents by "runs" or "rating". Defaults to None.
|
||||
search_query (str | None, optional): Search agents by name, subheading and description. Defaults to None.
|
||||
category (str | None, optional): Filter agents by category. Defaults to None.
|
||||
page (int, optional): Page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): Number of agents per page. Defaults to 20.
|
||||
|
||||
Returns:
|
||||
StoreAgentsResponse: Paginated list of agents matching the filters
|
||||
|
||||
Raises:
|
||||
HTTPException: If page or page_size are less than 1
|
||||
|
||||
Used for:
|
||||
- Home Page Featured Agents
|
||||
- Home Page Top Agents
|
||||
- Search Results
|
||||
- Agent Details - Other Agents By Creator
|
||||
- Agent Details - Similar Agents
|
||||
- Creator Details - Agents By Creator
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
agents = await store_cache._get_cached_store_agents(
|
||||
featured=featured,
|
||||
creator=creator,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return agents
|
||||
|
||||
|
||||
##############################################
|
||||
############### Search Endpoints #############
|
||||
##############################################
|
||||
@@ -158,60 +71,30 @@ async def get_agents(
|
||||
"/search",
|
||||
summary="Unified search across all content types",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.UnifiedSearchResponse,
|
||||
)
|
||||
async def unified_search(
|
||||
query: str,
|
||||
content_types: list[str] | None = fastapi.Query(
|
||||
content_types: list[prisma.enums.ContentType] | None = Query(
|
||||
default=None,
|
||||
description="Content types to search: STORE_AGENT, BLOCK, DOCUMENTATION. If not specified, searches all.",
|
||||
description="Content types to search. If not specified, searches all.",
|
||||
),
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
user_id: str | None = fastapi.Security(
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
user_id: str | None = Security(
|
||||
autogpt_libs.auth.get_optional_user_id, use_cache=False
|
||||
),
|
||||
):
|
||||
) -> store_model.UnifiedSearchResponse:
|
||||
"""
|
||||
Search across all content types (store agents, blocks, documentation) using hybrid search.
|
||||
Search across all content types (marketplace agents, blocks, documentation)
|
||||
using hybrid search.
|
||||
|
||||
Combines semantic (embedding-based) and lexical (text-based) search for best results.
|
||||
|
||||
Args:
|
||||
query: The search query string
|
||||
content_types: Optional list of content types to filter by (STORE_AGENT, BLOCK, DOCUMENTATION)
|
||||
page: Page number for pagination (default 1)
|
||||
page_size: Number of results per page (default 20)
|
||||
user_id: Optional authenticated user ID (for user-scoped content in future)
|
||||
|
||||
Returns:
|
||||
UnifiedSearchResponse: Paginated list of search results with relevance scores
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
# Convert string content types to enum
|
||||
content_type_enums: list[prisma.enums.ContentType] | None = None
|
||||
if content_types:
|
||||
try:
|
||||
content_type_enums = [prisma.enums.ContentType(ct) for ct in content_types]
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Invalid content type. Valid values: STORE_AGENT, BLOCK, DOCUMENTATION. Error: {e}",
|
||||
)
|
||||
|
||||
# Perform unified hybrid search
|
||||
results, total = await store_hybrid_search.unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=content_type_enums,
|
||||
content_types=content_types,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
@@ -245,22 +128,69 @@ async def unified_search(
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
############### Agent Endpoints ##############
|
||||
##############################################
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents",
|
||||
summary="List store agents",
|
||||
tags=["store", "public"],
|
||||
)
|
||||
async def get_agents(
|
||||
featured: bool = Query(
|
||||
default=False, description="Filter to only show featured agents"
|
||||
),
|
||||
creator: str | None = Query(
|
||||
default=None, description="Filter agents by creator username"
|
||||
),
|
||||
category: str | None = Query(default=None, description="Filter agents by category"),
|
||||
search_query: str | None = Query(
|
||||
default=None, description="Literal + semantic search on names and descriptions"
|
||||
),
|
||||
sorted_by: store_db.StoreAgentsSortOptions | None = Query(
|
||||
default=None,
|
||||
description="Property to sort results by. Ignored if search_query is provided.",
|
||||
),
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
) -> store_model.StoreAgentsResponse:
|
||||
"""
|
||||
Get a paginated list of agents from the marketplace,
|
||||
with optional filtering and sorting.
|
||||
|
||||
Used for:
|
||||
- Home Page Featured Agents
|
||||
- Home Page Top Agents
|
||||
- Search Results
|
||||
- Agent Details - Other Agents By Creator
|
||||
- Agent Details - Similar Agents
|
||||
- Creator Details - Agents By Creator
|
||||
"""
|
||||
agents = await store_cache._get_cached_store_agents(
|
||||
featured=featured,
|
||||
creator=creator,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return agents
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents/{username}/{agent_name}",
|
||||
summary="Get specific agent",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.StoreAgentDetails,
|
||||
)
|
||||
async def get_agent(
|
||||
async def get_agent_by_name(
|
||||
username: str,
|
||||
agent_name: str,
|
||||
include_changelog: bool = fastapi.Query(default=False),
|
||||
):
|
||||
"""
|
||||
This is only used on the AgentDetails Page.
|
||||
|
||||
It returns the store listing agents details.
|
||||
"""
|
||||
include_changelog: bool = Query(default=False),
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""Get details of a marketplace agent"""
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
# URL decode the agent name since it comes from the URL path
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
@@ -270,76 +200,82 @@ async def get_agent(
|
||||
return agent
|
||||
|
||||
|
||||
@router.get(
|
||||
"/graph/{store_listing_version_id}",
|
||||
summary="Get agent graph",
|
||||
tags=["store"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_graph_meta_by_store_listing_version_id(
|
||||
store_listing_version_id: str,
|
||||
) -> backend.data.graph.GraphModelWithoutNodes:
|
||||
"""
|
||||
Get Agent Graph from Store Listing Version ID.
|
||||
"""
|
||||
graph = await store_db.get_available_graph(store_listing_version_id)
|
||||
return graph
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents/{store_listing_version_id}",
|
||||
summary="Get agent by version",
|
||||
tags=["store"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreAgentDetails,
|
||||
)
|
||||
async def get_store_agent(store_listing_version_id: str):
|
||||
"""
|
||||
Get Store Agent Details from Store Listing Version ID.
|
||||
"""
|
||||
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@router.post(
|
||||
"/agents/{username}/{agent_name}/review",
|
||||
summary="Create agent review",
|
||||
tags=["store"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreReview,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def create_review(
|
||||
async def post_user_review_for_agent(
|
||||
username: str,
|
||||
agent_name: str,
|
||||
review: store_model.StoreReviewCreate,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Create a review for a store agent.
|
||||
|
||||
Args:
|
||||
username: Creator's username
|
||||
agent_name: Name/slug of the agent
|
||||
review: Review details including score and optional comments
|
||||
user_id: ID of authenticated user creating the review
|
||||
|
||||
Returns:
|
||||
The created review
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.StoreReview:
|
||||
"""Post a user review on a marketplace agent listing"""
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
# Create the review
|
||||
|
||||
created_review = await store_db.create_store_review(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=review.store_listing_version_id,
|
||||
score=review.score,
|
||||
comments=review.comments,
|
||||
)
|
||||
|
||||
return created_review
|
||||
|
||||
|
||||
@router.get(
|
||||
"/listings/versions/{store_listing_version_id}",
|
||||
summary="Get agent by version",
|
||||
tags=["store"],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_agent_by_listing_version(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
|
||||
return agent
|
||||
|
||||
|
||||
@router.get(
|
||||
"/listings/versions/{store_listing_version_id}/graph",
|
||||
summary="Get agent graph",
|
||||
tags=["store"],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_graph_meta_by_store_listing_version_id(
|
||||
store_listing_version_id: str,
|
||||
) -> backend.data.graph.GraphModelWithoutNodes:
|
||||
"""Get outline of graph belonging to a specific marketplace listing version"""
|
||||
graph = await store_db.get_available_graph(store_listing_version_id)
|
||||
return graph
|
||||
|
||||
|
||||
@router.get(
|
||||
"/listings/versions/{store_listing_version_id}/graph/download",
|
||||
summary="Download agent file",
|
||||
tags=["store", "public"],
|
||||
)
|
||||
async def download_agent_file(
|
||||
store_listing_version_id: str,
|
||||
) -> fastapi.responses.FileResponse:
|
||||
"""Download agent graph file for a specific marketplace listing version"""
|
||||
graph_data = await store_db.get_agent(store_listing_version_id)
|
||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||
|
||||
# Sending graph as a stream (similar to marketplace v1)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(backend.util.json.dumps(graph_data))
|
||||
tmp_file.flush()
|
||||
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
############# Creator Endpoints #############
|
||||
##############################################
|
||||
@@ -349,37 +285,19 @@ async def create_review(
|
||||
"/creators",
|
||||
summary="List store creators",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.CreatorsResponse,
|
||||
)
|
||||
async def get_creators(
|
||||
featured: bool = False,
|
||||
search_query: str | None = None,
|
||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""
|
||||
This is needed for:
|
||||
- Home Page Featured Creators
|
||||
- Search Results Page
|
||||
|
||||
---
|
||||
|
||||
To support this functionality we need:
|
||||
- featured: bool - to limit the list to just featured agents
|
||||
- search_query: str - vector search based on the creators profile description.
|
||||
- sorted_by: [agent_rating, agent_runs] -
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
featured: bool = Query(
|
||||
default=False, description="Filter to only show featured creators"
|
||||
),
|
||||
search_query: str | None = Query(
|
||||
default=None, description="Literal + semantic search on names and descriptions"
|
||||
),
|
||||
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
) -> store_model.CreatorsResponse:
|
||||
"""List or search marketplace creators"""
|
||||
creators = await store_cache._get_cached_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
@@ -391,18 +309,12 @@ async def get_creators(
|
||||
|
||||
|
||||
@router.get(
|
||||
"/creator/{username}",
|
||||
"/creators/{username}",
|
||||
summary="Get creator details",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.CreatorDetails,
|
||||
)
|
||||
async def get_creator(
|
||||
username: str,
|
||||
):
|
||||
"""
|
||||
Get the details of a creator.
|
||||
- Creator Details Page
|
||||
"""
|
||||
async def get_creator(username: str) -> store_model.CreatorDetails:
|
||||
"""Get details on a marketplace creator"""
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
creator = await store_cache._get_cached_creator_details(username=username)
|
||||
return creator
|
||||
@@ -414,20 +326,17 @@ async def get_creator(
|
||||
|
||||
|
||||
@router.get(
|
||||
"/myagents",
|
||||
"/my-unpublished-agents",
|
||||
summary="Get my agents",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.MyAgentsResponse,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_my_agents(
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
page: typing.Annotated[int, fastapi.Query(ge=1)] = 1,
|
||||
page_size: typing.Annotated[int, fastapi.Query(ge=1)] = 20,
|
||||
):
|
||||
"""
|
||||
Get user's own agents.
|
||||
"""
|
||||
async def get_my_unpublished_agents(
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
) -> store_model.MyUnpublishedAgentsResponse:
|
||||
"""List the authenticated user's unpublished agents"""
|
||||
agents = await store_db.get_my_agents(user_id, page=page, page_size=page_size)
|
||||
return agents
|
||||
|
||||
@@ -436,28 +345,17 @@ async def get_my_agents(
|
||||
"/submissions/{submission_id}",
|
||||
summary="Delete store submission",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=bool,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def delete_submission(
|
||||
submission_id: str,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Delete a store listing submission.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the authenticated user
|
||||
submission_id (str): ID of the submission to be deleted
|
||||
|
||||
Returns:
|
||||
bool: True if the submission was successfully deleted, False otherwise
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> bool:
|
||||
"""Delete a marketplace listing submission"""
|
||||
result = await store_db.delete_store_submission(
|
||||
user_id=user_id,
|
||||
submission_id=submission_id,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -465,37 +363,14 @@ async def delete_submission(
|
||||
"/submissions",
|
||||
summary="List my submissions",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreSubmissionsResponse,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_submissions(
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""
|
||||
Get a paginated list of store submissions for the authenticated user.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the authenticated user
|
||||
page (int, optional): Page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): Number of submissions per page. Defaults to 20.
|
||||
|
||||
Returns:
|
||||
StoreListingsResponse: Paginated list of store submissions
|
||||
|
||||
Raises:
|
||||
HTTPException: If page or page_size are less than 1
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
) -> store_model.StoreSubmissionsResponse:
|
||||
"""List the authenticated user's marketplace listing submissions"""
|
||||
listings = await store_db.get_store_submissions(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
@@ -508,30 +383,17 @@ async def get_submissions(
|
||||
"/submissions",
|
||||
summary="Create store submission",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreSubmission,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def create_submission(
|
||||
submission_request: store_model.StoreSubmissionRequest,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Create a new store listing submission.
|
||||
|
||||
Args:
|
||||
submission_request (StoreSubmissionRequest): The submission details
|
||||
user_id (str): ID of the authenticated user submitting the listing
|
||||
|
||||
Returns:
|
||||
StoreSubmission: The created store submission
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error creating the submission
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.StoreSubmission:
|
||||
"""Submit a new marketplace listing for review"""
|
||||
result = await store_db.create_store_submission(
|
||||
user_id=user_id,
|
||||
agent_id=submission_request.agent_id,
|
||||
agent_version=submission_request.agent_version,
|
||||
graph_id=submission_request.graph_id,
|
||||
graph_version=submission_request.graph_version,
|
||||
slug=submission_request.slug,
|
||||
name=submission_request.name,
|
||||
video_url=submission_request.video_url,
|
||||
@@ -544,7 +406,6 @@ async def create_submission(
|
||||
changes_summary=submission_request.changes_summary or "Initial Submission",
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -552,28 +413,14 @@ async def create_submission(
|
||||
"/submissions/{store_listing_version_id}",
|
||||
summary="Edit store submission",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreSubmission,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def edit_submission(
|
||||
store_listing_version_id: str,
|
||||
submission_request: store_model.StoreSubmissionEditRequest,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Edit an existing store listing submission.
|
||||
|
||||
Args:
|
||||
store_listing_version_id (str): ID of the store listing version to edit
|
||||
submission_request (StoreSubmissionRequest): The updated submission details
|
||||
user_id (str): ID of the authenticated user editing the listing
|
||||
|
||||
Returns:
|
||||
StoreSubmission: The updated store submission
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error editing the submission
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.StoreSubmission:
|
||||
"""Update a pending marketplace listing submission"""
|
||||
result = await store_db.edit_store_submission(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
@@ -588,7 +435,6 @@ async def edit_submission(
|
||||
changes_summary=submission_request.changes_summary,
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -596,115 +442,61 @@ async def edit_submission(
|
||||
"/submissions/media",
|
||||
summary="Upload submission media",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def upload_submission_media(
|
||||
file: fastapi.UploadFile,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Upload media (images/videos) for a store listing submission.
|
||||
|
||||
Args:
|
||||
file (UploadFile): The media file to upload
|
||||
user_id (str): ID of the authenticated user uploading the media
|
||||
|
||||
Returns:
|
||||
str: URL of the uploaded media file
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error uploading the media
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> str:
|
||||
"""Upload media for a marketplace listing submission"""
|
||||
media_url = await store_media.upload_media(user_id=user_id, file=file)
|
||||
return media_url
|
||||
|
||||
|
||||
class ImageURLResponse(BaseModel):
|
||||
image_url: str
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/generate_image",
|
||||
summary="Generate submission image",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def generate_image(
|
||||
agent_id: str,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
) -> fastapi.responses.Response:
|
||||
graph_id: str,
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> ImageURLResponse:
|
||||
"""
|
||||
Generate an image for a store listing submission.
|
||||
|
||||
Args:
|
||||
agent_id (str): ID of the agent to generate an image for
|
||||
user_id (str): ID of the authenticated user
|
||||
|
||||
Returns:
|
||||
JSONResponse: JSON containing the URL of the generated image
|
||||
Generate an image for a marketplace listing submission based on the properties
|
||||
of a given graph.
|
||||
"""
|
||||
agent = await backend.data.graph.get_graph(
|
||||
graph_id=agent_id, version=None, user_id=user_id
|
||||
graph = await backend.data.graph.get_graph(
|
||||
graph_id=graph_id, version=None, user_id=user_id
|
||||
)
|
||||
|
||||
if not agent:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(f"Agent graph #{graph_id} not found")
|
||||
# Use .jpeg here since we are generating JPEG images
|
||||
filename = f"agent_{agent_id}.jpeg"
|
||||
filename = f"agent_{graph_id}.jpeg"
|
||||
|
||||
existing_url = await store_media.check_media_exists(user_id, filename)
|
||||
if existing_url:
|
||||
logger.info(f"Using existing image for agent {agent_id}")
|
||||
return fastapi.responses.JSONResponse(content={"image_url": existing_url})
|
||||
logger.info(f"Using existing image for agent graph {graph_id}")
|
||||
return ImageURLResponse(image_url=existing_url)
|
||||
# Generate agent image as JPEG
|
||||
image = await store_image_gen.generate_agent_image(agent=agent)
|
||||
image = await store_image_gen.generate_agent_image(agent=graph)
|
||||
|
||||
# Create UploadFile with the correct filename and content_type
|
||||
image_file = fastapi.UploadFile(
|
||||
file=image,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
image_url = await store_media.upload_media(
|
||||
user_id=user_id, file=image_file, use_file_name=True
|
||||
)
|
||||
|
||||
return fastapi.responses.JSONResponse(content={"image_url": image_url})
|
||||
|
||||
|
||||
@router.get(
|
||||
"/download/agents/{store_listing_version_id}",
|
||||
summary="Download agent file",
|
||||
tags=["store", "public"],
|
||||
)
|
||||
async def download_agent_file(
|
||||
store_listing_version_id: str = fastapi.Path(
|
||||
..., description="The ID of the agent to download"
|
||||
),
|
||||
) -> fastapi.responses.FileResponse:
|
||||
"""
|
||||
Download the agent file by streaming its content.
|
||||
|
||||
Args:
|
||||
store_listing_version_id (str): The ID of the agent to download
|
||||
|
||||
Returns:
|
||||
StreamingResponse: A streaming response containing the agent's graph data.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the agent is not found or an unexpected error occurs.
|
||||
"""
|
||||
graph_data = await store_db.get_agent(store_listing_version_id)
|
||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||
|
||||
# Sending graph as a stream (similar to marketplace v1)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(backend.util.json.dumps(graph_data))
|
||||
tmp_file.flush()
|
||||
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
return ImageURLResponse(image_url=image_url)
|
||||
|
||||
|
||||
##############################################
|
||||
|
||||
@@ -8,6 +8,8 @@ import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.api.features.store.db import StoreAgentsSortOptions
|
||||
|
||||
from . import model as store_model
|
||||
from . import routes as store_routes
|
||||
|
||||
@@ -196,7 +198,7 @@ def test_get_agents_sorted(
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=False,
|
||||
creators=None,
|
||||
sorted_by="runs",
|
||||
sorted_by=StoreAgentsSortOptions.RUNS,
|
||||
search_query=None,
|
||||
category=None,
|
||||
page=1,
|
||||
@@ -380,9 +382,11 @@ def test_get_agent_details(
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
versions=["1.0.0", "1.1.0"],
|
||||
agentGraphVersions=["1", "2"],
|
||||
agentGraphId="test-graph-id",
|
||||
graph_versions=["1", "2"],
|
||||
graph_id="test-graph-id",
|
||||
last_updated=FIXED_NOW,
|
||||
active_version_id="test-version-id",
|
||||
has_approved_version=True,
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agent_details")
|
||||
mock_db_call.return_value = mocked_value
|
||||
@@ -435,15 +439,17 @@ def test_get_creators_pagination(
|
||||
) -> None:
|
||||
mocked_value = store_model.CreatorsResponse(
|
||||
creators=[
|
||||
store_model.Creator(
|
||||
store_model.CreatorDetails(
|
||||
name=f"Creator {i}",
|
||||
username=f"creator{i}",
|
||||
description=f"Creator {i} description",
|
||||
avatar_url=f"avatar{i}.jpg",
|
||||
num_agents=1,
|
||||
agent_rating=4.5,
|
||||
agent_runs=100,
|
||||
description=f"Creator {i} description",
|
||||
links=[f"user{i}.link.com"],
|
||||
is_featured=False,
|
||||
num_agents=1,
|
||||
agent_runs=100,
|
||||
agent_rating=4.5,
|
||||
top_categories=["cat1", "cat2", "cat3"],
|
||||
)
|
||||
for i in range(5)
|
||||
],
|
||||
@@ -496,19 +502,19 @@ def test_get_creator_details(
|
||||
mocked_value = store_model.CreatorDetails(
|
||||
name="Test User",
|
||||
username="creator1",
|
||||
avatar_url="avatar.jpg",
|
||||
description="Test creator description",
|
||||
links=["link1.com", "link2.com"],
|
||||
avatar_url="avatar.jpg",
|
||||
agent_rating=4.8,
|
||||
is_featured=True,
|
||||
num_agents=5,
|
||||
agent_runs=1000,
|
||||
agent_rating=4.8,
|
||||
top_categories=["category1", "category2"],
|
||||
)
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.api.features.store.db.get_store_creator_details"
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_creator")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/creator/creator1")
|
||||
response = client.get("/creators/creator1")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = store_model.CreatorDetails.model_validate(response.json())
|
||||
@@ -528,19 +534,26 @@ def test_get_submissions_success(
|
||||
submissions=[
|
||||
store_model.StoreSubmission(
|
||||
listing_id="test-listing-id",
|
||||
name="Test Agent",
|
||||
description="Test agent description",
|
||||
image_urls=["test.jpg"],
|
||||
date_submitted=FIXED_NOW,
|
||||
status=prisma.enums.SubmissionStatus.APPROVED,
|
||||
runs=50,
|
||||
rating=4.2,
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
sub_heading="Test agent subheading",
|
||||
user_id="test-user-id",
|
||||
slug="test-agent",
|
||||
video_url="test.mp4",
|
||||
listing_version_id="test-version-id",
|
||||
listing_version=1,
|
||||
graph_id="test-agent-id",
|
||||
graph_version=1,
|
||||
name="Test Agent",
|
||||
sub_heading="Test agent subheading",
|
||||
description="Test agent description",
|
||||
instructions="Click the button!",
|
||||
categories=["test-category"],
|
||||
image_urls=["test.jpg"],
|
||||
video_url="test.mp4",
|
||||
agent_output_demo_url="demo_video.mp4",
|
||||
submitted_at=FIXED_NOW,
|
||||
changes_summary="Initial Submission",
|
||||
status=prisma.enums.SubmissionStatus.APPROVED,
|
||||
run_count=50,
|
||||
review_count=5,
|
||||
review_avg_rating=4.2,
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
|
||||
@@ -11,6 +11,7 @@ import pytest
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from . import cache as store_cache
|
||||
from .db import StoreAgentsSortOptions
|
||||
from .model import StoreAgent, StoreAgentsResponse
|
||||
|
||||
|
||||
@@ -215,7 +216,7 @@ class TestCacheDeletion:
|
||||
await store_cache._get_cached_store_agents(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by="rating",
|
||||
sorted_by=StoreAgentsSortOptions.RATING,
|
||||
search_query="AI assistant",
|
||||
category="productivity",
|
||||
page=2,
|
||||
@@ -227,7 +228,7 @@ class TestCacheDeletion:
|
||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by="rating",
|
||||
sorted_by=StoreAgentsSortOptions.RATING,
|
||||
search_query="AI assistant",
|
||||
category="productivity",
|
||||
page=2,
|
||||
@@ -239,7 +240,7 @@ class TestCacheDeletion:
|
||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by="rating",
|
||||
sorted_by=StoreAgentsSortOptions.RATING,
|
||||
search_query="AI assistant",
|
||||
category="productivity",
|
||||
page=2,
|
||||
|
||||
@@ -55,6 +55,7 @@ from backend.data.credit import (
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.graph import GraphSettings
|
||||
from backend.data.invited_user import get_or_activate_user
|
||||
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.data.onboarding import (
|
||||
@@ -70,7 +71,6 @@ from backend.data.onboarding import (
|
||||
update_user_onboarding,
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_by_id,
|
||||
get_user_notification_preference,
|
||||
update_user_email,
|
||||
@@ -136,12 +136,10 @@ _tally_background_tasks: set[asyncio.Task] = set()
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
user = await get_or_create_user(user_data)
|
||||
user = await get_or_activate_user(user_data)
|
||||
|
||||
# Fire-and-forget: populate business understanding from Tally form.
|
||||
# We use created_at proximity instead of an is_new flag because
|
||||
# get_or_create_user is cached — a separate is_new return value would be
|
||||
# unreliable on repeated calls within the cache TTL.
|
||||
# Fire-and-forget: backfill Tally understanding when invite pre-seeding did
|
||||
# not produce a stored result before first activation.
|
||||
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
|
||||
if age_seconds < 30:
|
||||
try:
|
||||
@@ -165,8 +163,11 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_user_email_route(
|
||||
user_id: Annotated[str, Security(get_user_id)], email: str = Body(...)
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
email: str = Body(...),
|
||||
user_data: dict = Security(get_jwt_payload),
|
||||
) -> dict[str, str]:
|
||||
await get_or_activate_user(user_data)
|
||||
await update_user_email(user_id, email)
|
||||
|
||||
return {"email": email}
|
||||
@@ -182,7 +183,7 @@ async def get_user_timezone_route(
|
||||
user_data: dict = Security(get_jwt_payload),
|
||||
) -> TimezoneResponse:
|
||||
"""Get user timezone setting."""
|
||||
user = await get_or_create_user(user_data)
|
||||
user = await get_or_activate_user(user_data)
|
||||
return TimezoneResponse(timezone=user.timezone)
|
||||
|
||||
|
||||
@@ -193,9 +194,12 @@ async def get_user_timezone_route(
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_user_timezone_route(
|
||||
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
request: UpdateTimezoneRequest,
|
||||
user_data: dict = Security(get_jwt_payload),
|
||||
) -> TimezoneResponse:
|
||||
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
|
||||
await get_or_activate_user(user_data)
|
||||
user = await update_user_timezone(user_id, str(request.timezone))
|
||||
return TimezoneResponse(timezone=user.timezone)
|
||||
|
||||
@@ -208,7 +212,9 @@ async def update_user_timezone_route(
|
||||
)
|
||||
async def get_preferences(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
user_data: dict = Security(get_jwt_payload),
|
||||
) -> NotificationPreference:
|
||||
await get_or_activate_user(user_data)
|
||||
preferences = await get_user_notification_preference(user_id)
|
||||
return preferences
|
||||
|
||||
@@ -222,7 +228,9 @@ async def get_preferences(
|
||||
async def update_preferences(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
preferences: NotificationPreferenceDTO = Body(...),
|
||||
user_data: dict = Security(get_jwt_payload),
|
||||
) -> NotificationPreference:
|
||||
await get_or_activate_user(user_data)
|
||||
output = await update_user_notification_preference(user_id, preferences)
|
||||
return output
|
||||
|
||||
@@ -449,7 +457,6 @@ async def execute_graph_block(
|
||||
async def upload_file(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
file: UploadFile = File(...),
|
||||
provider: str = "gcs",
|
||||
expiration_hours: int = 24,
|
||||
) -> UploadFileResponse:
|
||||
"""
|
||||
@@ -512,7 +519,6 @@ async def upload_file(
|
||||
storage_path = await cloud_storage.store_file(
|
||||
content=content,
|
||||
filename=file_name,
|
||||
provider=provider,
|
||||
expiration_hours=expiration_hours,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@@ -51,7 +51,7 @@ def test_get_or_create_user_route(
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_or_create_user",
|
||||
"backend.api.features.v1.get_or_activate_user",
|
||||
return_value=mock_user,
|
||||
)
|
||||
|
||||
@@ -515,7 +515,6 @@ async def test_upload_file_success(test_user_id: str):
|
||||
result = await upload_file(
|
||||
file=upload_file_mock,
|
||||
user_id=test_user_id,
|
||||
provider="gcs",
|
||||
expiration_hours=24,
|
||||
)
|
||||
|
||||
@@ -533,7 +532,6 @@ async def test_upload_file_success(test_user_id: str):
|
||||
mock_handler.store_file.assert_called_once_with(
|
||||
content=file_content,
|
||||
filename="test.txt",
|
||||
provider="gcs",
|
||||
expiration_hours=24,
|
||||
user_id=test_user_id,
|
||||
)
|
||||
|
||||
@@ -19,6 +19,7 @@ from prisma.errors import PrismaError
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.admin.user_admin_routes
|
||||
import backend.api.features.builder
|
||||
import backend.api.features.builder.routes
|
||||
import backend.api.features.chat.routes as chat_routes
|
||||
@@ -55,6 +56,7 @@ from backend.util.exceptions import (
|
||||
MissingConfigError,
|
||||
NotAuthorizedError,
|
||||
NotFoundError,
|
||||
PreconditionFailed,
|
||||
)
|
||||
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
||||
from backend.util.service import UnhealthyServiceError
|
||||
@@ -275,6 +277,7 @@ app.add_exception_handler(RequestValidationError, validation_error_handler)
|
||||
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
|
||||
app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
|
||||
app.add_exception_handler(ValueError, handle_internal_http_error(400))
|
||||
app.add_exception_handler(PreconditionFailed, handle_internal_http_error(428))
|
||||
app.add_exception_handler(Exception, handle_internal_http_error(500))
|
||||
|
||||
app.include_router(backend.api.features.v1.v1_router, tags=["v1"], prefix="/api")
|
||||
@@ -309,6 +312,11 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/executions",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.user_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/users",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.executions.review.routes.router,
|
||||
tags=["v2", "executions", "review"],
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.store.db import StoreAgentsSortOptions
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -176,8 +176,8 @@ class SearchStoreAgentsBlock(Block):
|
||||
category: str | None = SchemaField(
|
||||
description="Filter by category", default=None
|
||||
)
|
||||
sort_by: Literal["rating", "runs", "name", "updated_at"] = SchemaField(
|
||||
description="How to sort the results", default="rating"
|
||||
sort_by: StoreAgentsSortOptions = SchemaField(
|
||||
description="How to sort the results", default=StoreAgentsSortOptions.RATING
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of results to return", default=10, ge=1, le=100
|
||||
@@ -278,7 +278,7 @@ class SearchStoreAgentsBlock(Block):
|
||||
self,
|
||||
query: str | None = None,
|
||||
category: str | None = None,
|
||||
sort_by: Literal["rating", "runs", "name", "updated_at"] = "rating",
|
||||
sort_by: StoreAgentsSortOptions = StoreAgentsSortOptions.RATING,
|
||||
limit: int = 10,
|
||||
) -> SearchAgentsResponse:
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.store.db import StoreAgentsSortOptions
|
||||
from backend.blocks.system.library_operations import (
|
||||
AddToLibraryFromStoreBlock,
|
||||
LibraryAgent,
|
||||
@@ -121,7 +122,10 @@ async def test_search_store_agents_block(mocker):
|
||||
)
|
||||
|
||||
input_data = block.Input(
|
||||
query="test", category="productivity", sort_by="rating", limit=10
|
||||
query="test",
|
||||
category="productivity",
|
||||
sort_by=StoreAgentsSortOptions.RATING, # type: ignore[reportArgumentType]
|
||||
limit=10,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
|
||||
@@ -22,6 +22,7 @@ from backend.copilot.model import (
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
@@ -62,8 +63,8 @@ async def _update_title_async(
|
||||
"""Generate and persist a session title in the background."""
|
||||
try:
|
||||
title = await _generate_session_title(message, user_id, session_id)
|
||||
if title:
|
||||
await update_session_title(session_id, title)
|
||||
if title and user_id:
|
||||
await update_session_title(session_id, user_id, title, only_if_empty=True)
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Failed to update session title: %s", e)
|
||||
|
||||
@@ -176,14 +177,17 @@ async def stream_chat_completion_baseline(
|
||||
# changes from concurrent chats updating business understanding.
|
||||
is_first_turn = len(session.messages) <= 1
|
||||
if is_first_turn:
|
||||
system_prompt, _ = await _build_system_prompt(
|
||||
base_system_prompt, _ = await _build_system_prompt(
|
||||
user_id, has_conversation_history=False
|
||||
)
|
||||
else:
|
||||
system_prompt, _ = await _build_system_prompt(
|
||||
base_system_prompt, _ = await _build_system_prompt(
|
||||
user_id=None, has_conversation_history=True
|
||||
)
|
||||
|
||||
# Append tool documentation and technical notes
|
||||
system_prompt = base_system_prompt + get_baseline_supplement()
|
||||
|
||||
# Compress context if approaching the model's token limit
|
||||
messages_for_context = await _compress_session_messages(session.messages)
|
||||
|
||||
|
||||
@@ -81,6 +81,35 @@ async def update_chat_session(
|
||||
return ChatSession.from_db(session) if session else None
|
||||
|
||||
|
||||
async def update_chat_session_title(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
title: str,
|
||||
*,
|
||||
only_if_empty: bool = False,
|
||||
) -> bool:
|
||||
"""Update the title of a chat session, scoped to the owning user.
|
||||
|
||||
Always filters by (session_id, user_id) so callers cannot mutate another
|
||||
user's session even when they know the session_id.
|
||||
|
||||
Args:
|
||||
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
|
||||
guard so auto-generated titles never overwrite a user-set title.
|
||||
|
||||
Returns True if a row was updated, False otherwise (session not found,
|
||||
wrong user, or — when only_if_empty — title was already set).
|
||||
"""
|
||||
where: ChatSessionWhereInput = {"id": session_id, "userId": user_id}
|
||||
if only_if_empty:
|
||||
where["title"] = None
|
||||
result = await PrismaChatSession.prisma().update_many(
|
||||
where=where,
|
||||
data={"title": title, "updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
return result > 0
|
||||
|
||||
|
||||
async def add_chat_message(
|
||||
session_id: str,
|
||||
role: str,
|
||||
|
||||
@@ -469,8 +469,16 @@ async def upsert_chat_session(
|
||||
)
|
||||
db_error = e
|
||||
|
||||
# Save to cache (best-effort, even if DB failed)
|
||||
# Save to cache (best-effort, even if DB failed).
|
||||
# Title updates (update_session_title) run *outside* this lock because
|
||||
# they only touch the title field, not messages. So a concurrent rename
|
||||
# or auto-title may have written a newer title to Redis while this
|
||||
# upsert was in progress. Always prefer the cached title to avoid
|
||||
# overwriting it with the stale in-memory copy.
|
||||
try:
|
||||
existing_cached = await _get_session_from_cache(session.session_id)
|
||||
if existing_cached and existing_cached.title:
|
||||
session = session.model_copy(update={"title": existing_cached.title})
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
# If DB succeeded but cache failed, raise cache error
|
||||
@@ -685,30 +693,48 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
return True
|
||||
|
||||
|
||||
async def update_session_title(session_id: str, title: str) -> bool:
|
||||
"""Update only the title of a chat session.
|
||||
async def update_session_title(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
title: str,
|
||||
*,
|
||||
only_if_empty: bool = False,
|
||||
) -> bool:
|
||||
"""Update the title of a chat session, scoped to the owning user.
|
||||
|
||||
This is a lightweight operation that doesn't touch messages, avoiding
|
||||
race conditions with concurrent message updates. Use this for background
|
||||
title generation instead of upsert_chat_session.
|
||||
Lightweight operation that doesn't touch messages, avoiding race conditions
|
||||
with concurrent message updates.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to update.
|
||||
user_id: Owning user — the DB query filters on this.
|
||||
title: The new title to set.
|
||||
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
|
||||
so auto-generated titles never overwrite a user-set title.
|
||||
|
||||
Returns:
|
||||
True if updated successfully, False otherwise.
|
||||
True if updated successfully, False otherwise (not found, wrong user,
|
||||
or — when only_if_empty — title was already set).
|
||||
"""
|
||||
try:
|
||||
result = await chat_db().update_chat_session(session_id=session_id, title=title)
|
||||
if result is None:
|
||||
logger.warning(f"Session {session_id} not found for title update")
|
||||
updated = await chat_db().update_chat_session_title(
|
||||
session_id, user_id, title, only_if_empty=only_if_empty
|
||||
)
|
||||
if not updated:
|
||||
return False
|
||||
|
||||
# Invalidate the cache so the next access reloads from DB with the
|
||||
# updated title. This avoids a read-modify-write on the full session
|
||||
# blob, which could overwrite concurrent message updates.
|
||||
await invalidate_session_cache(session_id)
|
||||
# Update title in cache if it exists (instead of invalidating).
|
||||
# This prevents race conditions where cache invalidation causes
|
||||
# the frontend to see stale DB data while streaming is still in progress.
|
||||
try:
|
||||
cached = await _get_session_from_cache(session_id)
|
||||
if cached:
|
||||
cached.title = title
|
||||
await cache_chat_session(cached)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Cache title update failed for session {session_id} (non-critical): {e}"
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
|
||||
191
autogpt_platform/backend/backend/copilot/prompting.py
Normal file
191
autogpt_platform/backend/backend/copilot/prompting.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Centralized prompt building logic for CoPilot.
|
||||
|
||||
This module contains all prompt construction functions and constants,
|
||||
handling the distinction between:
|
||||
- SDK mode vs Baseline mode (tool documentation needs)
|
||||
- Local mode vs E2B mode (storage/filesystem differences)
|
||||
"""
|
||||
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
# Shared technical notes that apply to both SDK and baseline modes
|
||||
_SHARED_TOOL_NOTES = """\
|
||||
|
||||
### Sharing files with the user
|
||||
After saving a file to the persistent workspace with `write_workspace_file`,
|
||||
share it with the user by embedding the `download_url` from the response in
|
||||
your message as a Markdown link or image:
|
||||
|
||||
- **Any file** — shows as a clickable download link:
|
||||
`[report.csv](workspace://file_id#text/csv)`
|
||||
- **Image** — renders inline in chat:
|
||||
``
|
||||
- **Video** — renders inline in chat with player controls:
|
||||
``
|
||||
|
||||
The `download_url` field in the `write_workspace_file` response is already
|
||||
in the correct format — paste it directly after the `(` in the Markdown.
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
"""
|
||||
|
||||
|
||||
# Environment-specific supplement templates
|
||||
def _build_storage_supplement(
|
||||
working_dir: str,
|
||||
sandbox_type: str,
|
||||
storage_system_1_name: str,
|
||||
storage_system_1_characteristics: list[str],
|
||||
storage_system_1_persistence: list[str],
|
||||
file_move_name_1_to_2: str,
|
||||
file_move_name_2_to_1: str,
|
||||
) -> str:
|
||||
"""Build storage/filesystem supplement for a specific environment.
|
||||
|
||||
Template function handles all formatting (bullets, indentation, markdown).
|
||||
Callers provide clean data as lists of strings.
|
||||
|
||||
Args:
|
||||
working_dir: Working directory path
|
||||
sandbox_type: Description of bash_exec sandbox
|
||||
storage_system_1_name: Name of primary storage (ephemeral or cloud)
|
||||
storage_system_1_characteristics: List of characteristic descriptions
|
||||
storage_system_1_persistence: List of persistence behavior descriptions
|
||||
file_move_name_1_to_2: Direction label for primary→persistent
|
||||
file_move_name_2_to_1: Direction label for persistent→primary
|
||||
"""
|
||||
# Format lists as bullet points with proper indentation
|
||||
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
|
||||
persistence = "\n".join(f" - {p}" for p in storage_system_1_persistence)
|
||||
|
||||
return f"""
|
||||
|
||||
## Tool notes
|
||||
|
||||
### Shell commands
|
||||
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
|
||||
for shell commands — it runs {sandbox_type}.
|
||||
|
||||
### Working directory
|
||||
- Your working directory is: `{working_dir}`
|
||||
- All SDK file tools AND `bash_exec` operate on the same filesystem
|
||||
- Use relative paths or absolute paths under `{working_dir}` for all file operations
|
||||
|
||||
### Two storage systems — CRITICAL to understand
|
||||
|
||||
1. **{storage_system_1_name}** (`{working_dir}`):
|
||||
{characteristics}
|
||||
{persistence}
|
||||
|
||||
2. **Persistent workspace** (cloud storage):
|
||||
- Files here **survive across sessions indefinitely**
|
||||
|
||||
### Moving files between storages
|
||||
- **{file_move_name_1_to_2}**: Copy to persistent workspace
|
||||
- **{file_move_name_2_to_1}**: Download for processing
|
||||
|
||||
### File persistence
|
||||
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
|
||||
{_SHARED_TOOL_NOTES}"""
|
||||
|
||||
|
||||
# Pre-built supplements for common environments
|
||||
def _get_local_storage_supplement(cwd: str) -> str:
|
||||
"""Local ephemeral storage (files lost between turns)."""
|
||||
return _build_storage_supplement(
|
||||
working_dir=cwd,
|
||||
sandbox_type="in a network-isolated sandbox",
|
||||
storage_system_1_name="Ephemeral working directory",
|
||||
storage_system_1_characteristics=[
|
||||
"Shared by SDK Read/Write/Edit/Glob/Grep tools AND `bash_exec`",
|
||||
],
|
||||
storage_system_1_persistence=[
|
||||
"Files here are **lost between turns** — do NOT rely on them persisting",
|
||||
"Use for temporary work: running scripts, processing data, etc.",
|
||||
],
|
||||
file_move_name_1_to_2="Ephemeral → Persistent",
|
||||
file_move_name_2_to_1="Persistent → Ephemeral",
|
||||
)
|
||||
|
||||
|
||||
def _get_cloud_sandbox_supplement() -> str:
|
||||
"""Cloud persistent sandbox (files survive across turns in session)."""
|
||||
return _build_storage_supplement(
|
||||
working_dir="/home/user",
|
||||
sandbox_type="in a cloud sandbox with full internet access",
|
||||
storage_system_1_name="Cloud sandbox",
|
||||
storage_system_1_characteristics=[
|
||||
"Shared by all file tools AND `bash_exec` — same filesystem",
|
||||
"Full Linux environment with internet access",
|
||||
],
|
||||
storage_system_1_persistence=[
|
||||
"Files **persist across turns** within the current session",
|
||||
"Lost when the session expires (12 h inactivity)",
|
||||
],
|
||||
file_move_name_1_to_2="Sandbox → Persistent",
|
||||
file_move_name_2_to_1="Persistent → Sandbox",
|
||||
)
|
||||
|
||||
|
||||
def _generate_tool_documentation() -> str:
|
||||
"""Auto-generate tool documentation from TOOL_REGISTRY.
|
||||
|
||||
NOTE: This is ONLY used in baseline mode (direct OpenAI API).
|
||||
SDK mode doesn't need it since Claude gets tool schemas automatically.
|
||||
|
||||
This generates a complete list of available tools with their descriptions,
|
||||
ensuring the documentation stays in sync with the actual tool implementations.
|
||||
All workflow guidance is now embedded in individual tool descriptions.
|
||||
|
||||
Only documents tools that are available in the current environment
|
||||
(checked via tool.is_available property).
|
||||
"""
|
||||
docs = "\n## AVAILABLE TOOLS\n\n"
|
||||
|
||||
# Sort tools alphabetically for consistent output
|
||||
# Filter by is_available to match get_available_tools() behavior
|
||||
for name in sorted(TOOL_REGISTRY.keys()):
|
||||
tool = TOOL_REGISTRY[name]
|
||||
if not tool.is_available:
|
||||
continue
|
||||
schema = tool.as_openai_tool()
|
||||
desc = schema["function"].get("description", "No description available")
|
||||
# Format as bullet list with tool name in code style
|
||||
docs += f"- **`{name}`**: {desc}\n"
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def get_sdk_supplement(use_e2b: bool, cwd: str = "") -> str:
|
||||
"""Get the supplement for SDK mode (Claude Agent SDK).
|
||||
|
||||
SDK mode does NOT include tool documentation because Claude automatically
|
||||
receives tool schemas from the SDK. Only includes technical notes about
|
||||
storage systems and execution environment.
|
||||
|
||||
Args:
|
||||
use_e2b: Whether E2B cloud sandbox is being used
|
||||
cwd: Current working directory (only used in local_storage mode)
|
||||
|
||||
Returns:
|
||||
The supplement string to append to the system prompt
|
||||
"""
|
||||
if use_e2b:
|
||||
return _get_cloud_sandbox_supplement()
|
||||
return _get_local_storage_supplement(cwd)
|
||||
|
||||
|
||||
def get_baseline_supplement() -> str:
|
||||
"""Get the supplement for baseline mode (direct OpenAI API).
|
||||
|
||||
Baseline mode INCLUDES auto-generated tool documentation because the
|
||||
direct API doesn't automatically provide tool schemas to Claude.
|
||||
Also includes shared technical notes (but NOT SDK-specific environment details).
|
||||
|
||||
Returns:
|
||||
The supplement string to append to the system prompt
|
||||
"""
|
||||
tool_docs = _generate_tool_documentation()
|
||||
return tool_docs + _SHARED_TOOL_NOTES
|
||||
@@ -127,7 +127,6 @@ def create_security_hooks(
|
||||
sdk_cwd: str | None = None,
|
||||
max_subtasks: int = 3,
|
||||
on_compact: Callable[[], None] | None = None,
|
||||
on_stop: Callable[[str, str], None] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create the security hooks configuration for Claude Agent SDK.
|
||||
|
||||
@@ -136,15 +135,12 @@ def create_security_hooks(
|
||||
- PostToolUse: Log successful tool executions
|
||||
- PostToolUseFailure: Log and handle failed tool executions
|
||||
- PreCompact: Log context compaction events (SDK handles compaction automatically)
|
||||
- Stop: Capture transcript path for stateless resume (when *on_stop* is provided)
|
||||
|
||||
Args:
|
||||
user_id: Current user ID for isolation validation
|
||||
sdk_cwd: SDK working directory for workspace-scoped tool validation
|
||||
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
|
||||
on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when
|
||||
the SDK finishes processing — used to read the JSONL transcript
|
||||
before the CLI process exits.
|
||||
on_compact: Callback invoked when SDK starts compacting context.
|
||||
|
||||
Returns:
|
||||
Hooks configuration dict for ClaudeAgentOptions
|
||||
@@ -311,30 +307,6 @@ def create_security_hooks(
|
||||
on_compact()
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
# --- Stop hook: capture transcript path for stateless resume ---
|
||||
async def stop_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Capture transcript path when SDK finishes processing.
|
||||
|
||||
The Stop hook fires while the CLI process is still alive, giving us
|
||||
a reliable window to read the JSONL transcript before SIGTERM.
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
transcript_path = cast(str, input_data.get("transcript_path", ""))
|
||||
sdk_session_id = cast(str, input_data.get("session_id", ""))
|
||||
|
||||
if transcript_path and on_stop:
|
||||
logger.info(
|
||||
f"[SDK] Stop hook: transcript_path={transcript_path}, "
|
||||
f"sdk_session_id={sdk_session_id[:12]}..."
|
||||
)
|
||||
on_stop(transcript_path, sdk_session_id)
|
||||
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
hooks: dict[str, Any] = {
|
||||
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
|
||||
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
|
||||
@@ -344,9 +316,6 @@ def create_security_hooks(
|
||||
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
|
||||
}
|
||||
|
||||
if on_stop is not None:
|
||||
hooks["Stop"] = [HookMatcher(matcher=None, hooks=[stop_hook])]
|
||||
|
||||
return hooks
|
||||
except ImportError:
|
||||
# Fallback for when SDK isn't available - return empty hooks
|
||||
|
||||
@@ -12,7 +12,6 @@ import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
import openai
|
||||
@@ -21,6 +20,9 @@ from claude_agent_sdk import (
|
||||
ClaudeAgentOptions,
|
||||
ClaudeSDKClient,
|
||||
ResultMessage,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
)
|
||||
from langfuse import propagate_attributes
|
||||
@@ -42,6 +44,7 @@ from ..model import (
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from ..prompting import get_sdk_supplement
|
||||
from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
@@ -74,11 +77,11 @@ from .tool_adapter import (
|
||||
from .transcript import (
|
||||
cleanup_cli_project_dir,
|
||||
download_transcript,
|
||||
read_transcript_file,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
)
|
||||
from .transcript_builder import TranscriptBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
@@ -137,19 +140,6 @@ _setup_langfuse_otel()
|
||||
_background_tasks: set[asyncio.Task[Any]] = set()
|
||||
|
||||
|
||||
@dataclass
|
||||
class CapturedTranscript:
|
||||
"""Info captured by the SDK Stop hook for stateless --resume."""
|
||||
|
||||
path: str = ""
|
||||
sdk_session_id: str = ""
|
||||
raw_content: str = ""
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
return bool(self.path)
|
||||
|
||||
|
||||
_SDK_CWD_PREFIX = WORKSPACE_PREFIX
|
||||
|
||||
# Heartbeat interval — keep SSE alive through proxies/LBs during tool execution.
|
||||
@@ -157,140 +147,6 @@ _SDK_CWD_PREFIX = WORKSPACE_PREFIX
|
||||
_HEARTBEAT_INTERVAL = 10.0 # seconds
|
||||
|
||||
|
||||
# Appended to the system prompt to inform the agent about available tools.
|
||||
# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead,
|
||||
# which has kernel-level network isolation (unshare --net).
|
||||
_SHARED_TOOL_NOTES = """\
|
||||
|
||||
### Sharing files with the user
|
||||
After saving a file to the persistent workspace with `write_workspace_file`,
|
||||
share it with the user by embedding the `download_url` from the response in
|
||||
your message as a Markdown link or image:
|
||||
|
||||
- **Any file** — shows as a clickable download link:
|
||||
`[report.csv](workspace://file_id#text/csv)`
|
||||
- **Image** — renders inline in chat:
|
||||
``
|
||||
- **Video** — renders inline in chat with player controls:
|
||||
``
|
||||
|
||||
The `download_url` field in the `write_workspace_file` response is already
|
||||
in the correct format — paste it directly after the `(` in the Markdown.
|
||||
|
||||
### Long-running tools
|
||||
Long-running tools (create_agent, edit_agent, etc.) are handled
|
||||
asynchronously. You will receive an immediate response; the actual result
|
||||
is delivered to the user via a background stream.
|
||||
|
||||
### Large tool outputs
|
||||
When a tool output exceeds the display limit, it is automatically saved to
|
||||
the persistent workspace. The truncated output includes a
|
||||
`<tool-output-truncated>` tag with the workspace path. Use
|
||||
`read_workspace_file(path="...", offset=N, length=50000)` to retrieve
|
||||
additional sections.
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
"""
|
||||
|
||||
|
||||
_LOCAL_TOOL_SUPPLEMENT = (
|
||||
"""
|
||||
|
||||
## Tool notes
|
||||
|
||||
### Shell commands
|
||||
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
|
||||
for shell commands — it runs in a network-isolated sandbox.
|
||||
|
||||
### Working directory
|
||||
- Your working directory is: `{cwd}`
|
||||
- All SDK Read/Write/Edit/Glob/Grep tools AND `bash_exec` operate inside this
|
||||
directory. This is the ONLY writable path — do not attempt to read or write
|
||||
anywhere else on the filesystem.
|
||||
- Use relative paths or absolute paths under `{cwd}` for all file operations.
|
||||
|
||||
### Two storage systems — CRITICAL to understand
|
||||
|
||||
1. **Ephemeral working directory** (`{cwd}`):
|
||||
- Shared by SDK Read/Write/Edit/Glob/Grep tools AND `bash_exec`
|
||||
- Files here are **lost between turns** — do NOT rely on them persisting
|
||||
- Use for temporary work: running scripts, processing data, etc.
|
||||
|
||||
2. **Persistent workspace** (cloud storage):
|
||||
- Files here **survive across turns and sessions**
|
||||
- Use `write_workspace_file` to save important files (code, outputs, configs)
|
||||
- Use `read_workspace_file` to retrieve previously saved files
|
||||
- Use `list_workspace_files` to see what files you've saved before
|
||||
- Call `list_workspace_files(include_all_sessions=True)` to see files from
|
||||
all sessions
|
||||
|
||||
### Moving files between ephemeral and persistent storage
|
||||
- **Ephemeral → Persistent**: Use `write_workspace_file` with either:
|
||||
- `content` param (plain text) — for text files
|
||||
- `source_path` param — to copy any file directly from the ephemeral dir
|
||||
- **Persistent → Ephemeral**: Use `read_workspace_file` with `save_to_path`
|
||||
param to download a workspace file to the ephemeral dir for processing
|
||||
|
||||
### File persistence workflow
|
||||
When you create or modify important files (code, configs, outputs), you MUST:
|
||||
1. Save them using `write_workspace_file` so they persist
|
||||
2. At the start of a new turn, call `list_workspace_files` to see what files
|
||||
are available from previous turns
|
||||
"""
|
||||
+ _SHARED_TOOL_NOTES
|
||||
)
|
||||
|
||||
|
||||
_E2B_TOOL_SUPPLEMENT = (
|
||||
"""
|
||||
|
||||
## Tool notes
|
||||
|
||||
### Shell commands
|
||||
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
|
||||
for shell commands — it runs in a cloud sandbox with full internet access.
|
||||
|
||||
### Working directory
|
||||
- Your working directory is: `/home/user` (cloud sandbox)
|
||||
- All file tools (`read_file`, `write_file`, `edit_file`, `glob`, `grep`)
|
||||
AND `bash_exec` operate on the **same cloud sandbox filesystem**.
|
||||
- Files created by `bash_exec` are immediately visible to `read_file` and
|
||||
vice-versa — they share one filesystem.
|
||||
- Use relative paths (resolved from `/home/user`) or absolute paths.
|
||||
|
||||
### Two storage systems — CRITICAL to understand
|
||||
|
||||
1. **Cloud sandbox** (`/home/user`):
|
||||
- Shared by all file tools AND `bash_exec` — same filesystem
|
||||
- Files **persist across turns** within the current session
|
||||
- Full Linux environment with internet access
|
||||
- Lost when the session expires (12 h inactivity)
|
||||
|
||||
2. **Persistent workspace** (cloud storage):
|
||||
- Files here **survive across sessions indefinitely**
|
||||
- Use `write_workspace_file` to save important files permanently
|
||||
- Use `read_workspace_file` to retrieve previously saved files
|
||||
- Use `list_workspace_files` to see what files you've saved before
|
||||
- Call `list_workspace_files(include_all_sessions=True)` to see files from
|
||||
all sessions
|
||||
|
||||
### Moving files between sandbox and persistent storage
|
||||
- **Sandbox → Persistent**: Use `write_workspace_file` with `source_path`
|
||||
to copy from the sandbox to permanent storage
|
||||
- **Persistent → Sandbox**: Use `read_workspace_file` with `save_to_path`
|
||||
to download into the sandbox for processing
|
||||
|
||||
### File persistence workflow
|
||||
Important files that must survive beyond this session should be saved with
|
||||
`write_workspace_file`. Sandbox files persist across turns but are lost
|
||||
when the session expires.
|
||||
"""
|
||||
+ _SHARED_TOOL_NOTES
|
||||
)
|
||||
|
||||
|
||||
STREAM_LOCK_PREFIX = "copilot:stream:lock:"
|
||||
|
||||
|
||||
@@ -451,6 +307,50 @@ def _cleanup_sdk_tool_results(cwd: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
|
||||
"""Convert SDK content blocks to transcript format.
|
||||
|
||||
Handles TextBlock, ToolUseBlock, ToolResultBlock, and ThinkingBlock.
|
||||
Unknown block types are logged and skipped.
|
||||
"""
|
||||
result: list[dict[str, Any]] = []
|
||||
for block in blocks or []:
|
||||
if isinstance(block, TextBlock):
|
||||
result.append({"type": "text", "text": block.text})
|
||||
elif isinstance(block, ToolUseBlock):
|
||||
result.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"input": block.input,
|
||||
}
|
||||
)
|
||||
elif isinstance(block, ToolResultBlock):
|
||||
tool_result_entry: dict[str, Any] = {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": block.tool_use_id,
|
||||
"content": block.content,
|
||||
}
|
||||
if block.is_error:
|
||||
tool_result_entry["is_error"] = True
|
||||
result.append(tool_result_entry)
|
||||
elif isinstance(block, ThinkingBlock):
|
||||
result.append(
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": block.thinking,
|
||||
"signature": block.signature,
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SDK] Unknown content block type: {type(block).__name__}. "
|
||||
f"This may indicate a new SDK version with additional block types."
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def _compress_messages(
|
||||
messages: list[ChatMessage],
|
||||
) -> tuple[list[ChatMessage], bool]:
|
||||
@@ -806,6 +706,11 @@ async def stream_chat_completion_sdk(
|
||||
user_id=user_id, session_id=session_id, message_length=len(message)
|
||||
)
|
||||
|
||||
# Structured log prefix: [SDK][<session>][T<turn>]
|
||||
# Turn = number of user messages (1-based), computed AFTER appending the new message.
|
||||
turn = sum(1 for m in session.messages if m.role == "user")
|
||||
log_prefix = f"[SDK][{session_id[:12]}][T{turn}]"
|
||||
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Generate title for new sessions (first user message)
|
||||
@@ -823,10 +728,11 @@ async def stream_chat_completion_sdk(
|
||||
message_id = str(uuid.uuid4())
|
||||
stream_id = str(uuid.uuid4())
|
||||
stream_completed = False
|
||||
ended_with_stream_error = False
|
||||
e2b_sandbox = None
|
||||
use_resume = False
|
||||
resume_file: str | None = None
|
||||
captured_transcript = CapturedTranscript()
|
||||
transcript_builder = TranscriptBuilder()
|
||||
sdk_cwd = ""
|
||||
|
||||
# Acquire stream lock to prevent concurrent streams to the same session
|
||||
@@ -841,7 +747,7 @@ async def stream_chat_completion_sdk(
|
||||
if lock_owner != stream_id:
|
||||
# Another stream is active
|
||||
logger.warning(
|
||||
f"[SDK] Session {session_id} already has an active stream: {lock_owner}"
|
||||
f"{log_prefix} Session already has an active stream: {lock_owner}"
|
||||
)
|
||||
yield StreamError(
|
||||
errorText="Another stream is already active for this session. "
|
||||
@@ -865,7 +771,7 @@ async def stream_chat_completion_sdk(
|
||||
sdk_cwd = _make_sdk_cwd(session_id)
|
||||
os.makedirs(sdk_cwd, exist_ok=True)
|
||||
except (ValueError, OSError) as e:
|
||||
logger.error("[SDK] [%s] Invalid SDK cwd: %s", session_id[:12], e)
|
||||
logger.error("%s Invalid SDK cwd: %s", log_prefix, e)
|
||||
yield StreamError(
|
||||
errorText="Unable to initialize working directory.",
|
||||
code="sdk_cwd_error",
|
||||
@@ -909,12 +815,13 @@ async def stream_chat_completion_sdk(
|
||||
):
|
||||
return None
|
||||
try:
|
||||
return await download_transcript(user_id, session_id)
|
||||
return await download_transcript(
|
||||
user_id, session_id, log_prefix=log_prefix
|
||||
)
|
||||
except Exception as transcript_err:
|
||||
logger.warning(
|
||||
"[SDK] [%s] Transcript download failed, continuing without "
|
||||
"--resume: %s",
|
||||
session_id[:12],
|
||||
"%s Transcript download failed, continuing without " "--resume: %s",
|
||||
log_prefix,
|
||||
transcript_err,
|
||||
)
|
||||
return None
|
||||
@@ -926,21 +833,27 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
|
||||
use_e2b = e2b_sandbox is not None
|
||||
system_prompt = base_system_prompt + (
|
||||
_E2B_TOOL_SUPPLEMENT
|
||||
if use_e2b
|
||||
else _LOCAL_TOOL_SUPPLEMENT.format(cwd=sdk_cwd)
|
||||
# Append appropriate supplement (Claude gets tool schemas automatically)
|
||||
system_prompt = base_system_prompt + get_sdk_supplement(
|
||||
use_e2b=use_e2b, cwd=sdk_cwd
|
||||
)
|
||||
|
||||
# Process transcript download result
|
||||
transcript_msg_count = 0
|
||||
if dl:
|
||||
is_valid = validate_transcript(dl.content)
|
||||
dl_lines = dl.content.strip().split("\n") if dl.content else []
|
||||
logger.info(
|
||||
"%s Downloaded transcript: %dB, %d lines, " "msg_count=%d, valid=%s",
|
||||
log_prefix,
|
||||
len(dl.content),
|
||||
len(dl_lines),
|
||||
dl.message_count,
|
||||
is_valid,
|
||||
)
|
||||
if is_valid:
|
||||
logger.info(
|
||||
f"[SDK] Transcript available for session {session_id}: "
|
||||
f"{len(dl.content)}B, msg_count={dl.message_count}"
|
||||
)
|
||||
# Load previous FULL context into builder
|
||||
transcript_builder.load_previous(dl.content, log_prefix=log_prefix)
|
||||
resume_file = write_transcript_to_tempfile(
|
||||
dl.content, session_id, sdk_cwd
|
||||
)
|
||||
@@ -948,16 +861,14 @@ async def stream_chat_completion_sdk(
|
||||
use_resume = True
|
||||
transcript_msg_count = dl.message_count
|
||||
logger.debug(
|
||||
f"[SDK] Using --resume ({len(dl.content)}B, "
|
||||
f"{log_prefix} Using --resume ({len(dl.content)}B, "
|
||||
f"msg_count={transcript_msg_count})"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SDK] Transcript downloaded but invalid for {session_id}"
|
||||
)
|
||||
logger.warning(f"{log_prefix} Transcript downloaded but invalid")
|
||||
elif config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
||||
logger.warning(
|
||||
f"[SDK] No transcript available for {session_id} "
|
||||
f"{log_prefix} No transcript available "
|
||||
f"({len(session.messages)} messages in session)"
|
||||
)
|
||||
|
||||
@@ -979,25 +890,6 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
sdk_model = _resolve_sdk_model()
|
||||
|
||||
# --- Transcript capture via Stop hook ---
|
||||
# Read the file content immediately — the SDK may clean up
|
||||
# the file before our finally block runs.
|
||||
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
|
||||
captured_transcript.path = transcript_path
|
||||
captured_transcript.sdk_session_id = sdk_session_id
|
||||
content = read_transcript_file(transcript_path)
|
||||
if content:
|
||||
captured_transcript.raw_content = content
|
||||
logger.info(
|
||||
f"[SDK] Stop hook: captured {len(content)}B from "
|
||||
f"{transcript_path}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SDK] Stop hook: transcript file empty/missing at "
|
||||
f"{transcript_path}"
|
||||
)
|
||||
|
||||
# Track SDK-internal compaction (PreCompact hook → start, next msg → end)
|
||||
compaction = CompactionTracker()
|
||||
|
||||
@@ -1005,7 +897,6 @@ async def stream_chat_completion_sdk(
|
||||
user_id,
|
||||
sdk_cwd=sdk_cwd,
|
||||
max_subtasks=config.claude_agent_max_subtasks,
|
||||
on_stop=_on_stop if config.claude_agent_use_resume else None,
|
||||
on_compact=compaction.on_compact,
|
||||
)
|
||||
|
||||
@@ -1040,7 +931,10 @@ async def stream_chat_completion_sdk(
|
||||
session_id=session_id,
|
||||
trace_name="copilot-sdk",
|
||||
tags=["sdk"],
|
||||
metadata={"resume": str(use_resume)},
|
||||
metadata={
|
||||
"resume": str(use_resume),
|
||||
"conversation_turn": str(turn),
|
||||
},
|
||||
)
|
||||
_otel_ctx.__enter__()
|
||||
|
||||
@@ -1074,9 +968,9 @@ async def stream_chat_completion_sdk(
|
||||
query_message = f"{query_message}\n\n{attachments.hint}"
|
||||
|
||||
logger.info(
|
||||
"[SDK] [%s] Sending query — resume=%s, total_msgs=%d, "
|
||||
"%s Sending query — resume=%s, total_msgs=%d, "
|
||||
"query_len=%d, attached_files=%d, image_blocks=%d",
|
||||
session_id[:12],
|
||||
log_prefix,
|
||||
use_resume,
|
||||
len(session.messages),
|
||||
len(query_message),
|
||||
@@ -1105,15 +999,19 @@ async def stream_chat_completion_sdk(
|
||||
await client._transport.write( # noqa: SLF001
|
||||
json.dumps(user_msg) + "\n"
|
||||
)
|
||||
# Capture user message in transcript (multimodal)
|
||||
transcript_builder.append_user(content=content_blocks)
|
||||
else:
|
||||
await client.query(query_message, session_id=session_id)
|
||||
# Capture actual user message in transcript (not the engineered query)
|
||||
# query_message may include context wrappers, but transcript needs raw input
|
||||
transcript_builder.append_user(content=current_message)
|
||||
|
||||
assistant_response = ChatMessage(role="assistant", content="")
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
has_appended_assistant = False
|
||||
has_tool_results = False
|
||||
ended_with_stream_error = False
|
||||
|
||||
# Use an explicit async iterator with non-cancelling heartbeats.
|
||||
# CRITICAL: we must NOT cancel __anext__() mid-flight — doing so
|
||||
# (via asyncio.timeout or wait_for) corrupts the SDK's internal
|
||||
@@ -1150,8 +1048,8 @@ async def stream_chat_completion_sdk(
|
||||
sdk_msg = done.pop().result()
|
||||
except StopAsyncIteration:
|
||||
logger.info(
|
||||
"[SDK] [%s] Stream ended normally (StopAsyncIteration)",
|
||||
session_id[:12],
|
||||
"%s Stream ended normally (StopAsyncIteration)",
|
||||
log_prefix,
|
||||
)
|
||||
break
|
||||
except Exception as stream_err:
|
||||
@@ -1160,8 +1058,8 @@ async def stream_chat_completion_sdk(
|
||||
# so the session can still be saved and the
|
||||
# frontend gets a clean finish.
|
||||
logger.error(
|
||||
"[SDK] [%s] Stream error from SDK: %s",
|
||||
session_id[:12],
|
||||
"%s Stream error from SDK: %s",
|
||||
log_prefix,
|
||||
stream_err,
|
||||
exc_info=True,
|
||||
)
|
||||
@@ -1173,9 +1071,9 @@ async def stream_chat_completion_sdk(
|
||||
break
|
||||
|
||||
logger.info(
|
||||
"[SDK] [%s] Received: %s %s "
|
||||
"%s Received: %s %s "
|
||||
"(unresolved=%d, current=%d, resolved=%d)",
|
||||
session_id[:12],
|
||||
log_prefix,
|
||||
type(sdk_msg).__name__,
|
||||
getattr(sdk_msg, "subtype", ""),
|
||||
len(adapter.current_tool_calls)
|
||||
@@ -1210,10 +1108,10 @@ async def stream_chat_completion_sdk(
|
||||
await asyncio.sleep(0)
|
||||
else:
|
||||
logger.warning(
|
||||
"[SDK] [%s] Timed out waiting for "
|
||||
"%s Timed out waiting for "
|
||||
"PostToolUse hook stash "
|
||||
"(%d unresolved tool calls)",
|
||||
session_id[:12],
|
||||
log_prefix,
|
||||
len(adapter.current_tool_calls)
|
||||
- len(adapter.resolved_tool_calls),
|
||||
)
|
||||
@@ -1221,9 +1119,9 @@ async def stream_chat_completion_sdk(
|
||||
# Log ResultMessage details for debugging
|
||||
if isinstance(sdk_msg, ResultMessage):
|
||||
logger.info(
|
||||
"[SDK] [%s] Received: ResultMessage %s "
|
||||
"%s Received: ResultMessage %s "
|
||||
"(unresolved=%d, current=%d, resolved=%d)",
|
||||
session_id[:12],
|
||||
log_prefix,
|
||||
sdk_msg.subtype,
|
||||
len(adapter.current_tool_calls)
|
||||
- len(adapter.resolved_tool_calls),
|
||||
@@ -1232,8 +1130,8 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
if sdk_msg.subtype in ("error", "error_during_execution"):
|
||||
logger.error(
|
||||
"[SDK] [%s] SDK execution failed with error: %s",
|
||||
session_id[:12],
|
||||
"%s SDK execution failed with error: %s",
|
||||
log_prefix,
|
||||
sdk_msg.result or "(no error message provided)",
|
||||
)
|
||||
|
||||
@@ -1258,8 +1156,8 @@ async def stream_chat_completion_sdk(
|
||||
out_len = len(str(response.output))
|
||||
extra = f", output_len={out_len}"
|
||||
logger.info(
|
||||
"[SDK] [%s] Tool event: %s, tool=%s%s",
|
||||
session_id[:12],
|
||||
"%s Tool event: %s, tool=%s%s",
|
||||
log_prefix,
|
||||
type(response).__name__,
|
||||
getattr(response, "toolName", "N/A"),
|
||||
extra,
|
||||
@@ -1268,8 +1166,8 @@ async def stream_chat_completion_sdk(
|
||||
# Log errors being sent to frontend
|
||||
if isinstance(response, StreamError):
|
||||
logger.error(
|
||||
"[SDK] [%s] Sending error to frontend: %s (code=%s)",
|
||||
session_id[:12],
|
||||
"%s Sending error to frontend: %s (code=%s)",
|
||||
log_prefix,
|
||||
response.errorText,
|
||||
response.code,
|
||||
)
|
||||
@@ -1314,29 +1212,44 @@ async def stream_chat_completion_sdk(
|
||||
has_appended_assistant = True
|
||||
|
||||
elif isinstance(response, StreamToolOutputAvailable):
|
||||
content = (
|
||||
response.output
|
||||
if isinstance(response.output, str)
|
||||
else json.dumps(response.output, ensure_ascii=False)
|
||||
)
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=(
|
||||
response.output
|
||||
if isinstance(response.output, str)
|
||||
else str(response.output)
|
||||
),
|
||||
content=content,
|
||||
tool_call_id=response.toolCallId,
|
||||
)
|
||||
)
|
||||
transcript_builder.append_tool_result(
|
||||
tool_use_id=response.toolCallId,
|
||||
content=content,
|
||||
)
|
||||
has_tool_results = True
|
||||
|
||||
elif isinstance(response, StreamFinish):
|
||||
stream_completed = True
|
||||
|
||||
# Append assistant entry AFTER convert_message so that
|
||||
# any stashed tool results from the previous turn are
|
||||
# recorded first, preserving the required API order:
|
||||
# assistant(tool_use) → tool_result → assistant(text).
|
||||
if isinstance(sdk_msg, AssistantMessage):
|
||||
transcript_builder.append_assistant(
|
||||
content_blocks=_format_sdk_content_blocks(sdk_msg.content),
|
||||
model=sdk_msg.model,
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Task/generator was cancelled (e.g. client disconnect,
|
||||
# server shutdown). Log and let the safety-net / finally
|
||||
# blocks handle cleanup.
|
||||
logger.warning(
|
||||
"[SDK] [%s] Streaming loop cancelled (asyncio.CancelledError)",
|
||||
session_id[:12],
|
||||
"%s Streaming loop cancelled (asyncio.CancelledError)",
|
||||
log_prefix,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
@@ -1350,7 +1263,8 @@ async def stream_chat_completion_sdk(
|
||||
except (asyncio.CancelledError, StopAsyncIteration):
|
||||
# Expected: task was cancelled or exhausted during cleanup
|
||||
logger.info(
|
||||
"[SDK] Pending __anext__ task completed during cleanup"
|
||||
"%s Pending __anext__ task completed during cleanup",
|
||||
log_prefix,
|
||||
)
|
||||
|
||||
# Safety net: if tools are still unresolved after the
|
||||
@@ -1359,9 +1273,9 @@ async def stream_chat_completion_sdk(
|
||||
# them now so the frontend stops showing spinners.
|
||||
if adapter.has_unresolved_tool_calls:
|
||||
logger.warning(
|
||||
"[SDK] [%s] %d unresolved tool(s) after stream loop — "
|
||||
"%s %d unresolved tool(s) after stream loop — "
|
||||
"flushing as safety net",
|
||||
session_id[:12],
|
||||
log_prefix,
|
||||
len(adapter.current_tool_calls) - len(adapter.resolved_tool_calls),
|
||||
)
|
||||
safety_responses: list[StreamBaseResponse] = []
|
||||
@@ -1372,11 +1286,20 @@ async def stream_chat_completion_sdk(
|
||||
(StreamToolInputAvailable, StreamToolOutputAvailable),
|
||||
):
|
||||
logger.info(
|
||||
"[SDK] [%s] Safety flush: %s, tool=%s",
|
||||
session_id[:12],
|
||||
"%s Safety flush: %s, tool=%s",
|
||||
log_prefix,
|
||||
type(response).__name__,
|
||||
getattr(response, "toolName", "N/A"),
|
||||
)
|
||||
if isinstance(response, StreamToolOutputAvailable):
|
||||
transcript_builder.append_tool_result(
|
||||
tool_use_id=response.toolCallId,
|
||||
content=(
|
||||
response.output
|
||||
if isinstance(response.output, str)
|
||||
else json.dumps(response.output, ensure_ascii=False)
|
||||
),
|
||||
)
|
||||
yield response
|
||||
|
||||
# If the stream ended without a ResultMessage, the SDK
|
||||
@@ -1386,8 +1309,8 @@ async def stream_chat_completion_sdk(
|
||||
# StreamFinish is published by mark_session_completed in the processor.
|
||||
if not stream_completed and not ended_with_stream_error:
|
||||
logger.info(
|
||||
"[SDK] [%s] Stream ended without ResultMessage (stopped by user)",
|
||||
session_id[:12],
|
||||
"%s Stream ended without ResultMessage (stopped by user)",
|
||||
log_prefix,
|
||||
)
|
||||
closing_responses: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(closing_responses)
|
||||
@@ -1408,69 +1331,36 @@ async def stream_chat_completion_sdk(
|
||||
) and not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
|
||||
# --- Upload transcript for next-turn --resume ---
|
||||
# After async with the SDK task group has exited, so the Stop
|
||||
# hook has already fired and the CLI has been SIGTERMed. The
|
||||
# CLI uses appendFileSync, so all writes are safely on disk.
|
||||
if config.claude_agent_use_resume and user_id:
|
||||
# With --resume the CLI appends to the resume file (most
|
||||
# complete). Otherwise use the Stop hook path.
|
||||
if use_resume and resume_file:
|
||||
raw_transcript = read_transcript_file(resume_file)
|
||||
logger.debug("[SDK] Transcript source: resume file")
|
||||
elif captured_transcript.path:
|
||||
raw_transcript = read_transcript_file(captured_transcript.path)
|
||||
logger.debug(
|
||||
"[SDK] Transcript source: stop hook (%s), read result: %s",
|
||||
captured_transcript.path,
|
||||
f"{len(raw_transcript)}B" if raw_transcript else "None",
|
||||
)
|
||||
else:
|
||||
raw_transcript = None
|
||||
# Transcript upload is handled exclusively in the finally block
|
||||
# to avoid double-uploads (the success path used to upload the
|
||||
# old resume file, then the finally block overwrote it with the
|
||||
# stop hook content — which could be smaller after compaction).
|
||||
|
||||
if not raw_transcript:
|
||||
logger.debug(
|
||||
"[SDK] No usable transcript — CLI file had no "
|
||||
"conversation entries (expected for first turn "
|
||||
"without --resume)"
|
||||
)
|
||||
|
||||
if raw_transcript:
|
||||
# Shield the upload from generator cancellation so a
|
||||
# client disconnect / page refresh doesn't lose the
|
||||
# transcript. The upload must finish even if the SSE
|
||||
# connection is torn down.
|
||||
await asyncio.shield(
|
||||
_try_upload_transcript(
|
||||
user_id,
|
||||
session_id,
|
||||
raw_transcript,
|
||||
message_count=len(session.messages),
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[SDK] [%s] Stream completed successfully with %d messages",
|
||||
session_id[:12],
|
||||
len(session.messages),
|
||||
)
|
||||
if ended_with_stream_error:
|
||||
logger.warning(
|
||||
"%s Stream ended with SDK error after %d messages",
|
||||
log_prefix,
|
||||
len(session.messages),
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"%s Stream completed successfully with %d messages",
|
||||
log_prefix,
|
||||
len(session.messages),
|
||||
)
|
||||
except BaseException as e:
|
||||
# Catch BaseException to handle both Exception and CancelledError
|
||||
# (CancelledError inherits from BaseException in Python 3.8+)
|
||||
if isinstance(e, asyncio.CancelledError):
|
||||
logger.warning("[SDK] [%s] Session cancelled", session_id[:12])
|
||||
logger.warning("%s Session cancelled", log_prefix)
|
||||
error_msg = "Operation cancelled"
|
||||
else:
|
||||
error_msg = str(e) or type(e).__name__
|
||||
# SDK cleanup RuntimeError is expected during cancellation, log as warning
|
||||
if isinstance(e, RuntimeError) and "cancel scope" in str(e):
|
||||
logger.warning(
|
||||
"[SDK] [%s] SDK cleanup error: %s", session_id[:12], error_msg
|
||||
)
|
||||
logger.warning("%s SDK cleanup error: %s", log_prefix, error_msg)
|
||||
else:
|
||||
logger.error(
|
||||
f"[SDK] [%s] Error: {error_msg}", session_id[:12], exc_info=True
|
||||
)
|
||||
logger.error("%s Error: %s", log_prefix, error_msg, exc_info=True)
|
||||
|
||||
# Append error marker to session (non-invasive text parsing approach)
|
||||
# The finally block will persist the session with this error marker
|
||||
@@ -1481,8 +1371,8 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
"[SDK] [%s] Appended error marker, will be persisted in finally",
|
||||
session_id[:12],
|
||||
"%s Appended error marker, will be persisted in finally",
|
||||
log_prefix,
|
||||
)
|
||||
|
||||
# Yield StreamError for immediate feedback (only for non-cancellation errors)
|
||||
@@ -1514,47 +1404,61 @@ async def stream_chat_completion_sdk(
|
||||
try:
|
||||
await asyncio.shield(upsert_chat_session(session))
|
||||
logger.info(
|
||||
"[SDK] [%s] Session persisted in finally with %d messages",
|
||||
session_id[:12],
|
||||
"%s Session persisted in finally with %d messages",
|
||||
log_prefix,
|
||||
len(session.messages),
|
||||
)
|
||||
except Exception as persist_err:
|
||||
logger.error(
|
||||
"[SDK] [%s] Failed to persist session in finally: %s",
|
||||
session_id[:12],
|
||||
"%s Failed to persist session in finally: %s",
|
||||
log_prefix,
|
||||
persist_err,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# --- Upload transcript for next-turn --resume ---
|
||||
# This MUST run in finally so the transcript is uploaded even when
|
||||
# the streaming loop raises an exception. The CLI uses
|
||||
# appendFileSync, so whatever was written before the error/SIGTERM
|
||||
# is safely on disk and still useful for the next turn.
|
||||
if config.claude_agent_use_resume and user_id:
|
||||
# the streaming loop raises an exception.
|
||||
# The transcript represents the COMPLETE active context (atomic).
|
||||
if config.claude_agent_use_resume and user_id and session is not None:
|
||||
try:
|
||||
# Prefer content captured in the Stop hook (read before
|
||||
# cleanup removes the file). Fall back to the resume
|
||||
# file when the stop hook didn't fire (e.g. error before
|
||||
# completion) so we don't lose the prior transcript.
|
||||
raw_transcript = captured_transcript.raw_content or None
|
||||
if not raw_transcript and use_resume and resume_file:
|
||||
raw_transcript = read_transcript_file(resume_file)
|
||||
# Build complete transcript from captured SDK messages
|
||||
transcript_content = transcript_builder.to_jsonl()
|
||||
|
||||
if raw_transcript and session is not None:
|
||||
await asyncio.shield(
|
||||
_try_upload_transcript(
|
||||
user_id,
|
||||
session_id,
|
||||
raw_transcript,
|
||||
message_count=len(session.messages),
|
||||
)
|
||||
if not transcript_content:
|
||||
logger.warning(
|
||||
"%s No transcript to upload (builder empty)", log_prefix
|
||||
)
|
||||
elif not validate_transcript(transcript_content):
|
||||
logger.warning(
|
||||
"%s Transcript invalid, skipping upload (entries=%d)",
|
||||
log_prefix,
|
||||
transcript_builder.entry_count,
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[SDK] No transcript to upload for {session_id}")
|
||||
logger.info(
|
||||
"%s Uploading complete transcript (entries=%d, bytes=%d)",
|
||||
log_prefix,
|
||||
transcript_builder.entry_count,
|
||||
len(transcript_content),
|
||||
)
|
||||
# Shield upload from cancellation - let it complete even if
|
||||
# the finally block is interrupted. No timeout to avoid race
|
||||
# conditions where backgrounded uploads overwrite newer transcripts.
|
||||
await asyncio.shield(
|
||||
upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
content=transcript_content,
|
||||
message_count=len(session.messages),
|
||||
log_prefix=log_prefix,
|
||||
)
|
||||
)
|
||||
except Exception as upload_err:
|
||||
logger.error(
|
||||
f"[SDK] Transcript upload failed in finally: {upload_err}",
|
||||
"%s Transcript upload failed in finally: %s",
|
||||
log_prefix,
|
||||
upload_err,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
@@ -1565,33 +1469,6 @@ async def stream_chat_completion_sdk(
|
||||
await lock.release()
|
||||
|
||||
|
||||
async def _try_upload_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
raw_content: str,
|
||||
message_count: int = 0,
|
||||
) -> bool:
|
||||
"""Strip progress entries and upload transcript (with timeout).
|
||||
|
||||
Returns True if the upload completed without error.
|
||||
"""
|
||||
try:
|
||||
async with asyncio.timeout(30):
|
||||
await upload_transcript(
|
||||
user_id, session_id, raw_content, message_count=message_count
|
||||
)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[SDK] Transcript upload timed out for {session_id}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[SDK] Failed to upload transcript for {session_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
session_id: str, message: str, user_id: str | None = None
|
||||
) -> None:
|
||||
@@ -1600,8 +1477,8 @@ async def _update_title_async(
|
||||
title = await _generate_session_title(
|
||||
message, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if title:
|
||||
await update_session_title(session_id, title)
|
||||
if title and user_id:
|
||||
await update_session_title(session_id, user_id, title, only_if_empty=True)
|
||||
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[SDK] Failed to update session title: {e}")
|
||||
|
||||
@@ -145,3 +145,103 @@ class TestPrepareFileAttachments:
|
||||
|
||||
assert "Read tool" not in result.hint
|
||||
assert len(result.image_blocks) == 1
|
||||
|
||||
|
||||
class TestPromptSupplement:
|
||||
"""Tests for centralized prompt supplement generation."""
|
||||
|
||||
def test_sdk_supplement_excludes_tool_docs(self):
|
||||
"""SDK mode should NOT include tool documentation (Claude gets schemas automatically)."""
|
||||
from backend.copilot.prompting import get_sdk_supplement
|
||||
|
||||
# Test both local and E2B modes
|
||||
local_supplement = get_sdk_supplement(use_e2b=False, cwd="/tmp/test")
|
||||
e2b_supplement = get_sdk_supplement(use_e2b=True, cwd="")
|
||||
|
||||
# Should NOT have tool list section
|
||||
assert "## AVAILABLE TOOLS" not in local_supplement
|
||||
assert "## AVAILABLE TOOLS" not in e2b_supplement
|
||||
|
||||
# Should still have technical notes
|
||||
assert "## Tool notes" in local_supplement
|
||||
assert "## Tool notes" in e2b_supplement
|
||||
|
||||
def test_baseline_supplement_includes_tool_docs(self):
|
||||
"""Baseline mode MUST include tool documentation (direct API needs it)."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
|
||||
supplement = get_baseline_supplement()
|
||||
|
||||
# MUST have tool list section
|
||||
assert "## AVAILABLE TOOLS" in supplement
|
||||
|
||||
# Should NOT have environment-specific notes (SDK-only)
|
||||
assert "## Tool notes" not in supplement
|
||||
|
||||
def test_baseline_supplement_includes_key_tools(self):
|
||||
"""Baseline supplement should document all essential tools."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
docs = get_baseline_supplement()
|
||||
|
||||
# Core agent workflow tools (always available)
|
||||
assert "`create_agent`" in docs
|
||||
assert "`run_agent`" in docs
|
||||
assert "`find_library_agent`" in docs
|
||||
assert "`edit_agent`" in docs
|
||||
|
||||
# MCP integration (always available)
|
||||
assert "`run_mcp_tool`" in docs
|
||||
|
||||
# Folder management (always available)
|
||||
assert "`create_folder`" in docs
|
||||
|
||||
# Browser tools only if available (Playwright may not be installed in CI)
|
||||
if (
|
||||
TOOL_REGISTRY.get("browser_navigate")
|
||||
and TOOL_REGISTRY["browser_navigate"].is_available
|
||||
):
|
||||
assert "`browser_navigate`" in docs
|
||||
|
||||
def test_baseline_supplement_includes_workflows(self):
|
||||
"""Baseline supplement should include workflow guidance in tool descriptions."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
|
||||
docs = get_baseline_supplement()
|
||||
|
||||
# Workflows are now in individual tool descriptions (not separate sections)
|
||||
# Check that key workflow concepts appear in tool descriptions
|
||||
assert "suggested_goal" in docs or "clarifying_questions" in docs
|
||||
assert "run_mcp_tool" in docs
|
||||
|
||||
def test_baseline_supplement_completeness(self):
|
||||
"""All available tools from TOOL_REGISTRY should appear in baseline supplement."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
docs = get_baseline_supplement()
|
||||
|
||||
# Verify each available registered tool is documented
|
||||
# (matches _generate_tool_documentation which filters by is_available)
|
||||
for tool_name, tool in TOOL_REGISTRY.items():
|
||||
if not tool.is_available:
|
||||
continue
|
||||
assert (
|
||||
f"`{tool_name}`" in docs
|
||||
), f"Tool '{tool_name}' missing from baseline supplement"
|
||||
|
||||
def test_baseline_supplement_no_duplicate_tools(self):
|
||||
"""No tool should appear multiple times in baseline supplement."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
docs = get_baseline_supplement()
|
||||
|
||||
# Count occurrences of each available tool in the entire supplement
|
||||
for tool_name, tool in TOOL_REGISTRY.items():
|
||||
if not tool.is_available:
|
||||
continue
|
||||
# Count how many times this tool appears as a bullet point
|
||||
count = docs.count(f"- **`{tool_name}`**")
|
||||
assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)"
|
||||
|
||||
@@ -10,13 +10,14 @@ Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
|
||||
filesystem for self-hosted) — no DB column needed.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from backend.util import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
|
||||
@@ -58,41 +59,37 @@ def strip_progress_entries(content: str) -> str:
|
||||
Removes entries whose ``type`` is in ``STRIPPABLE_TYPES`` and reparents
|
||||
any remaining child entries so the ``parentUuid`` chain stays intact.
|
||||
Typically reduces transcript size by ~30%.
|
||||
|
||||
Entries that are not stripped or reparented are kept as their original
|
||||
raw JSON line to avoid unnecessary re-serialization that changes
|
||||
whitespace or key ordering.
|
||||
"""
|
||||
lines = content.strip().split("\n")
|
||||
|
||||
entries: list[dict] = []
|
||||
# Parse entries, keeping the original line alongside the parsed dict.
|
||||
parsed: list[tuple[str, dict | None]] = []
|
||||
for line in lines:
|
||||
try:
|
||||
entries.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
# Keep unparseable lines as-is (safety)
|
||||
entries.append({"_raw": line})
|
||||
parsed.append((line, json.loads(line, fallback=None)))
|
||||
|
||||
# First pass: identify stripped UUIDs and build parent map.
|
||||
stripped_uuids: set[str] = set()
|
||||
uuid_to_parent: dict[str, str] = {}
|
||||
kept: list[dict] = []
|
||||
|
||||
for entry in entries:
|
||||
if "_raw" in entry:
|
||||
kept.append(entry)
|
||||
for _line, entry in parsed:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
uid = entry.get("uuid", "")
|
||||
parent = entry.get("parentUuid", "")
|
||||
entry_type = entry.get("type", "")
|
||||
|
||||
if uid:
|
||||
uuid_to_parent[uid] = parent
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES and uid:
|
||||
stripped_uuids.add(uid)
|
||||
|
||||
if entry_type in STRIPPABLE_TYPES:
|
||||
if uid:
|
||||
stripped_uuids.add(uid)
|
||||
else:
|
||||
kept.append(entry)
|
||||
|
||||
# Reparent: walk up chain through stripped entries to find surviving ancestor
|
||||
for entry in kept:
|
||||
if "_raw" in entry:
|
||||
# Second pass: keep non-stripped entries, reparenting where needed.
|
||||
# Preserve original line when no reparenting is required.
|
||||
reparented: set[str] = set()
|
||||
for _line, entry in parsed:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
parent = entry.get("parentUuid", "")
|
||||
original_parent = parent
|
||||
@@ -100,63 +97,32 @@ def strip_progress_entries(content: str) -> str:
|
||||
parent = uuid_to_parent.get(parent, "")
|
||||
if parent != original_parent:
|
||||
entry["parentUuid"] = parent
|
||||
uid = entry.get("uuid", "")
|
||||
if uid:
|
||||
reparented.add(uid)
|
||||
|
||||
result_lines: list[str] = []
|
||||
for entry in kept:
|
||||
if "_raw" in entry:
|
||||
result_lines.append(entry["_raw"])
|
||||
else:
|
||||
for line, entry in parsed:
|
||||
if not isinstance(entry, dict):
|
||||
result_lines.append(line)
|
||||
continue
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES:
|
||||
continue
|
||||
uid = entry.get("uuid", "")
|
||||
if uid in reparented:
|
||||
# Re-serialize only entries whose parentUuid was changed.
|
||||
result_lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
else:
|
||||
result_lines.append(line)
|
||||
|
||||
return "\n".join(result_lines) + "\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Local file I/O (read from CLI's JSONL, write temp file for --resume)
|
||||
# Local file I/O (write temp file for --resume)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def read_transcript_file(transcript_path: str) -> str | None:
|
||||
"""Read a JSONL transcript file from disk.
|
||||
|
||||
Returns the raw JSONL content, or ``None`` if the file is missing, empty,
|
||||
or only contains metadata (≤2 lines with no conversation messages).
|
||||
"""
|
||||
if not transcript_path or not os.path.isfile(transcript_path):
|
||||
logger.debug(f"[Transcript] File not found: {transcript_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(transcript_path) as f:
|
||||
content = f.read()
|
||||
|
||||
if not content.strip():
|
||||
logger.debug("[Transcript] File is empty: %s", transcript_path)
|
||||
return None
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
|
||||
# Validate that the transcript has real conversation content
|
||||
# (not just metadata like queue-operation entries).
|
||||
if not validate_transcript(content):
|
||||
logger.debug(
|
||||
"[Transcript] No conversation content (%d lines) in %s",
|
||||
len(lines),
|
||||
transcript_path,
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"[Transcript] Read {len(lines)} lines, "
|
||||
f"{len(content)} bytes from {transcript_path}"
|
||||
)
|
||||
return content
|
||||
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.warning(f"[Transcript] Failed to read {transcript_path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
||||
"""Sanitize an ID for safe use in file paths.
|
||||
|
||||
@@ -171,14 +137,6 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
||||
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
||||
|
||||
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does.
|
||||
|
||||
The CLI replaces all non-alphanumeric characters with ``-``.
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""Remove the CLI's project directory for a specific working directory.
|
||||
|
||||
@@ -188,7 +146,8 @@ def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""
|
||||
import shutil
|
||||
|
||||
cwd_encoded = _encode_cwd_for_cli(sdk_cwd)
|
||||
# Encode cwd the same way CLI does (replaces non-alphanumeric with -)
|
||||
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
|
||||
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
|
||||
@@ -248,32 +207,29 @@ def write_transcript_to_tempfile(
|
||||
def validate_transcript(content: str | None) -> bool:
|
||||
"""Check that a transcript has actual conversation messages.
|
||||
|
||||
A valid transcript for resume needs at least one user message and one
|
||||
assistant message (not just queue-operation / file-history-snapshot
|
||||
metadata).
|
||||
A valid transcript needs at least one assistant message (not just
|
||||
queue-operation / file-history-snapshot metadata). We do NOT require
|
||||
a ``type: "user"`` entry because with ``--resume`` the user's message
|
||||
is passed as a CLI query parameter and does not appear in the
|
||||
transcript file.
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
return False
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
if len(lines) < 2:
|
||||
return False
|
||||
|
||||
has_user = False
|
||||
has_assistant = False
|
||||
|
||||
for line in lines:
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
msg_type = entry.get("type")
|
||||
if msg_type == "user":
|
||||
has_user = True
|
||||
elif msg_type == "assistant":
|
||||
has_assistant = True
|
||||
except json.JSONDecodeError:
|
||||
if not line.strip():
|
||||
continue
|
||||
entry = json.loads(line, fallback=None)
|
||||
if not isinstance(entry, dict):
|
||||
return False
|
||||
if entry.get("type") == "assistant":
|
||||
has_assistant = True
|
||||
|
||||
return has_user and has_assistant
|
||||
return has_assistant
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -328,26 +284,41 @@ async def upload_transcript(
|
||||
session_id: str,
|
||||
content: str,
|
||||
message_count: int = 0,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> None:
|
||||
"""Strip progress entries and upload transcript to bucket storage.
|
||||
"""Strip progress entries and upload complete transcript.
|
||||
|
||||
The transcript represents the FULL active context (atomic).
|
||||
Each upload REPLACES the previous transcript entirely.
|
||||
|
||||
The executor holds a cluster lock per session, so concurrent uploads for
|
||||
the same session cannot happen. We always overwrite — with ``--resume``
|
||||
the CLI may compact old tool results, so neither byte size nor line count
|
||||
is a reliable proxy for "newer".
|
||||
the same session cannot happen.
|
||||
|
||||
Args:
|
||||
message_count: ``len(session.messages)`` at upload time — used by
|
||||
the next turn to detect staleness and compress only the gap.
|
||||
content: Complete JSONL transcript (from TranscriptBuilder).
|
||||
message_count: ``len(session.messages)`` at upload time.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
# Strip metadata entries (progress, file-history-snapshot, etc.)
|
||||
# Note: SDK-built transcripts shouldn't have these, but strip for safety
|
||||
stripped = strip_progress_entries(content)
|
||||
if not validate_transcript(stripped):
|
||||
# Log entry types for debugging — helps identify why validation failed
|
||||
entry_types: list[str] = []
|
||||
for line in stripped.strip().split("\n"):
|
||||
entry = json.loads(line, fallback={"type": "INVALID_JSON"})
|
||||
entry_types.append(entry.get("type", "?"))
|
||||
logger.warning(
|
||||
f"[Transcript] Skipping upload — stripped content not valid "
|
||||
f"for session {session_id}"
|
||||
"%s Skipping upload — stripped content not valid "
|
||||
"(types=%s, stripped_len=%d, raw_len=%d)",
|
||||
log_prefix,
|
||||
entry_types,
|
||||
len(stripped),
|
||||
len(content),
|
||||
)
|
||||
logger.debug("%s Raw content preview: %s", log_prefix, content[:500])
|
||||
logger.debug("%s Stripped content: %s", log_prefix, stripped[:500])
|
||||
return
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
@@ -373,17 +344,18 @@ async def upload_transcript(
|
||||
content=json.dumps(meta).encode("utf-8"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transcript] Failed to write metadata for {session_id}: {e}")
|
||||
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[Transcript] Uploaded {len(encoded)}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count}) "
|
||||
f"for session {session_id}"
|
||||
f"{log_prefix} Uploaded {len(encoded)}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count})"
|
||||
)
|
||||
|
||||
|
||||
async def download_transcript(
|
||||
user_id: str, session_id: str
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> TranscriptDownload | None:
|
||||
"""Download transcript and metadata from bucket storage.
|
||||
|
||||
@@ -399,10 +371,10 @@ async def download_transcript(
|
||||
data = await storage.retrieve(path)
|
||||
content = data.decode("utf-8")
|
||||
except FileNotFoundError:
|
||||
logger.debug(f"[Transcript] No transcript in storage for {session_id}")
|
||||
logger.debug(f"{log_prefix} No transcript in storage")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transcript] Failed to download transcript: {e}")
|
||||
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
|
||||
return None
|
||||
|
||||
# Try to load metadata (best-effort — old transcripts won't have it)
|
||||
@@ -419,16 +391,13 @@ async def download_transcript(
|
||||
meta_path = f"local://{mwid}/{mfid}/{mfname}"
|
||||
|
||||
meta_data = await storage.retrieve(meta_path)
|
||||
meta = json.loads(meta_data.decode("utf-8"))
|
||||
meta = json.loads(meta_data.decode("utf-8"), fallback={})
|
||||
message_count = meta.get("message_count", 0)
|
||||
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||
except (FileNotFoundError, json.JSONDecodeError, Exception):
|
||||
except (FileNotFoundError, Exception):
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
|
||||
logger.info(
|
||||
f"[Transcript] Downloaded {len(content)}B "
|
||||
f"(msg_count={message_count}) for session {session_id}"
|
||||
)
|
||||
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
|
||||
return TranscriptDownload(
|
||||
content=content,
|
||||
message_count=message_count,
|
||||
|
||||
@@ -0,0 +1,188 @@
|
||||
"""Build complete JSONL transcript from SDK messages.
|
||||
|
||||
The transcript represents the FULL active context at any point in time.
|
||||
Each upload REPLACES the previous transcript atomically.
|
||||
|
||||
Flow:
|
||||
Turn 1: Upload [msg1, msg2]
|
||||
Turn 2: Download [msg1, msg2] → Upload [msg1, msg2, msg3, msg4] (REPLACE)
|
||||
Turn 3: Download [msg1, msg2, msg3, msg4] → Upload [all messages] (REPLACE)
|
||||
|
||||
The transcript is never incremental - always the complete atomic state.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .transcript import STRIPPABLE_TYPES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TranscriptEntry(BaseModel):
|
||||
"""Single transcript entry (user or assistant turn)."""
|
||||
|
||||
type: str
|
||||
uuid: str
|
||||
parentUuid: str | None
|
||||
message: dict[str, Any]
|
||||
|
||||
|
||||
class TranscriptBuilder:
|
||||
"""Build complete JSONL transcript from SDK messages.
|
||||
|
||||
This builder maintains the FULL conversation state, not incremental changes.
|
||||
The output is always the complete active context.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._entries: list[TranscriptEntry] = []
|
||||
self._last_uuid: str | None = None
|
||||
|
||||
def _last_is_assistant(self) -> bool:
|
||||
return bool(self._entries) and self._entries[-1].type == "assistant"
|
||||
|
||||
def _last_message_id(self) -> str:
|
||||
"""Return the message.id of the last entry, or '' if none."""
|
||||
if self._entries:
|
||||
return self._entries[-1].message.get("id", "")
|
||||
return ""
|
||||
|
||||
def load_previous(self, content: str, log_prefix: str = "[Transcript]") -> None:
|
||||
"""Load complete previous transcript.
|
||||
|
||||
This loads the FULL previous context. As new messages come in,
|
||||
we append to this state. The final output is the complete context
|
||||
(previous + new), not just the delta.
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
return
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
data = json.loads(line, fallback=None)
|
||||
if data is None:
|
||||
logger.warning(
|
||||
"%s Failed to parse transcript line %d/%d",
|
||||
log_prefix,
|
||||
line_num,
|
||||
len(lines),
|
||||
)
|
||||
continue
|
||||
|
||||
# Load all non-strippable entries (user/assistant/system/etc.)
|
||||
# Skip only STRIPPABLE_TYPES to match strip_progress_entries() behavior
|
||||
entry_type = data.get("type", "")
|
||||
if entry_type in STRIPPABLE_TYPES:
|
||||
continue
|
||||
|
||||
entry = TranscriptEntry(
|
||||
type=data["type"],
|
||||
uuid=data.get("uuid") or str(uuid4()),
|
||||
parentUuid=data.get("parentUuid"),
|
||||
message=data.get("message", {}),
|
||||
)
|
||||
self._entries.append(entry)
|
||||
self._last_uuid = entry.uuid
|
||||
|
||||
logger.info(
|
||||
"%s Loaded %d entries from previous transcript (last_uuid=%s)",
|
||||
log_prefix,
|
||||
len(self._entries),
|
||||
self._last_uuid[:12] if self._last_uuid else None,
|
||||
)
|
||||
|
||||
def append_user(self, content: str | list[dict], uuid: str | None = None) -> None:
|
||||
"""Append a user entry."""
|
||||
msg_uuid = uuid or str(uuid4())
|
||||
|
||||
self._entries.append(
|
||||
TranscriptEntry(
|
||||
type="user",
|
||||
uuid=msg_uuid,
|
||||
parentUuid=self._last_uuid,
|
||||
message={"role": "user", "content": content},
|
||||
)
|
||||
)
|
||||
self._last_uuid = msg_uuid
|
||||
|
||||
def append_tool_result(self, tool_use_id: str, content: str) -> None:
|
||||
"""Append a tool result as a user entry (one per tool call)."""
|
||||
self.append_user(
|
||||
content=[
|
||||
{"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
|
||||
]
|
||||
)
|
||||
|
||||
def append_assistant(
|
||||
self,
|
||||
content_blocks: list[dict],
|
||||
model: str = "",
|
||||
stop_reason: str | None = None,
|
||||
) -> None:
|
||||
"""Append an assistant entry.
|
||||
|
||||
Consecutive assistant entries automatically share the same message ID
|
||||
so the CLI can merge them (thinking → text → tool_use) into a single
|
||||
API message on ``--resume``. A new ID is assigned whenever an
|
||||
assistant entry follows a non-assistant entry (user message or tool
|
||||
result), because that marks the start of a new API response.
|
||||
"""
|
||||
message_id = (
|
||||
self._last_message_id()
|
||||
if self._last_is_assistant()
|
||||
else f"msg_sdk_{uuid4().hex[:24]}"
|
||||
)
|
||||
|
||||
msg_uuid = str(uuid4())
|
||||
|
||||
self._entries.append(
|
||||
TranscriptEntry(
|
||||
type="assistant",
|
||||
uuid=msg_uuid,
|
||||
parentUuid=self._last_uuid,
|
||||
message={
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"id": message_id,
|
||||
"type": "message",
|
||||
"content": content_blocks,
|
||||
"stop_reason": stop_reason,
|
||||
"stop_sequence": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
self._last_uuid = msg_uuid
|
||||
|
||||
def to_jsonl(self) -> str:
|
||||
"""Export complete context as JSONL.
|
||||
|
||||
Consecutive assistant entries are kept separate to match the
|
||||
native CLI format — the SDK merges them internally on resume.
|
||||
|
||||
Returns the FULL conversation state (all entries), not incremental.
|
||||
This output REPLACES any previous transcript.
|
||||
"""
|
||||
if not self._entries:
|
||||
return ""
|
||||
|
||||
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
@property
|
||||
def entry_count(self) -> int:
|
||||
"""Total number of entries in the complete context."""
|
||||
return len(self._entries)
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Whether this builder has any entries."""
|
||||
return len(self._entries) == 0
|
||||
@@ -1,11 +1,11 @@
|
||||
"""Unit tests for JSONL transcript management utilities."""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .transcript import (
|
||||
STRIPPABLE_TYPES,
|
||||
read_transcript_file,
|
||||
strip_progress_entries,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
@@ -38,49 +38,6 @@ PROGRESS_ENTRY = {
|
||||
VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG)
|
||||
|
||||
|
||||
# --- read_transcript_file ---
|
||||
|
||||
|
||||
class TestReadTranscriptFile:
|
||||
def test_returns_content_for_valid_file(self, tmp_path):
|
||||
path = tmp_path / "session.jsonl"
|
||||
path.write_text(VALID_TRANSCRIPT)
|
||||
result = read_transcript_file(str(path))
|
||||
assert result is not None
|
||||
assert "user" in result
|
||||
|
||||
def test_returns_none_for_missing_file(self):
|
||||
assert read_transcript_file("/nonexistent/path.jsonl") is None
|
||||
|
||||
def test_returns_none_for_empty_path(self):
|
||||
assert read_transcript_file("") is None
|
||||
|
||||
def test_returns_none_for_empty_file(self, tmp_path):
|
||||
path = tmp_path / "empty.jsonl"
|
||||
path.write_text("")
|
||||
assert read_transcript_file(str(path)) is None
|
||||
|
||||
def test_returns_none_for_metadata_only(self, tmp_path):
|
||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY)
|
||||
path = tmp_path / "meta.jsonl"
|
||||
path.write_text(content)
|
||||
assert read_transcript_file(str(path)) is None
|
||||
|
||||
def test_returns_none_for_invalid_json(self, tmp_path):
|
||||
path = tmp_path / "bad.jsonl"
|
||||
path.write_text("not json\n{}\n{}\n")
|
||||
assert read_transcript_file(str(path)) is None
|
||||
|
||||
def test_no_size_limit(self, tmp_path):
|
||||
"""Large files are accepted — bucket storage has no size limit."""
|
||||
big_content = {"type": "user", "uuid": "u9", "data": "x" * 1_000_000}
|
||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, big_content, ASST_MSG)
|
||||
path = tmp_path / "big.jsonl"
|
||||
path.write_text(content)
|
||||
result = read_transcript_file(str(path))
|
||||
assert result is not None
|
||||
|
||||
|
||||
# --- write_transcript_to_tempfile ---
|
||||
|
||||
|
||||
@@ -155,12 +112,56 @@ class TestValidateTranscript:
|
||||
assert validate_transcript(content) is False
|
||||
|
||||
def test_assistant_only_no_user(self):
|
||||
"""With --resume the user message is a CLI query param, not a transcript entry.
|
||||
A transcript with only assistant entries is valid."""
|
||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, ASST_MSG)
|
||||
assert validate_transcript(content) is False
|
||||
assert validate_transcript(content) is True
|
||||
|
||||
def test_resume_transcript_without_user_entry(self):
|
||||
"""Simulates a real --resume stop hook transcript: the CLI session file
|
||||
has summary + assistant entries but no user entry."""
|
||||
summary = {"type": "summary", "uuid": "s1", "text": "context..."}
|
||||
asst1 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {"role": "assistant", "content": "Hello!"},
|
||||
}
|
||||
asst2 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a2",
|
||||
"parentUuid": "a1",
|
||||
"message": {"role": "assistant", "content": "Sure, let me help."},
|
||||
}
|
||||
content = _make_jsonl(summary, asst1, asst2)
|
||||
assert validate_transcript(content) is True
|
||||
|
||||
def test_single_assistant_entry(self):
|
||||
"""A transcript with just one assistant line is valid — the CLI may
|
||||
produce short transcripts for simple responses with no tool use."""
|
||||
content = json.dumps(ASST_MSG) + "\n"
|
||||
assert validate_transcript(content) is True
|
||||
|
||||
def test_invalid_json_returns_false(self):
|
||||
assert validate_transcript("not json\n{}\n{}\n") is False
|
||||
|
||||
def test_malformed_json_after_valid_assistant_returns_false(self):
|
||||
"""Validation must scan all lines - malformed JSON anywhere should fail."""
|
||||
valid_asst = json.dumps(ASST_MSG)
|
||||
malformed = "not valid json"
|
||||
content = valid_asst + "\n" + malformed + "\n"
|
||||
assert validate_transcript(content) is False
|
||||
|
||||
def test_blank_lines_are_skipped(self):
|
||||
"""Transcripts with blank lines should be valid if they contain assistant entries."""
|
||||
content = (
|
||||
json.dumps(USER_MSG)
|
||||
+ "\n\n" # blank line
|
||||
+ json.dumps(ASST_MSG)
|
||||
+ "\n"
|
||||
+ "\n" # another blank line
|
||||
)
|
||||
assert validate_transcript(content) is True
|
||||
|
||||
|
||||
# --- strip_progress_entries ---
|
||||
|
||||
@@ -253,3 +254,31 @@ class TestStripProgressEntries:
|
||||
assert "queue-operation" not in result_types
|
||||
assert "user" in result_types
|
||||
assert "assistant" in result_types
|
||||
|
||||
def test_preserves_original_line_formatting(self):
|
||||
"""Non-reparented entries keep their original JSON formatting."""
|
||||
# orjson produces compact JSON - test that we preserve the exact input
|
||||
# when no reparenting is needed (no re-serialization)
|
||||
original_line = json.dumps(USER_MSG)
|
||||
|
||||
content = original_line + "\n" + json.dumps(ASST_MSG) + "\n"
|
||||
result = strip_progress_entries(content)
|
||||
result_lines = result.strip().split("\n")
|
||||
|
||||
# Original line should be byte-identical (not re-serialized)
|
||||
assert result_lines[0] == original_line
|
||||
|
||||
def test_reparented_entries_are_reserialized(self):
|
||||
"""Entries whose parentUuid changes must be re-serialized."""
|
||||
progress = {"type": "progress", "uuid": "p1", "parentUuid": "u1"}
|
||||
asst = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "p1",
|
||||
"message": {"role": "assistant", "content": "done"},
|
||||
}
|
||||
content = _make_jsonl(USER_MSG, progress, asst)
|
||||
result = strip_progress_entries(content)
|
||||
lines = result.strip().split("\n")
|
||||
asst_entry = json.loads(lines[-1])
|
||||
assert asst_entry["parentUuid"] == "u1" # reparented
|
||||
|
||||
@@ -18,7 +18,7 @@ from langfuse.openai import (
|
||||
|
||||
from backend.data.db_accessors import understanding_db
|
||||
from backend.data.understanding import format_understanding_for_prompt
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
from backend.util.settings import AppEnvironment, Settings
|
||||
|
||||
from .config import ChatConfig
|
||||
@@ -34,8 +34,9 @@ client = LangfuseAsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
langfuse = get_client()
|
||||
|
||||
# Default system prompt used when Langfuse is not configured
|
||||
# This is a snapshot of the "CoPilot Prompt" from Langfuse (version 11)
|
||||
DEFAULT_SYSTEM_PROMPT = """You are **Otto**, an AI Co-Pilot for AutoGPT and a Forward-Deployed Automation Engineer serving small business owners. Your mission is to help users automate business tasks with AI by delivering tangible value through working automations—not through documentation or lengthy explanations.
|
||||
# Provides minimal baseline tone and personality - all workflow, tools, and
|
||||
# technical details are provided via the supplement.
|
||||
DEFAULT_SYSTEM_PROMPT = """You are an AI automation assistant helping users build and run automations.
|
||||
|
||||
Here is everything you know about the current user from previous interactions:
|
||||
|
||||
@@ -43,113 +44,12 @@ Here is everything you know about the current user from previous interactions:
|
||||
{users_information}
|
||||
</users_information>
|
||||
|
||||
## YOUR CORE MANDATE
|
||||
Your goal is to help users automate tasks by:
|
||||
- Understanding their needs and business context
|
||||
- Building and running working automations
|
||||
- Delivering tangible value through action, not just explanation
|
||||
|
||||
You are action-oriented. Your success is measured by:
|
||||
- **Value Delivery**: Does the user think "wow, that was amazing" or "what was the point"?
|
||||
- **Demonstrable Proof**: Show working automations, not descriptions of what's possible
|
||||
- **Time Saved**: Focus on tangible efficiency gains
|
||||
- **Quality Output**: Deliver results that meet or exceed expectations
|
||||
|
||||
## YOUR WORKFLOW
|
||||
|
||||
Adapt flexibly to the conversation context. Not every interaction requires all stages:
|
||||
|
||||
1. **Explore & Understand**: Learn about the user's business, tasks, and goals. Use `add_understanding` to capture important context that will improve future conversations.
|
||||
|
||||
2. **Assess Automation Potential**: Help the user understand whether and how AI can automate their task.
|
||||
|
||||
3. **Prepare for AI**: Provide brief, actionable guidance on prerequisites (data, access, etc.).
|
||||
|
||||
4. **Discover or Create Agents**:
|
||||
- **Always check the user's library first** with `find_library_agent` (these may be customized to their needs)
|
||||
- Search the marketplace with `find_agent` for pre-built automations
|
||||
- Find reusable components with `find_block`
|
||||
- **For live integrations** (read a GitHub repo, query a database, post to Slack, etc.) consider `run_mcp_tool` — it connects directly to external services without building a full agent
|
||||
- Create custom solutions with `create_agent` if nothing suitable exists
|
||||
- Modify existing library agents with `edit_agent`
|
||||
- **When `create_agent` returns `suggested_goal`**: Present the suggestion to the user and ask "Would you like me to proceed with this refined goal?" If they accept, call `create_agent` again with the suggested goal.
|
||||
- **When `create_agent` returns `clarifying_questions`**: After the user answers, call `create_agent` again with the original description AND the answers in the `context` parameter.
|
||||
|
||||
5. **Execute**: Run automations immediately, schedule them, or set up webhooks using `run_agent`. Test specific components with `run_block`.
|
||||
|
||||
6. **Show Results**: Display outputs using `agent_output`.
|
||||
|
||||
## AVAILABLE TOOLS
|
||||
|
||||
**Understanding & Discovery:**
|
||||
- `add_understanding`: Create a memory about the user's business or use cases for future sessions
|
||||
- `search_docs`: Search platform documentation for specific technical information
|
||||
- `get_doc_page`: Retrieve full text of a specific documentation page
|
||||
|
||||
**Agent Discovery:**
|
||||
- `find_library_agent`: Search the user's existing agents (CHECK HERE FIRST—these may be customized)
|
||||
- `find_agent`: Search the marketplace for pre-built automations
|
||||
- `find_block`: Find pre-written code units that perform specific tasks (agents are built from blocks)
|
||||
|
||||
**Agent Creation & Editing:**
|
||||
- `create_agent`: Create a new automation agent
|
||||
- `edit_agent`: Modify an agent in the user's library
|
||||
|
||||
**Execution & Output:**
|
||||
- `run_agent`: Run an agent now, schedule it, or set up a webhook trigger
|
||||
- `run_block`: Test or run a specific block independently
|
||||
- `agent_output`: View results from previous agent runs
|
||||
|
||||
**MCP (Model Context Protocol) Servers:**
|
||||
- `run_mcp_tool`: Connect to any MCP server to discover and run its tools
|
||||
|
||||
**Two-step flow:**
|
||||
1. `run_mcp_tool(server_url)` → returns a list of available tools. Each tool has `name`, `description`, and `input_schema` (JSON Schema). Read `input_schema.properties` to understand what arguments are needed.
|
||||
2. `run_mcp_tool(server_url, tool_name, tool_arguments)` → executes the tool. Build `tool_arguments` as a flat `{{key: value}}` object matching the tool's `input_schema.properties`.
|
||||
|
||||
**Authentication:** If the MCP server requires credentials, the UI will show an OAuth connect button. Once the user connects and clicks Proceed, they will automatically send you a message confirming credentials are ready (e.g. "I've connected the MCP server credentials. Please retry run_mcp_tool..."). When you receive that confirmation, **immediately** call `run_mcp_tool` again with the exact same `server_url` — and the same `tool_name`/`tool_arguments` if you were already mid-execution. Do not ask the user what to do next; just retry.
|
||||
|
||||
**Finding server URLs (fastest → slowest):**
|
||||
1. **Known hosted servers** — use directly, no lookup:
|
||||
- Notion: `https://mcp.notion.com/mcp`
|
||||
- Linear: `https://mcp.linear.app/mcp`
|
||||
- Stripe: `https://mcp.stripe.com`
|
||||
- Intercom: `https://mcp.intercom.com/mcp`
|
||||
- Cloudflare: `https://mcp.cloudflare.com/mcp`
|
||||
- Atlassian (Jira/Confluence): `https://mcp.atlassian.com/mcp`
|
||||
2. **`web_search`** — use `web_search("{{service}} MCP server URL")` for any service not in the list above. This is the fastest way to find unlisted servers.
|
||||
3. **Registry API** — `web_fetch("https://registry.modelcontextprotocol.io/v0.1/servers?search={{query}}&limit=10")` to browse what's available. Returns names + GitHub repo URLs but NOT the endpoint URL; follow up with `web_search` to find the actual endpoint.
|
||||
- **Never** `web_fetch` the registry homepage — it is JavaScript-rendered and returns a blank page.
|
||||
|
||||
**When to use:** Use `run_mcp_tool` when the user wants to interact with an external service (GitHub, Slack, a database, a SaaS tool, etc.) via its MCP integration. Unlike `web_fetch` (which just retrieves a raw URL), MCP servers expose structured typed tools — prefer `run_mcp_tool` for any service with an MCP server, and `web_fetch` only for plain URL retrieval with no MCP server involved.
|
||||
|
||||
**CRITICAL**: `run_mcp_tool` is **always available** in your tool list. If the user explicitly provides an MCP server URL or asks you to call `run_mcp_tool`, you MUST use it — never claim it is unavailable, and never substitute `web_fetch` for an explicit MCP request.
|
||||
|
||||
## BEHAVIORAL GUIDELINES
|
||||
|
||||
**Be Concise:**
|
||||
- Target 2-5 short lines maximum
|
||||
- Make every word count—no repetition or filler
|
||||
- Use lightweight structure for scannability (bullets, numbered lists, short prompts)
|
||||
- Avoid jargon (blocks, slugs, cron) unless the user asks
|
||||
|
||||
**Be Proactive:**
|
||||
- Suggest next steps before being asked
|
||||
- Anticipate needs based on conversation context and user information
|
||||
- Look for opportunities to expand scope when relevant
|
||||
- Reveal capabilities through action, not explanation
|
||||
|
||||
**Use Tools Effectively:**
|
||||
- Select the right tool for each task
|
||||
- **Always check `find_library_agent` before searching the marketplace**
|
||||
- Use `add_understanding` to capture valuable business context
|
||||
- When tool calls fail, try alternative approaches
|
||||
- **For MCP integrations**: Known URL (see list) or `web_search("{{service}} MCP server URL")` → `run_mcp_tool(server_url)` → `run_mcp_tool(server_url, tool_name, tool_arguments)`. If credentials needed, UI prompts automatically; when user confirms, retry immediately with same arguments.
|
||||
|
||||
**Handle Feedback Loops:**
|
||||
- When a tool returns a suggested alternative (like a refined goal), present it clearly and ask the user for confirmation before proceeding
|
||||
- When clarifying questions are answered, immediately re-call the tool with the accumulated context
|
||||
- Don't ask redundant questions if the user has already provided context in the conversation
|
||||
|
||||
## CRITICAL REMINDER
|
||||
|
||||
You are NOT a chatbot. You are NOT documentation. You are a partner who helps busy business owners get value quickly by showing proof through working automations. Bias toward action over explanation."""
|
||||
Be concise, proactive, and action-oriented. Bias toward showing working solutions over lengthy explanations."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -298,6 +198,12 @@ async def assign_user_to_session(
|
||||
session = await get_chat_session(session_id, None)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found")
|
||||
if session.user_id is not None and session.user_id != user_id:
|
||||
logger.warning(
|
||||
f"[SECURITY] Attempt to claim session {session_id} by user {user_id}, "
|
||||
f"but it already belongs to user {session.user_id}"
|
||||
)
|
||||
raise NotAuthorizedError(f"Not authorized to claim session {session_id}")
|
||||
session.user_id = user_id
|
||||
session = await upsert_chat_session(session)
|
||||
return session
|
||||
|
||||
@@ -20,6 +20,14 @@ from .find_agent import FindAgentTool
|
||||
from .find_block import FindBlockTool
|
||||
from .find_library_agent import FindLibraryAgentTool
|
||||
from .get_doc_page import GetDocPageTool
|
||||
from .manage_folders import (
|
||||
CreateFolderTool,
|
||||
DeleteFolderTool,
|
||||
ListFoldersTool,
|
||||
MoveAgentsToFolderTool,
|
||||
MoveFolderTool,
|
||||
UpdateFolderTool,
|
||||
)
|
||||
from .run_agent import RunAgentTool
|
||||
from .run_block import RunBlockTool
|
||||
from .run_mcp_tool import RunMCPToolTool
|
||||
@@ -47,6 +55,13 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"find_agent": FindAgentTool(),
|
||||
"find_block": FindBlockTool(),
|
||||
"find_library_agent": FindLibraryAgentTool(),
|
||||
# Folder management tools
|
||||
"create_folder": CreateFolderTool(),
|
||||
"list_folders": ListFoldersTool(),
|
||||
"update_folder": UpdateFolderTool(),
|
||||
"move_folder": MoveFolderTool(),
|
||||
"delete_folder": DeleteFolderTool(),
|
||||
"move_agents_to_folder": MoveAgentsToFolderTool(),
|
||||
"run_agent": RunAgentTool(),
|
||||
"run_block": RunBlockTool(),
|
||||
"run_mcp_tool": RunMCPToolTool(),
|
||||
|
||||
@@ -151,8 +151,8 @@ async def setup_test_data(server):
|
||||
unique_slug = f"test-agent-{str(uuid.uuid4())[:8]}"
|
||||
store_submission = await store_db.create_store_submission(
|
||||
user_id=user.id,
|
||||
agent_id=created_graph.id,
|
||||
agent_version=created_graph.version,
|
||||
graph_id=created_graph.id,
|
||||
graph_version=created_graph.version,
|
||||
slug=unique_slug,
|
||||
name="Test Agent",
|
||||
description="A simple test agent",
|
||||
@@ -161,10 +161,10 @@ async def setup_test_data(server):
|
||||
image_urls=["https://example.com/image.jpg"],
|
||||
)
|
||||
|
||||
assert store_submission.store_listing_version_id is not None
|
||||
assert store_submission.listing_version_id is not None
|
||||
# 4. Approve the store listing version
|
||||
await store_db.review_store_submission(
|
||||
store_listing_version_id=store_submission.store_listing_version_id,
|
||||
store_listing_version_id=store_submission.listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Approved for testing",
|
||||
internal_comments="Test approval",
|
||||
@@ -321,8 +321,8 @@ async def setup_llm_test_data(server):
|
||||
unique_slug = f"llm-test-agent-{str(uuid.uuid4())[:8]}"
|
||||
store_submission = await store_db.create_store_submission(
|
||||
user_id=user.id,
|
||||
agent_id=created_graph.id,
|
||||
agent_version=created_graph.version,
|
||||
graph_id=created_graph.id,
|
||||
graph_version=created_graph.version,
|
||||
slug=unique_slug,
|
||||
name="LLM Test Agent",
|
||||
description="An agent with LLM capabilities",
|
||||
@@ -330,9 +330,9 @@ async def setup_llm_test_data(server):
|
||||
categories=["testing", "ai"],
|
||||
image_urls=["https://example.com/image.jpg"],
|
||||
)
|
||||
assert store_submission.store_listing_version_id is not None
|
||||
assert store_submission.listing_version_id is not None
|
||||
await store_db.review_store_submission(
|
||||
store_listing_version_id=store_submission.store_listing_version_id,
|
||||
store_listing_version_id=store_submission.listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Approved for testing",
|
||||
internal_comments="Test approval for LLM agent",
|
||||
@@ -476,8 +476,8 @@ async def setup_firecrawl_test_data(server):
|
||||
unique_slug = f"firecrawl-test-agent-{str(uuid.uuid4())[:8]}"
|
||||
store_submission = await store_db.create_store_submission(
|
||||
user_id=user.id,
|
||||
agent_id=created_graph.id,
|
||||
agent_version=created_graph.version,
|
||||
graph_id=created_graph.id,
|
||||
graph_version=created_graph.version,
|
||||
slug=unique_slug,
|
||||
name="Firecrawl Test Agent",
|
||||
description="An agent with Firecrawl integration (no credentials)",
|
||||
@@ -485,9 +485,9 @@ async def setup_firecrawl_test_data(server):
|
||||
categories=["testing", "scraping"],
|
||||
image_urls=["https://example.com/image.jpg"],
|
||||
)
|
||||
assert store_submission.store_listing_version_id is not None
|
||||
assert store_submission.listing_version_id is not None
|
||||
await store_db.review_store_submission(
|
||||
store_listing_version_id=store_submission.store_listing_version_id,
|
||||
store_listing_version_id=store_submission.listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Approved for testing",
|
||||
internal_comments="Test approval for Firecrawl agent",
|
||||
|
||||
@@ -695,7 +695,10 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||
|
||||
|
||||
async def save_agent_to_library(
|
||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||
agent_json: dict[str, Any],
|
||||
user_id: str,
|
||||
is_update: bool = False,
|
||||
folder_id: str | None = None,
|
||||
) -> tuple[Graph, Any]:
|
||||
"""Save agent to database and user's library.
|
||||
|
||||
@@ -703,6 +706,7 @@ async def save_agent_to_library(
|
||||
agent_json: Agent JSON dict
|
||||
user_id: User ID
|
||||
is_update: Whether this is an update to an existing agent
|
||||
folder_id: Optional folder ID to place the agent in
|
||||
|
||||
Returns:
|
||||
Tuple of (created Graph, LibraryAgent)
|
||||
@@ -711,7 +715,7 @@ async def save_agent_to_library(
|
||||
db = library_db()
|
||||
if is_update:
|
||||
return await db.update_graph_in_library(graph, user_id)
|
||||
return await db.create_graph_in_library(graph, user_id)
|
||||
return await db.create_graph_in_library(graph, user_id, folder_id=folder_id)
|
||||
|
||||
|
||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||
|
||||
@@ -39,9 +39,13 @@ class CreateAgentTool(BaseTool):
|
||||
return (
|
||||
"Create a new agent workflow from a natural language description. "
|
||||
"First generates a preview, then saves to library if save=true. "
|
||||
"\n\nIMPORTANT: Before calling this tool, search for relevant existing agents "
|
||||
"using find_library_agent that could be used as building blocks. "
|
||||
"Pass their IDs in the library_agent_ids parameter so the generator can compose them."
|
||||
"\n\nWorkflow: (1) Always check find_library_agent first for existing building blocks. "
|
||||
"(2) Call create_agent with description and library_agent_ids. "
|
||||
"(3) If response contains suggested_goal: Present to user, ask for confirmation, "
|
||||
"then call again with the suggested goal if accepted. "
|
||||
"(4) If response contains clarifying_questions: Present to user, collect answers, "
|
||||
"then call again with original description AND answers in the context parameter. "
|
||||
"\n\nThis feedback loop ensures the generated agent matches user intent."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -84,6 +88,14 @@ class CreateAgentTool(BaseTool):
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional folder ID to save the agent into. "
|
||||
"If not provided, the agent is saved at root level. "
|
||||
"Use list_folders to find available folders."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["description"],
|
||||
}
|
||||
@@ -105,6 +117,7 @@ class CreateAgentTool(BaseTool):
|
||||
context = kwargs.get("context", "")
|
||||
library_agent_ids = kwargs.get("library_agent_ids", [])
|
||||
save = kwargs.get("save", True)
|
||||
folder_id = kwargs.get("folder_id")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
logger.info(
|
||||
@@ -336,7 +349,7 @@ class CreateAgentTool(BaseTool):
|
||||
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
agent_json, user_id
|
||||
agent_json, user_id, folder_id=folder_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import store_db as get_store_db
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
@@ -80,6 +80,14 @@ class CustomizeAgentTool(BaseTool):
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional folder ID to save the agent into. "
|
||||
"If not provided, the agent is saved at root level. "
|
||||
"Use list_folders to find available folders."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["agent_id", "modifications"],
|
||||
}
|
||||
@@ -102,6 +110,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
modifications = kwargs.get("modifications", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
folder_id = kwargs.get("folder_id")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_id:
|
||||
@@ -140,7 +149,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
agent_details = await store_db.get_store_agent_details(
|
||||
username=creator_username, agent_name=agent_slug
|
||||
)
|
||||
except AgentNotFoundError:
|
||||
except NotFoundError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Could not find marketplace agent '{agent_id}'. "
|
||||
@@ -310,7 +319,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
# Save to user's library
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
customized_agent, user_id, is_update=False
|
||||
customized_agent, user_id, is_update=False, folder_id=folder_id
|
||||
)
|
||||
|
||||
return AgentSavedResponse(
|
||||
|
||||
573
autogpt_platform/backend/backend/copilot/tools/manage_folders.py
Normal file
573
autogpt_platform/backend/backend/copilot/tools/manage_folders.py
Normal file
@@ -0,0 +1,573 @@
|
||||
"""Folder management tools for the copilot."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.library import model as library_model
|
||||
from backend.api.features.library.db import collect_tree_ids
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import library_db
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentsMovedToFolderResponse,
|
||||
ErrorResponse,
|
||||
FolderAgentSummary,
|
||||
FolderCreatedResponse,
|
||||
FolderDeletedResponse,
|
||||
FolderInfo,
|
||||
FolderListResponse,
|
||||
FolderMovedResponse,
|
||||
FolderTreeInfo,
|
||||
FolderUpdatedResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
|
||||
def _folder_to_info(
|
||||
folder: library_model.LibraryFolder,
|
||||
agents: list[FolderAgentSummary] | None = None,
|
||||
) -> FolderInfo:
|
||||
"""Convert a LibraryFolder DB model to a FolderInfo response model."""
|
||||
return FolderInfo(
|
||||
id=folder.id,
|
||||
name=folder.name,
|
||||
parent_id=folder.parent_id,
|
||||
icon=folder.icon,
|
||||
color=folder.color,
|
||||
agent_count=folder.agent_count,
|
||||
subfolder_count=folder.subfolder_count,
|
||||
agents=agents,
|
||||
)
|
||||
|
||||
|
||||
def _tree_to_info(
|
||||
tree: library_model.LibraryFolderTree,
|
||||
agents_map: dict[str, list[FolderAgentSummary]] | None = None,
|
||||
) -> FolderTreeInfo:
|
||||
"""Recursively convert a LibraryFolderTree to a FolderTreeInfo response."""
|
||||
return FolderTreeInfo(
|
||||
id=tree.id,
|
||||
name=tree.name,
|
||||
parent_id=tree.parent_id,
|
||||
icon=tree.icon,
|
||||
color=tree.color,
|
||||
agent_count=tree.agent_count,
|
||||
subfolder_count=tree.subfolder_count,
|
||||
children=[_tree_to_info(child, agents_map) for child in tree.children],
|
||||
agents=agents_map.get(tree.id) if agents_map else None,
|
||||
)
|
||||
|
||||
|
||||
def _to_agent_summaries(
|
||||
raw: list[dict[str, str | None]],
|
||||
) -> list[FolderAgentSummary]:
|
||||
"""Convert raw agent dicts to typed FolderAgentSummary models."""
|
||||
return [
|
||||
FolderAgentSummary(
|
||||
id=a["id"] or "",
|
||||
name=a["name"] or "",
|
||||
description=a["description"] or "",
|
||||
)
|
||||
for a in raw
|
||||
]
|
||||
|
||||
|
||||
def _to_agent_summaries_map(
|
||||
raw: dict[str, list[dict[str, str | None]]],
|
||||
) -> dict[str, list[FolderAgentSummary]]:
|
||||
"""Convert a folder-id-keyed dict of raw agents to typed summaries."""
|
||||
return {fid: _to_agent_summaries(agents) for fid, agents in raw.items()}
|
||||
|
||||
|
||||
class CreateFolderTool(BaseTool):
|
||||
"""Tool for creating a library folder."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "create_folder"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Create a new folder in the user's library to organize agents. "
|
||||
"Optionally nest it inside an existing folder using parent_id."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Name for the new folder (max 100 chars).",
|
||||
},
|
||||
"parent_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"ID of the parent folder to nest inside. "
|
||||
"Omit to create at root level."
|
||||
),
|
||||
},
|
||||
"icon": {
|
||||
"type": "string",
|
||||
"description": "Optional icon identifier for the folder.",
|
||||
},
|
||||
"color": {
|
||||
"type": "string",
|
||||
"description": "Optional hex color code (#RRGGBB).",
|
||||
},
|
||||
},
|
||||
"required": ["name"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Create a folder with the given name and optional parent/icon/color."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
name = (kwargs.get("name") or "").strip()
|
||||
parent_id = kwargs.get("parent_id")
|
||||
icon = kwargs.get("icon")
|
||||
color = kwargs.get("color")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not name:
|
||||
return ErrorResponse(
|
||||
message="Please provide a folder name.",
|
||||
error="missing_name",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
folder = await library_db().create_folder(
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
parent_id=parent_id,
|
||||
icon=icon,
|
||||
color=color,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to create folder: {e}",
|
||||
error="create_folder_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return FolderCreatedResponse(
|
||||
message=f"Folder '{folder.name}' created successfully!",
|
||||
folder=_folder_to_info(folder),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class ListFoldersTool(BaseTool):
|
||||
"""Tool for listing library folders."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "list_folders"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"List the user's library folders. "
|
||||
"Omit parent_id to get the full folder tree. "
|
||||
"Provide parent_id to list only direct children of that folder. "
|
||||
"Set include_agents=true to also return the agents inside each folder "
|
||||
"and root-level agents not in any folder. Always set include_agents=true "
|
||||
"when the user asks about agents, wants to see what's in their folders, "
|
||||
"or mentions agents alongside folders."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"parent_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"List children of this folder. "
|
||||
"Omit to get the full folder tree."
|
||||
),
|
||||
},
|
||||
"include_agents": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to include the list of agents inside each folder. "
|
||||
"Defaults to false."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""List folders as a flat list (by parent) or full tree."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
parent_id = kwargs.get("parent_id")
|
||||
include_agents = kwargs.get("include_agents", False)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
try:
|
||||
if parent_id:
|
||||
folders = await library_db().list_folders(
|
||||
user_id=user_id, parent_id=parent_id
|
||||
)
|
||||
raw_map = (
|
||||
await library_db().get_folder_agents_map(
|
||||
user_id, [f.id for f in folders]
|
||||
)
|
||||
if include_agents
|
||||
else None
|
||||
)
|
||||
agents_map = _to_agent_summaries_map(raw_map) if raw_map else None
|
||||
return FolderListResponse(
|
||||
message=f"Found {len(folders)} folder(s).",
|
||||
folders=[
|
||||
_folder_to_info(f, agents_map.get(f.id) if agents_map else None)
|
||||
for f in folders
|
||||
],
|
||||
count=len(folders),
|
||||
session_id=session_id,
|
||||
)
|
||||
else:
|
||||
tree = await library_db().get_folder_tree(user_id=user_id)
|
||||
all_ids = collect_tree_ids(tree)
|
||||
agents_map = None
|
||||
root_agents = None
|
||||
if include_agents:
|
||||
raw_map = await library_db().get_folder_agents_map(user_id, all_ids)
|
||||
agents_map = _to_agent_summaries_map(raw_map)
|
||||
root_agents = _to_agent_summaries(
|
||||
await library_db().get_root_agent_summaries(user_id)
|
||||
)
|
||||
return FolderListResponse(
|
||||
message=f"Found {len(all_ids)} folder(s) in your library.",
|
||||
tree=[_tree_to_info(t, agents_map) for t in tree],
|
||||
root_agents=root_agents,
|
||||
count=len(all_ids),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to list folders: {e}",
|
||||
error="list_folders_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class UpdateFolderTool(BaseTool):
|
||||
"""Tool for updating a folder's properties."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "update_folder"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Update a folder's name, icon, or color."
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": "ID of the folder to update.",
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "New name for the folder.",
|
||||
},
|
||||
"icon": {
|
||||
"type": "string",
|
||||
"description": "New icon identifier.",
|
||||
},
|
||||
"color": {
|
||||
"type": "string",
|
||||
"description": "New hex color code (#RRGGBB).",
|
||||
},
|
||||
},
|
||||
"required": ["folder_id"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Update a folder's name, icon, or color."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
folder_id = (kwargs.get("folder_id") or "").strip()
|
||||
name = kwargs.get("name")
|
||||
icon = kwargs.get("icon")
|
||||
color = kwargs.get("color")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not folder_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide a folder_id.",
|
||||
error="missing_folder_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
folder = await library_db().update_folder(
|
||||
folder_id=folder_id,
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
icon=icon,
|
||||
color=color,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to update folder: {e}",
|
||||
error="update_folder_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return FolderUpdatedResponse(
|
||||
message=f"Folder updated to '{folder.name}'.",
|
||||
folder=_folder_to_info(folder),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class MoveFolderTool(BaseTool):
|
||||
"""Tool for moving a folder to a new parent."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "move_folder"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Move a folder to a different parent folder. "
|
||||
"Set target_parent_id to null to move to root level."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": "ID of the folder to move.",
|
||||
},
|
||||
"target_parent_id": {
|
||||
"type": ["string", "null"],
|
||||
"description": (
|
||||
"ID of the new parent folder. "
|
||||
"Use null to move to root level."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["folder_id"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Move a folder to a new parent or to root level."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
folder_id = (kwargs.get("folder_id") or "").strip()
|
||||
target_parent_id = kwargs.get("target_parent_id")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not folder_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide a folder_id.",
|
||||
error="missing_folder_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
folder = await library_db().move_folder(
|
||||
folder_id=folder_id,
|
||||
user_id=user_id,
|
||||
target_parent_id=target_parent_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to move folder: {e}",
|
||||
error="move_folder_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
dest = "a subfolder" if target_parent_id else "root level"
|
||||
return FolderMovedResponse(
|
||||
message=f"Folder '{folder.name}' moved to {dest}.",
|
||||
folder=_folder_to_info(folder),
|
||||
target_parent_id=target_parent_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class DeleteFolderTool(BaseTool):
|
||||
"""Tool for deleting a folder."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "delete_folder"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Delete a folder from the user's library. "
|
||||
"Agents inside the folder are moved to root level (not deleted)."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": "ID of the folder to delete.",
|
||||
},
|
||||
},
|
||||
"required": ["folder_id"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Soft-delete a folder; agents inside are moved to root level."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
folder_id = (kwargs.get("folder_id") or "").strip()
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not folder_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide a folder_id.",
|
||||
error="missing_folder_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
await library_db().delete_folder(
|
||||
folder_id=folder_id,
|
||||
user_id=user_id,
|
||||
soft_delete=True,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to delete folder: {e}",
|
||||
error="delete_folder_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return FolderDeletedResponse(
|
||||
message="Folder deleted. Any agents inside were moved to root level.",
|
||||
folder_id=folder_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class MoveAgentsToFolderTool(BaseTool):
|
||||
"""Tool for moving agents into a folder."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "move_agents_to_folder"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Move one or more agents to a folder. "
|
||||
"Set folder_id to null to move agents to root level."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List of library agent IDs to move.",
|
||||
},
|
||||
"folder_id": {
|
||||
"type": ["string", "null"],
|
||||
"description": (
|
||||
"Target folder ID. Use null to move to root level."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["agent_ids"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Move one or more agents to a folder or to root level."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
agent_ids = kwargs.get("agent_ids", [])
|
||||
folder_id = kwargs.get("folder_id")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_ids:
|
||||
return ErrorResponse(
|
||||
message="Please provide at least one agent ID.",
|
||||
error="missing_agent_ids",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
moved = await library_db().bulk_move_agents_to_folder(
|
||||
agent_ids=agent_ids,
|
||||
folder_id=folder_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to move agents: {e}",
|
||||
error="move_agents_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
moved_ids = [a.id for a in moved]
|
||||
agent_names = [a.name for a in moved]
|
||||
dest = "the folder" if folder_id else "root level"
|
||||
names_str = (
|
||||
", ".join(agent_names) if agent_names else f"{len(agent_ids)} agent(s)"
|
||||
)
|
||||
return AgentsMovedToFolderResponse(
|
||||
message=f"Moved {names_str} to {dest}.",
|
||||
agent_ids=moved_ids,
|
||||
agent_names=agent_names,
|
||||
folder_id=folder_id,
|
||||
count=len(moved),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -0,0 +1,455 @@
|
||||
"""Tests for folder management copilot tools."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.library import model as library_model
|
||||
from backend.copilot.tools.manage_folders import (
|
||||
CreateFolderTool,
|
||||
DeleteFolderTool,
|
||||
ListFoldersTool,
|
||||
MoveAgentsToFolderTool,
|
||||
MoveFolderTool,
|
||||
UpdateFolderTool,
|
||||
)
|
||||
from backend.copilot.tools.models import (
|
||||
AgentsMovedToFolderResponse,
|
||||
ErrorResponse,
|
||||
FolderCreatedResponse,
|
||||
FolderDeletedResponse,
|
||||
FolderListResponse,
|
||||
FolderMovedResponse,
|
||||
FolderUpdatedResponse,
|
||||
)
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-folders"
|
||||
_NOW = datetime.now(UTC)
|
||||
|
||||
|
||||
def _make_folder(
|
||||
id: str = "folder-1",
|
||||
name: str = "My Folder",
|
||||
parent_id: str | None = None,
|
||||
icon: str | None = None,
|
||||
color: str | None = None,
|
||||
agent_count: int = 0,
|
||||
subfolder_count: int = 0,
|
||||
) -> library_model.LibraryFolder:
|
||||
return library_model.LibraryFolder(
|
||||
id=id,
|
||||
user_id=_TEST_USER_ID,
|
||||
name=name,
|
||||
icon=icon,
|
||||
color=color,
|
||||
parent_id=parent_id,
|
||||
created_at=_NOW,
|
||||
updated_at=_NOW,
|
||||
agent_count=agent_count,
|
||||
subfolder_count=subfolder_count,
|
||||
)
|
||||
|
||||
|
||||
def _make_tree(
|
||||
id: str = "folder-1",
|
||||
name: str = "Root",
|
||||
children: list[library_model.LibraryFolderTree] | None = None,
|
||||
) -> library_model.LibraryFolderTree:
|
||||
return library_model.LibraryFolderTree(
|
||||
id=id,
|
||||
user_id=_TEST_USER_ID,
|
||||
name=name,
|
||||
created_at=_NOW,
|
||||
updated_at=_NOW,
|
||||
children=children or [],
|
||||
)
|
||||
|
||||
|
||||
def _make_library_agent(id: str = "agent-1", name: str = "Test Agent"):
|
||||
agent = MagicMock()
|
||||
agent.id = id
|
||||
agent.name = name
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session():
|
||||
return make_session(_TEST_USER_ID)
|
||||
|
||||
|
||||
# ── CreateFolderTool ──
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_tool():
|
||||
return CreateFolderTool()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_folder_missing_name(create_tool, session):
|
||||
result = await create_tool._execute(user_id=_TEST_USER_ID, session=session, name="")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_name"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_folder_none_name(create_tool, session):
|
||||
result = await create_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, name=None
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_name"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_folder_success(create_tool, session):
|
||||
folder = _make_folder(name="New Folder")
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.create_folder = AsyncMock(return_value=folder)
|
||||
result = await create_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, name="New Folder"
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderCreatedResponse)
|
||||
assert result.folder.name == "New Folder"
|
||||
assert "New Folder" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_folder_db_error(create_tool, session):
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.create_folder = AsyncMock(
|
||||
side_effect=Exception("db down")
|
||||
)
|
||||
result = await create_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, name="Folder"
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "create_folder_failed"
|
||||
|
||||
|
||||
# ── ListFoldersTool ──
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def list_tool():
|
||||
return ListFoldersTool()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_folders_by_parent(list_tool, session):
|
||||
folders = [_make_folder(id="f1", name="A"), _make_folder(id="f2", name="B")]
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.list_folders = AsyncMock(return_value=folders)
|
||||
result = await list_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, parent_id="parent-1"
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderListResponse)
|
||||
assert result.count == 2
|
||||
assert len(result.folders) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_folders_tree(list_tool, session):
|
||||
tree = [
|
||||
_make_tree(id="r1", name="Root", children=[_make_tree(id="c1", name="Child")])
|
||||
]
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.get_folder_tree = AsyncMock(return_value=tree)
|
||||
result = await list_tool._execute(user_id=_TEST_USER_ID, session=session)
|
||||
|
||||
assert isinstance(result, FolderListResponse)
|
||||
assert result.count == 2 # root + child
|
||||
assert result.tree is not None
|
||||
assert len(result.tree) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_folders_tree_with_agents_includes_root(list_tool, session):
|
||||
tree = [_make_tree(id="r1", name="Root")]
|
||||
raw_map = {"r1": [{"id": "a1", "name": "Foldered", "description": "In folder"}]}
|
||||
root_raw = [{"id": "a2", "name": "Loose Agent", "description": "At root"}]
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.get_folder_tree = AsyncMock(return_value=tree)
|
||||
mock_lib.return_value.get_folder_agents_map = AsyncMock(return_value=raw_map)
|
||||
mock_lib.return_value.get_root_agent_summaries = AsyncMock(
|
||||
return_value=root_raw
|
||||
)
|
||||
result = await list_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, include_agents=True
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderListResponse)
|
||||
assert result.root_agents is not None
|
||||
assert len(result.root_agents) == 1
|
||||
assert result.root_agents[0].name == "Loose Agent"
|
||||
assert result.tree is not None
|
||||
assert result.tree[0].agents is not None
|
||||
assert result.tree[0].agents[0].name == "Foldered"
|
||||
mock_lib.return_value.get_root_agent_summaries.assert_awaited_once_with(
|
||||
_TEST_USER_ID
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_folders_tree_without_agents_no_root(list_tool, session):
|
||||
tree = [_make_tree(id="r1", name="Root")]
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.get_folder_tree = AsyncMock(return_value=tree)
|
||||
result = await list_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, include_agents=False
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderListResponse)
|
||||
assert result.root_agents is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_folders_db_error(list_tool, session):
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.get_folder_tree = AsyncMock(
|
||||
side_effect=Exception("timeout")
|
||||
)
|
||||
result = await list_tool._execute(user_id=_TEST_USER_ID, session=session)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "list_folders_failed"
|
||||
|
||||
|
||||
# ── UpdateFolderTool ──
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def update_tool():
|
||||
return UpdateFolderTool()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_folder_missing_id(update_tool, session):
|
||||
result = await update_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id=""
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_folder_id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_folder_none_id(update_tool, session):
|
||||
result = await update_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id=None
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_folder_id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_folder_success(update_tool, session):
|
||||
folder = _make_folder(name="Renamed")
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.update_folder = AsyncMock(return_value=folder)
|
||||
result = await update_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id="folder-1", name="Renamed"
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderUpdatedResponse)
|
||||
assert result.folder.name == "Renamed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_folder_db_error(update_tool, session):
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.update_folder = AsyncMock(
|
||||
side_effect=Exception("not found")
|
||||
)
|
||||
result = await update_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id="folder-1", name="X"
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "update_folder_failed"
|
||||
|
||||
|
||||
# ── MoveFolderTool ──
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def move_tool():
|
||||
return MoveFolderTool()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_folder_missing_id(move_tool, session):
|
||||
result = await move_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id=""
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_folder_id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_folder_to_parent(move_tool, session):
|
||||
folder = _make_folder(name="Moved")
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.move_folder = AsyncMock(return_value=folder)
|
||||
result = await move_tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
folder_id="folder-1",
|
||||
target_parent_id="parent-1",
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderMovedResponse)
|
||||
assert "subfolder" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_folder_to_root(move_tool, session):
|
||||
folder = _make_folder(name="Moved")
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.move_folder = AsyncMock(return_value=folder)
|
||||
result = await move_tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
folder_id="folder-1",
|
||||
target_parent_id=None,
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderMovedResponse)
|
||||
assert "root level" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_folder_db_error(move_tool, session):
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.move_folder = AsyncMock(side_effect=Exception("circular"))
|
||||
result = await move_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id="folder-1"
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "move_folder_failed"
|
||||
|
||||
|
||||
# ── DeleteFolderTool ──
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def delete_tool():
|
||||
return DeleteFolderTool()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_folder_missing_id(delete_tool, session):
|
||||
result = await delete_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id=""
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_folder_id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_folder_success(delete_tool, session):
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.delete_folder = AsyncMock(return_value=None)
|
||||
result = await delete_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id="folder-1"
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderDeletedResponse)
|
||||
assert result.folder_id == "folder-1"
|
||||
assert "root level" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_folder_db_error(delete_tool, session):
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.delete_folder = AsyncMock(
|
||||
side_effect=Exception("permission denied")
|
||||
)
|
||||
result = await delete_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id="folder-1"
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "delete_folder_failed"
|
||||
|
||||
|
||||
# ── MoveAgentsToFolderTool ──
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def move_agents_tool():
|
||||
return MoveAgentsToFolderTool()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_agents_missing_ids(move_agents_tool, session):
|
||||
result = await move_agents_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, agent_ids=[]
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_agent_ids"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_agents_success(move_agents_tool, session):
|
||||
agents = [
|
||||
_make_library_agent(id="a1", name="Agent Alpha"),
|
||||
_make_library_agent(id="a2", name="Agent Beta"),
|
||||
]
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.bulk_move_agents_to_folder = AsyncMock(
|
||||
return_value=agents
|
||||
)
|
||||
result = await move_agents_tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_ids=["a1", "a2"],
|
||||
folder_id="folder-1",
|
||||
)
|
||||
|
||||
assert isinstance(result, AgentsMovedToFolderResponse)
|
||||
assert result.count == 2
|
||||
assert result.agent_names == ["Agent Alpha", "Agent Beta"]
|
||||
assert "Agent Alpha" in result.message
|
||||
assert "Agent Beta" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_agents_to_root(move_agents_tool, session):
|
||||
agents = [_make_library_agent(id="a1", name="Agent One")]
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.bulk_move_agents_to_folder = AsyncMock(
|
||||
return_value=agents
|
||||
)
|
||||
result = await move_agents_tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_ids=["a1"],
|
||||
folder_id=None,
|
||||
)
|
||||
|
||||
assert isinstance(result, AgentsMovedToFolderResponse)
|
||||
assert "root level" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_agents_db_error(move_agents_tool, session):
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.bulk_move_agents_to_folder = AsyncMock(
|
||||
side_effect=Exception("folder not found")
|
||||
)
|
||||
result = await move_agents_tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_ids=["a1"],
|
||||
folder_id="bad-folder",
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "move_agents_failed"
|
||||
@@ -55,6 +55,13 @@ class ResponseType(str, Enum):
|
||||
# MCP tool types
|
||||
MCP_TOOLS_DISCOVERED = "mcp_tools_discovered"
|
||||
MCP_TOOL_OUTPUT = "mcp_tool_output"
|
||||
# Folder management types
|
||||
FOLDER_CREATED = "folder_created"
|
||||
FOLDER_LIST = "folder_list"
|
||||
FOLDER_UPDATED = "folder_updated"
|
||||
FOLDER_MOVED = "folder_moved"
|
||||
FOLDER_DELETED = "folder_deleted"
|
||||
AGENTS_MOVED_TO_FOLDER = "agents_moved_to_folder"
|
||||
|
||||
|
||||
# Base response model
|
||||
@@ -539,3 +546,82 @@ class BrowserScreenshotResponse(ToolResponseBase):
|
||||
type: ResponseType = ResponseType.BROWSER_SCREENSHOT
|
||||
file_id: str # Workspace file ID — use read_workspace_file to retrieve
|
||||
filename: str
|
||||
|
||||
|
||||
# Folder management models
|
||||
|
||||
|
||||
class FolderAgentSummary(BaseModel):
|
||||
"""Lightweight agent info for folder listings."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str = ""
|
||||
|
||||
|
||||
class FolderInfo(BaseModel):
|
||||
"""Information about a folder."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
parent_id: str | None = None
|
||||
icon: str | None = None
|
||||
color: str | None = None
|
||||
agent_count: int = 0
|
||||
subfolder_count: int = 0
|
||||
agents: list[FolderAgentSummary] | None = None
|
||||
|
||||
|
||||
class FolderTreeInfo(FolderInfo):
|
||||
"""Folder with nested children for tree display."""
|
||||
|
||||
children: list["FolderTreeInfo"] = []
|
||||
|
||||
|
||||
class FolderCreatedResponse(ToolResponseBase):
|
||||
"""Response when a folder is created."""
|
||||
|
||||
type: ResponseType = ResponseType.FOLDER_CREATED
|
||||
folder: FolderInfo
|
||||
|
||||
|
||||
class FolderListResponse(ToolResponseBase):
|
||||
"""Response for listing folders."""
|
||||
|
||||
type: ResponseType = ResponseType.FOLDER_LIST
|
||||
folders: list[FolderInfo] = Field(default_factory=list)
|
||||
tree: list[FolderTreeInfo] | None = None
|
||||
root_agents: list[FolderAgentSummary] | None = None
|
||||
count: int = 0
|
||||
|
||||
|
||||
class FolderUpdatedResponse(ToolResponseBase):
|
||||
"""Response when a folder is updated."""
|
||||
|
||||
type: ResponseType = ResponseType.FOLDER_UPDATED
|
||||
folder: FolderInfo
|
||||
|
||||
|
||||
class FolderMovedResponse(ToolResponseBase):
|
||||
"""Response when a folder is moved."""
|
||||
|
||||
type: ResponseType = ResponseType.FOLDER_MOVED
|
||||
folder: FolderInfo
|
||||
target_parent_id: str | None = None
|
||||
|
||||
|
||||
class FolderDeletedResponse(ToolResponseBase):
|
||||
"""Response when a folder is deleted."""
|
||||
|
||||
type: ResponseType = ResponseType.FOLDER_DELETED
|
||||
folder_id: str
|
||||
|
||||
|
||||
class AgentsMovedToFolderResponse(ToolResponseBase):
|
||||
"""Response when agents are moved to a folder."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENTS_MOVED_TO_FOLDER
|
||||
agent_ids: list[str]
|
||||
agent_names: list[str] = []
|
||||
folder_id: str | None = None
|
||||
count: int = 0
|
||||
|
||||
@@ -53,11 +53,15 @@ class RunMCPToolTool(BaseTool):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Connect to an MCP (Model Context Protocol) server to discover and execute its tools. "
|
||||
"Call with just `server_url` to see available tools. "
|
||||
"Then call again with `server_url`, `tool_name`, and `tool_arguments` to execute. "
|
||||
"If the server requires authentication, the user will be prompted to connect it. "
|
||||
"Find MCP servers at https://registry.modelcontextprotocol.io/ — hundreds of integrations "
|
||||
"including GitHub, Postgres, Slack, filesystem, and more."
|
||||
"Two-step workflow: (1) Call with just `server_url` to discover available tools. "
|
||||
"(2) Call again with `server_url`, `tool_name`, and `tool_arguments` to execute. "
|
||||
"Known hosted servers (use directly): Notion (https://mcp.notion.com/mcp), "
|
||||
"Linear (https://mcp.linear.app/mcp), Stripe (https://mcp.stripe.com), "
|
||||
"Intercom (https://mcp.intercom.com/mcp), Cloudflare (https://mcp.cloudflare.com/mcp), "
|
||||
"Atlassian/Jira (https://mcp.atlassian.com/mcp). "
|
||||
"For other services, search the MCP registry at https://registry.modelcontextprotocol.io/. "
|
||||
"Authentication: If the server requires credentials, user will be prompted to complete the MCP credential setup flow."
|
||||
"Once connected and user confirms, retry the same call immediately."
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -4,11 +4,20 @@ from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cas
|
||||
|
||||
from backend.api.features.library.db import (
|
||||
add_store_agent_to_library,
|
||||
bulk_move_agents_to_folder,
|
||||
create_folder,
|
||||
create_graph_in_library,
|
||||
create_library_agent,
|
||||
delete_folder,
|
||||
get_folder_agents_map,
|
||||
get_folder_tree,
|
||||
get_library_agent,
|
||||
get_library_agent_by_graph_id,
|
||||
get_root_agent_summaries,
|
||||
list_folders,
|
||||
list_library_agents,
|
||||
move_folder,
|
||||
update_folder,
|
||||
update_graph_in_library,
|
||||
)
|
||||
from backend.api.features.store.db import (
|
||||
@@ -260,6 +269,16 @@ class DatabaseManager(AppService):
|
||||
update_graph_in_library = _(update_graph_in_library)
|
||||
validate_graph_execution_permissions = _(validate_graph_execution_permissions)
|
||||
|
||||
create_folder = _(create_folder)
|
||||
list_folders = _(list_folders)
|
||||
get_folder_tree = _(get_folder_tree)
|
||||
update_folder = _(update_folder)
|
||||
move_folder = _(move_folder)
|
||||
delete_folder = _(delete_folder)
|
||||
bulk_move_agents_to_folder = _(bulk_move_agents_to_folder)
|
||||
get_folder_agents_map = _(get_folder_agents_map)
|
||||
get_root_agent_summaries = _(get_root_agent_summaries)
|
||||
|
||||
# ============ Onboarding ============ #
|
||||
increment_onboarding_runs = _(increment_onboarding_runs)
|
||||
|
||||
@@ -305,6 +324,7 @@ class DatabaseManager(AppService):
|
||||
delete_chat_session = _(chat_db.delete_chat_session)
|
||||
get_next_sequence = _(chat_db.get_next_sequence)
|
||||
update_tool_message_content = _(chat_db.update_tool_message_content)
|
||||
update_chat_session_title = _(chat_db.update_chat_session_title)
|
||||
|
||||
|
||||
class DatabaseManagerClient(AppServiceClient):
|
||||
@@ -433,6 +453,17 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
update_graph_in_library = d.update_graph_in_library
|
||||
validate_graph_execution_permissions = d.validate_graph_execution_permissions
|
||||
|
||||
# ============ Library Folders ============ #
|
||||
create_folder = d.create_folder
|
||||
list_folders = d.list_folders
|
||||
get_folder_tree = d.get_folder_tree
|
||||
update_folder = d.update_folder
|
||||
move_folder = d.move_folder
|
||||
delete_folder = d.delete_folder
|
||||
bulk_move_agents_to_folder = d.bulk_move_agents_to_folder
|
||||
get_folder_agents_map = d.get_folder_agents_map
|
||||
get_root_agent_summaries = d.get_root_agent_summaries
|
||||
|
||||
# ============ Onboarding ============ #
|
||||
increment_onboarding_runs = d.increment_onboarding_runs
|
||||
|
||||
@@ -475,3 +506,4 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
delete_chat_session = d.delete_chat_session
|
||||
get_next_sequence = d.get_next_sequence
|
||||
update_tool_message_content = d.update_tool_message_content
|
||||
update_chat_session_title = d.update_chat_session_title
|
||||
|
||||
@@ -344,7 +344,7 @@ class GraphExecution(GraphExecutionMeta):
|
||||
),
|
||||
**{
|
||||
# input from webhook-triggered block
|
||||
"payload": exec.input_data["payload"]
|
||||
"payload": exec.input_data.get("payload")
|
||||
for exec in complete_node_executions
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import fastapi.exceptions
|
||||
import prisma
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
@@ -250,8 +251,8 @@ async def test_clean_graph(server: SpinTestServer):
|
||||
"_test_id": "node_with_secrets",
|
||||
"input": "normal_value",
|
||||
"control_test_input": "should be preserved",
|
||||
"api_key": "secret_api_key_123", # Should be filtered
|
||||
"password": "secret_password_456", # Should be filtered
|
||||
"api_key": "secret_api_key_123", # Should be filtered # pragma: allowlist secret # noqa
|
||||
"password": "secret_password_456", # Should be filtered # pragma: allowlist secret # noqa
|
||||
"token": "secret_token_789", # Should be filtered
|
||||
"credentials": { # Should be filtered
|
||||
"id": "fake-github-credentials-id",
|
||||
@@ -354,9 +355,24 @@ async def test_access_store_listing_graph(server: SpinTestServer):
|
||||
create_graph, DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
# Ensure the default user has a Profile (required for store submissions)
|
||||
existing_profile = await prisma.models.Profile.prisma().find_first(
|
||||
where={"userId": DEFAULT_USER_ID}
|
||||
)
|
||||
if not existing_profile:
|
||||
await prisma.models.Profile.prisma().create(
|
||||
data=prisma.types.ProfileCreateInput(
|
||||
userId=DEFAULT_USER_ID,
|
||||
name="Default User",
|
||||
username=f"default-user-{DEFAULT_USER_ID[:8]}",
|
||||
description="Default test user profile",
|
||||
links=[],
|
||||
)
|
||||
)
|
||||
|
||||
store_submission_request = store.StoreSubmissionRequest(
|
||||
agent_id=created_graph.id,
|
||||
agent_version=created_graph.version,
|
||||
graph_id=created_graph.id,
|
||||
graph_version=created_graph.version,
|
||||
slug=created_graph.id,
|
||||
name="Test name",
|
||||
sub_heading="Test sub heading",
|
||||
@@ -385,8 +401,8 @@ async def test_access_store_listing_graph(server: SpinTestServer):
|
||||
assert False, "Failed to create store listing"
|
||||
|
||||
slv_id = (
|
||||
store_listing.store_listing_version_id
|
||||
if store_listing.store_listing_version_id is not None
|
||||
store_listing.listing_version_id
|
||||
if store_listing.listing_version_id is not None
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
@@ -184,17 +184,17 @@ async def find_webhook_by_credentials_and_props(
|
||||
credentials_id: str,
|
||||
webhook_type: str,
|
||||
resource: str,
|
||||
events: Optional[list[str]],
|
||||
events: list[str] | None = None,
|
||||
) -> Webhook | None:
|
||||
webhook = await IntegrationWebhook.prisma().find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"credentialsId": credentials_id,
|
||||
"webhookType": webhook_type,
|
||||
"resource": resource,
|
||||
**({"events": {"has_every": events}} if events else {}),
|
||||
},
|
||||
)
|
||||
where: IntegrationWebhookWhereInput = {
|
||||
"userId": user_id,
|
||||
"credentialsId": credentials_id,
|
||||
"webhookType": webhook_type,
|
||||
"resource": resource,
|
||||
}
|
||||
if events is not None:
|
||||
where["events"] = {"has_every": events}
|
||||
webhook = await IntegrationWebhook.prisma().find_first(where=where)
|
||||
return Webhook.from_db(webhook) if webhook else None
|
||||
|
||||
|
||||
|
||||
601
autogpt_platform/backend/backend/data/invited_user.py
Normal file
601
autogpt_platform/backend/backend/data/invited_user.py
Normal file
@@ -0,0 +1,601 @@
|
||||
import asyncio
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
from pydantic import BaseModel, EmailStr, TypeAdapter, ValidationError
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.data.model import User
|
||||
from backend.data.tally import get_business_understanding_input_from_tally
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
merge_business_understanding_data,
|
||||
)
|
||||
from backend.util.exceptions import (
|
||||
NotAuthorizedError,
|
||||
NotFoundError,
|
||||
PreconditionFailed,
|
||||
)
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_tally_seed_tasks: set[asyncio.Task] = set()
|
||||
_email_adapter = TypeAdapter(EmailStr)
|
||||
|
||||
MAX_BULK_INVITE_FILE_BYTES = 1024 * 1024
|
||||
MAX_BULK_INVITE_ROWS = 500
|
||||
|
||||
|
||||
class InvitedUserRecord(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
status: prisma.enums.InvitedUserStatus
|
||||
auth_user_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
tally_understanding: Optional[dict[str, Any]] = None
|
||||
tally_status: prisma.enums.TallyComputationStatus
|
||||
tally_computed_at: Optional[datetime] = None
|
||||
tally_error: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, invited_user: "prisma.models.InvitedUser") -> "InvitedUserRecord":
|
||||
payload = (
|
||||
invited_user.tallyUnderstanding
|
||||
if isinstance(invited_user.tallyUnderstanding, dict)
|
||||
else None
|
||||
)
|
||||
return cls(
|
||||
id=invited_user.id,
|
||||
email=invited_user.email,
|
||||
status=invited_user.status,
|
||||
auth_user_id=invited_user.authUserId,
|
||||
name=invited_user.name,
|
||||
tally_understanding=payload,
|
||||
tally_status=invited_user.tallyStatus,
|
||||
tally_computed_at=invited_user.tallyComputedAt,
|
||||
tally_error=invited_user.tallyError,
|
||||
created_at=invited_user.createdAt,
|
||||
updated_at=invited_user.updatedAt,
|
||||
)
|
||||
|
||||
|
||||
class BulkInvitedUserRowResult(BaseModel):
|
||||
row_number: int
|
||||
email: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
status: Literal["CREATED", "SKIPPED", "ERROR"]
|
||||
message: str
|
||||
invited_user: Optional[InvitedUserRecord] = None
|
||||
|
||||
|
||||
class BulkInvitedUsersResult(BaseModel):
|
||||
created_count: int
|
||||
skipped_count: int
|
||||
error_count: int
|
||||
results: list[BulkInvitedUserRowResult]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ParsedInviteRow:
|
||||
row_number: int
|
||||
email: str
|
||||
name: Optional[str]
|
||||
|
||||
|
||||
def normalize_email(email: str) -> str:
|
||||
return email.strip().lower()
|
||||
|
||||
|
||||
def _normalize_name(name: Optional[str]) -> Optional[str]:
|
||||
if name is None:
|
||||
return None
|
||||
normalized = name.strip()
|
||||
return normalized or None
|
||||
|
||||
|
||||
def _default_profile_name(email: str, preferred_name: Optional[str]) -> str:
|
||||
if preferred_name:
|
||||
return preferred_name
|
||||
local_part = email.split("@", 1)[0].strip()
|
||||
return local_part or "user"
|
||||
|
||||
|
||||
def _sanitize_username_base(email: str) -> str:
|
||||
local_part = email.split("@", 1)[0].lower()
|
||||
sanitized = re.sub(r"[^a-z0-9-]", "", local_part)
|
||||
sanitized = sanitized.strip("-")
|
||||
return sanitized[:40] or "user"
|
||||
|
||||
|
||||
async def _generate_unique_profile_username(email: str, tx) -> str:
|
||||
base = _sanitize_username_base(email)
|
||||
|
||||
for attempt in range(10):
|
||||
candidate = base if attempt == 0 else f"{base}-{uuid4().hex[:6]}"
|
||||
existing = await prisma.models.Profile.prisma(tx).find_unique(
|
||||
where={"username": candidate}
|
||||
)
|
||||
if existing is None:
|
||||
return candidate
|
||||
|
||||
raise RuntimeError(f"Unable to generate unique username for {email}")
|
||||
|
||||
|
||||
async def _ensure_default_profile(
|
||||
user_id: str,
|
||||
email: str,
|
||||
preferred_name: Optional[str],
|
||||
tx,
|
||||
) -> None:
|
||||
existing_profile = await prisma.models.Profile.prisma(tx).find_unique(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
if existing_profile is not None:
|
||||
return
|
||||
|
||||
username = await _generate_unique_profile_username(email, tx)
|
||||
await prisma.models.Profile.prisma(tx).create(
|
||||
data=prisma.types.ProfileCreateInput(
|
||||
userId=user_id,
|
||||
name=_default_profile_name(email, preferred_name),
|
||||
username=username,
|
||||
description="I'm new here",
|
||||
links=[],
|
||||
avatarUrl="",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _ensure_default_onboarding(user_id: str, tx) -> None:
|
||||
await prisma.models.UserOnboarding.prisma(tx).upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": prisma.types.UserOnboardingCreateInput(userId=user_id),
|
||||
"update": {},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _apply_tally_understanding(
|
||||
user_id: str,
|
||||
invited_user: "prisma.models.InvitedUser",
|
||||
tx,
|
||||
) -> None:
|
||||
if not isinstance(invited_user.tallyUnderstanding, dict):
|
||||
return
|
||||
|
||||
input_data = BusinessUnderstandingInput.model_validate(
|
||||
invited_user.tallyUnderstanding
|
||||
)
|
||||
payload = merge_business_understanding_data({}, input_data)
|
||||
await prisma.models.CoPilotUnderstanding.prisma(tx).upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "data": SafeJson(payload)},
|
||||
"update": {"data": SafeJson(payload)},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def list_invited_users() -> list[InvitedUserRecord]:
|
||||
invited_users = await prisma.models.InvitedUser.prisma().find_many(
|
||||
order={"createdAt": "desc"}
|
||||
)
|
||||
return [InvitedUserRecord.from_db(invited_user) for invited_user in invited_users]
|
||||
|
||||
|
||||
async def create_invited_user(
|
||||
email: str, name: Optional[str] = None
|
||||
) -> InvitedUserRecord:
|
||||
normalized_email = normalize_email(email)
|
||||
normalized_name = _normalize_name(name)
|
||||
|
||||
existing_user = await prisma.models.User.prisma().find_unique(
|
||||
where={"email": normalized_email}
|
||||
)
|
||||
if existing_user is not None:
|
||||
raise PreconditionFailed("An active user with this email already exists")
|
||||
|
||||
existing_invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"email": normalized_email}
|
||||
)
|
||||
if existing_invited_user is not None:
|
||||
raise PreconditionFailed("An invited user with this email already exists")
|
||||
|
||||
invited_user = await prisma.models.InvitedUser.prisma().create(
|
||||
data={
|
||||
"email": normalized_email,
|
||||
"name": normalized_name,
|
||||
"status": prisma.enums.InvitedUserStatus.INVITED,
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
|
||||
}
|
||||
)
|
||||
schedule_invited_user_tally_precompute(invited_user.id)
|
||||
return InvitedUserRecord.from_db(invited_user)
|
||||
|
||||
|
||||
async def revoke_invited_user(invited_user_id: str) -> InvitedUserRecord:
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"id": invited_user_id}
|
||||
)
|
||||
if invited_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.CLAIMED:
|
||||
raise PreconditionFailed("Claimed invited users cannot be revoked")
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
|
||||
return InvitedUserRecord.from_db(invited_user)
|
||||
|
||||
revoked_user = await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={"status": prisma.enums.InvitedUserStatus.REVOKED},
|
||||
)
|
||||
if revoked_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
return InvitedUserRecord.from_db(revoked_user)
|
||||
|
||||
|
||||
async def retry_invited_user_tally(invited_user_id: str) -> InvitedUserRecord:
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"id": invited_user_id}
|
||||
)
|
||||
if invited_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
|
||||
raise PreconditionFailed("Revoked invited users cannot retry Tally seeding")
|
||||
|
||||
refreshed_user = await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyUnderstanding": None,
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
|
||||
"tallyComputedAt": None,
|
||||
"tallyError": None,
|
||||
},
|
||||
)
|
||||
if refreshed_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
schedule_invited_user_tally_precompute(invited_user_id)
|
||||
return InvitedUserRecord.from_db(refreshed_user)
|
||||
|
||||
|
||||
def _decode_bulk_invite_file(content: bytes) -> str:
|
||||
if len(content) > MAX_BULK_INVITE_FILE_BYTES:
|
||||
raise ValueError("Invite file exceeds the maximum size of 1 MB")
|
||||
|
||||
try:
|
||||
return content.decode("utf-8-sig")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise ValueError("Invite file must be UTF-8 encoded") from exc
|
||||
|
||||
|
||||
def _parse_bulk_invite_csv(text: str) -> list[_ParsedInviteRow]:
|
||||
indexed_rows: list[tuple[int, list[str]]] = []
|
||||
|
||||
for row_number, row in enumerate(csv.reader(io.StringIO(text)), start=1):
|
||||
normalized_row = [cell.strip() for cell in row]
|
||||
if any(normalized_row):
|
||||
indexed_rows.append((row_number, normalized_row))
|
||||
|
||||
if not indexed_rows:
|
||||
return []
|
||||
|
||||
header = [cell.lower() for cell in indexed_rows[0][1]]
|
||||
has_header = "email" in header
|
||||
email_index = header.index("email") if has_header else 0
|
||||
name_index = header.index("name") if has_header and "name" in header else 1
|
||||
data_rows = indexed_rows[1:] if has_header else indexed_rows
|
||||
|
||||
parsed_rows: list[_ParsedInviteRow] = []
|
||||
for row_number, row in data_rows:
|
||||
email = row[email_index].strip() if len(row) > email_index else ""
|
||||
name = row[name_index].strip() if len(row) > name_index else ""
|
||||
parsed_rows.append(
|
||||
_ParsedInviteRow(
|
||||
row_number=row_number,
|
||||
email=email,
|
||||
name=name or None,
|
||||
)
|
||||
)
|
||||
|
||||
return parsed_rows
|
||||
|
||||
|
||||
def _parse_bulk_invite_text(text: str) -> list[_ParsedInviteRow]:
|
||||
parsed_rows: list[_ParsedInviteRow] = []
|
||||
|
||||
for row_number, raw_line in enumerate(text.splitlines(), start=1):
|
||||
line = raw_line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
parsed_rows.append(
|
||||
_ParsedInviteRow(
|
||||
row_number=row_number,
|
||||
email=line,
|
||||
name=None,
|
||||
)
|
||||
)
|
||||
|
||||
return parsed_rows
|
||||
|
||||
|
||||
def _parse_bulk_invite_file(
|
||||
filename: Optional[str],
|
||||
content: bytes,
|
||||
) -> list[_ParsedInviteRow]:
|
||||
text = _decode_bulk_invite_file(content)
|
||||
file_name = filename.lower() if filename else ""
|
||||
parsed_rows = (
|
||||
_parse_bulk_invite_csv(text)
|
||||
if file_name.endswith(".csv")
|
||||
else _parse_bulk_invite_text(text)
|
||||
)
|
||||
|
||||
if not parsed_rows:
|
||||
raise ValueError("Invite file did not contain any emails")
|
||||
|
||||
if len(parsed_rows) > MAX_BULK_INVITE_ROWS:
|
||||
raise ValueError(
|
||||
f"Invite file contains too many rows. Maximum supported rows: {MAX_BULK_INVITE_ROWS}"
|
||||
)
|
||||
|
||||
return parsed_rows
|
||||
|
||||
|
||||
async def bulk_create_invited_users_from_file(
|
||||
filename: Optional[str],
|
||||
content: bytes,
|
||||
) -> BulkInvitedUsersResult:
|
||||
parsed_rows = _parse_bulk_invite_file(filename, content)
|
||||
|
||||
created_count = 0
|
||||
skipped_count = 0
|
||||
error_count = 0
|
||||
results: list[BulkInvitedUserRowResult] = []
|
||||
seen_emails: set[str] = set()
|
||||
|
||||
for row in parsed_rows:
|
||||
row_name = _normalize_name(row.name)
|
||||
|
||||
try:
|
||||
validated_email = _email_adapter.validate_python(row.email)
|
||||
except ValidationError:
|
||||
error_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=row.email or None,
|
||||
name=row_name,
|
||||
status="ERROR",
|
||||
message="Invalid email address",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
normalized_email = normalize_email(str(validated_email))
|
||||
if normalized_email in seen_emails:
|
||||
skipped_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="SKIPPED",
|
||||
message="Duplicate email in upload file",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
seen_emails.add(normalized_email)
|
||||
|
||||
try:
|
||||
invited_user = await create_invited_user(normalized_email, row_name)
|
||||
except PreconditionFailed as exc:
|
||||
skipped_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="SKIPPED",
|
||||
message=str(exc),
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to create bulk invite for %s from row %s",
|
||||
normalized_email,
|
||||
row.row_number,
|
||||
)
|
||||
error_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="ERROR",
|
||||
message="Unexpected error creating invite",
|
||||
)
|
||||
)
|
||||
else:
|
||||
created_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="CREATED",
|
||||
message="Invite created",
|
||||
invited_user=invited_user,
|
||||
)
|
||||
)
|
||||
|
||||
return BulkInvitedUsersResult(
|
||||
created_count=created_count,
|
||||
skipped_count=skipped_count,
|
||||
error_count=error_count,
|
||||
results=results,
|
||||
)
|
||||
|
||||
|
||||
async def _compute_invited_user_tally_seed(invited_user_id: str) -> None:
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"id": invited_user_id}
|
||||
)
|
||||
if invited_user is None:
|
||||
return
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
|
||||
return
|
||||
|
||||
await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.RUNNING,
|
||||
"tallyError": None,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
input_data = await get_business_understanding_input_from_tally(
|
||||
invited_user.email,
|
||||
require_api_key=True,
|
||||
)
|
||||
payload = (
|
||||
SafeJson(input_data.model_dump(exclude_none=True))
|
||||
if input_data is not None
|
||||
else None
|
||||
)
|
||||
await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyUnderstanding": payload,
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.READY,
|
||||
"tallyComputedAt": datetime.now(timezone.utc),
|
||||
"tallyError": None,
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"Failed to compute Tally understanding for invited user %s",
|
||||
invited_user_id,
|
||||
)
|
||||
await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.FAILED,
|
||||
"tallyError": str(exc),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def schedule_invited_user_tally_precompute(invited_user_id: str) -> None:
|
||||
task = asyncio.create_task(_compute_invited_user_tally_seed(invited_user_id))
|
||||
_tally_seed_tasks.add(task)
|
||||
task.add_done_callback(_tally_seed_tasks.discard)
|
||||
|
||||
|
||||
async def get_or_activate_user(user_data: dict) -> User:
|
||||
auth_user_id = user_data.get("sub")
|
||||
if not auth_user_id:
|
||||
raise NotAuthorizedError("User ID not found in token")
|
||||
|
||||
auth_email = user_data.get("email")
|
||||
if not auth_email:
|
||||
raise NotAuthorizedError("Email not found in token")
|
||||
|
||||
normalized_email = normalize_email(auth_email)
|
||||
user_metadata = user_data.get("user_metadata")
|
||||
metadata_name = (
|
||||
user_metadata.get("name") if isinstance(user_metadata, dict) else None
|
||||
)
|
||||
|
||||
existing_user = await prisma.models.User.prisma().find_unique(
|
||||
where={"id": auth_user_id}
|
||||
)
|
||||
if existing_user is not None:
|
||||
return User.from_db(existing_user)
|
||||
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"email": normalized_email}
|
||||
)
|
||||
if invited_user is None:
|
||||
raise NotAuthorizedError("Your email is not allowed to access the platform")
|
||||
|
||||
if invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
|
||||
raise NotAuthorizedError("Your invitation is no longer active")
|
||||
|
||||
async with transaction() as tx:
|
||||
current_user = await prisma.models.User.prisma(tx).find_unique(
|
||||
where={"id": auth_user_id}
|
||||
)
|
||||
if current_user is not None:
|
||||
return User.from_db(current_user)
|
||||
|
||||
current_invited_user = await prisma.models.InvitedUser.prisma(tx).find_unique(
|
||||
where={"email": normalized_email}
|
||||
)
|
||||
if current_invited_user is None:
|
||||
raise NotAuthorizedError("Your email is not allowed to access the platform")
|
||||
|
||||
if current_invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
|
||||
raise NotAuthorizedError("Your invitation is no longer active")
|
||||
|
||||
if current_invited_user.authUserId not in (None, auth_user_id):
|
||||
raise NotAuthorizedError("Your invitation has already been claimed")
|
||||
|
||||
preferred_name = current_invited_user.name or _normalize_name(metadata_name)
|
||||
await prisma.models.User.prisma(tx).create(
|
||||
data=prisma.types.UserCreateInput(
|
||||
id=auth_user_id,
|
||||
email=normalized_email,
|
||||
name=preferred_name,
|
||||
)
|
||||
)
|
||||
|
||||
await prisma.models.InvitedUser.prisma(tx).update(
|
||||
where={"id": current_invited_user.id},
|
||||
data={
|
||||
"status": prisma.enums.InvitedUserStatus.CLAIMED,
|
||||
"authUserId": auth_user_id,
|
||||
},
|
||||
)
|
||||
|
||||
await _ensure_default_profile(
|
||||
auth_user_id,
|
||||
normalized_email,
|
||||
preferred_name,
|
||||
tx,
|
||||
)
|
||||
await _ensure_default_onboarding(auth_user_id, tx)
|
||||
await _apply_tally_understanding(auth_user_id, current_invited_user, tx)
|
||||
|
||||
from backend.data.user import get_user_by_email, get_user_by_id
|
||||
|
||||
get_user_by_id.cache_delete(auth_user_id)
|
||||
get_user_by_email.cache_delete(normalized_email)
|
||||
|
||||
activated_user = await prisma.models.User.prisma().find_unique(
|
||||
where={"id": auth_user_id}
|
||||
)
|
||||
if activated_user is None:
|
||||
raise RuntimeError(
|
||||
f"Activated user {auth_user_id} was not found after creation"
|
||||
)
|
||||
|
||||
return User.from_db(activated_user)
|
||||
297
autogpt_platform/backend/backend/data/invited_user_test.py
Normal file
297
autogpt_platform/backend/backend/data/invited_user_test.py
Normal file
@@ -0,0 +1,297 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import prisma.enums
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.util.exceptions import NotAuthorizedError, PreconditionFailed
|
||||
|
||||
from .invited_user import (
|
||||
bulk_create_invited_users_from_file,
|
||||
create_invited_user,
|
||||
get_or_activate_user,
|
||||
retry_invited_user_tally,
|
||||
)
|
||||
|
||||
|
||||
def _invited_user_db_record(
|
||||
*,
|
||||
status: prisma.enums.InvitedUserStatus = prisma.enums.InvitedUserStatus.INVITED,
|
||||
tally_understanding: dict | None = None,
|
||||
):
|
||||
now = datetime.now(timezone.utc)
|
||||
return SimpleNamespace(
|
||||
id="invite-1",
|
||||
email="invited@example.com",
|
||||
status=status,
|
||||
authUserId=None,
|
||||
name="Invited User",
|
||||
tallyUnderstanding=tally_understanding,
|
||||
tallyStatus=prisma.enums.TallyComputationStatus.PENDING,
|
||||
tallyComputedAt=None,
|
||||
tallyError=None,
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
)
|
||||
|
||||
|
||||
def _user_db_record():
|
||||
now = datetime.now(timezone.utc)
|
||||
return SimpleNamespace(
|
||||
id="auth-user-1",
|
||||
email="invited@example.com",
|
||||
emailVerified=True,
|
||||
name="Invited User",
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
metadata={},
|
||||
integrations="",
|
||||
stripeCustomerId=None,
|
||||
topUpConfig=None,
|
||||
maxEmailsPerDay=3,
|
||||
notifyOnAgentRun=True,
|
||||
notifyOnZeroBalance=True,
|
||||
notifyOnLowBalance=True,
|
||||
notifyOnBlockExecutionFailed=True,
|
||||
notifyOnContinuousAgentError=True,
|
||||
notifyOnDailySummary=True,
|
||||
notifyOnWeeklySummary=True,
|
||||
notifyOnMonthlySummary=True,
|
||||
notifyOnAgentApproved=True,
|
||||
notifyOnAgentRejected=True,
|
||||
timezone="not-set",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_invited_user_rejects_existing_active_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
user_repo = Mock()
|
||||
user_repo.find_unique = AsyncMock(return_value=_user_db_record())
|
||||
invited_user_repo = Mock()
|
||||
invited_user_repo.find_unique = AsyncMock()
|
||||
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=invited_user_repo,
|
||||
)
|
||||
|
||||
with pytest.raises(PreconditionFailed):
|
||||
await create_invited_user("Invited@example.com")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_invited_user_schedules_tally_seed(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
user_repo = Mock()
|
||||
user_repo.find_unique = AsyncMock(return_value=None)
|
||||
invited_user_repo = Mock()
|
||||
invited_user_repo.find_unique = AsyncMock(return_value=None)
|
||||
invited_user_repo.create = AsyncMock(return_value=_invited_user_db_record())
|
||||
schedule = mocker.patch(
|
||||
"backend.data.invited_user.schedule_invited_user_tally_precompute"
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=invited_user_repo,
|
||||
)
|
||||
|
||||
invited_user = await create_invited_user("Invited@example.com", "Invited User")
|
||||
|
||||
assert invited_user.email == "invited@example.com"
|
||||
invited_user_repo.create.assert_awaited_once()
|
||||
schedule.assert_called_once_with("invite-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_invited_user_tally_resets_state_and_schedules(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
invited_user_repo = Mock()
|
||||
invited_user_repo.find_unique = AsyncMock(return_value=_invited_user_db_record())
|
||||
invited_user_repo.update = AsyncMock(return_value=_invited_user_db_record())
|
||||
schedule = mocker.patch(
|
||||
"backend.data.invited_user.schedule_invited_user_tally_precompute"
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=invited_user_repo,
|
||||
)
|
||||
|
||||
invited_user = await retry_invited_user_tally("invite-1")
|
||||
|
||||
assert invited_user.id == "invite-1"
|
||||
invited_user_repo.update.assert_awaited_once()
|
||||
schedule.assert_called_once_with("invite-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_activate_user_requires_invite(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
user_repo = Mock()
|
||||
user_repo.find_unique = AsyncMock(return_value=None)
|
||||
invited_user_repo = Mock()
|
||||
invited_user_repo.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=invited_user_repo,
|
||||
)
|
||||
|
||||
with pytest.raises(NotAuthorizedError):
|
||||
await get_or_activate_user(
|
||||
{"sub": "auth-user-1", "email": "invited@example.com"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_activate_user_creates_user_from_invite(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
tx = object()
|
||||
invited_user = _invited_user_db_record(
|
||||
tally_understanding={"user_name": "Invited User", "industry": "Automation"}
|
||||
)
|
||||
created_user = _user_db_record()
|
||||
|
||||
outside_user_repo = Mock()
|
||||
outside_user_repo.find_unique = AsyncMock(side_effect=[None, created_user])
|
||||
|
||||
inside_user_repo = Mock()
|
||||
inside_user_repo.find_unique = AsyncMock(return_value=None)
|
||||
inside_user_repo.create = AsyncMock(return_value=created_user)
|
||||
|
||||
outside_invited_repo = Mock()
|
||||
outside_invited_repo.find_unique = AsyncMock(return_value=invited_user)
|
||||
|
||||
inside_invited_repo = Mock()
|
||||
inside_invited_repo.find_unique = AsyncMock(return_value=invited_user)
|
||||
inside_invited_repo.update = AsyncMock(return_value=invited_user)
|
||||
|
||||
def user_prisma(client=None):
|
||||
return inside_user_repo if client is tx else outside_user_repo
|
||||
|
||||
def invited_user_prisma(client=None):
|
||||
return inside_invited_repo if client is tx else outside_invited_repo
|
||||
|
||||
@asynccontextmanager
|
||||
async def fake_transaction():
|
||||
yield tx
|
||||
|
||||
ensure_profile = mocker.patch(
|
||||
"backend.data.invited_user._ensure_default_profile",
|
||||
AsyncMock(),
|
||||
)
|
||||
ensure_onboarding = mocker.patch(
|
||||
"backend.data.invited_user._ensure_default_onboarding",
|
||||
AsyncMock(),
|
||||
)
|
||||
apply_tally = mocker.patch(
|
||||
"backend.data.invited_user._apply_tally_understanding",
|
||||
AsyncMock(),
|
||||
)
|
||||
mocker.patch("backend.data.invited_user.transaction", fake_transaction)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.User.prisma", side_effect=user_prisma
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
side_effect=invited_user_prisma,
|
||||
)
|
||||
mocker.patch("backend.data.user.get_user_by_id.cache_delete", Mock())
|
||||
mocker.patch("backend.data.user.get_user_by_email.cache_delete", Mock())
|
||||
|
||||
user = await get_or_activate_user(
|
||||
{
|
||||
"sub": "auth-user-1",
|
||||
"email": "Invited@example.com",
|
||||
"user_metadata": {"name": "Invited User"},
|
||||
}
|
||||
)
|
||||
|
||||
assert user.id == "auth-user-1"
|
||||
inside_user_repo.create.assert_awaited_once()
|
||||
inside_invited_repo.update.assert_awaited_once()
|
||||
ensure_profile.assert_awaited_once()
|
||||
ensure_onboarding.assert_awaited_once_with("auth-user-1", tx)
|
||||
apply_tally.assert_awaited_once_with("auth-user-1", invited_user, tx)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_create_invited_users_from_text_file(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
create_invited = mocker.patch(
|
||||
"backend.data.invited_user.create_invited_user",
|
||||
AsyncMock(
|
||||
side_effect=[
|
||||
_invited_user_db_record(),
|
||||
_invited_user_db_record(),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
result = await bulk_create_invited_users_from_file(
|
||||
"invites.txt",
|
||||
b"Invited@example.com\nsecond@example.com\n",
|
||||
)
|
||||
|
||||
assert result.created_count == 2
|
||||
assert result.skipped_count == 0
|
||||
assert result.error_count == 0
|
||||
assert [row.status for row in result.results] == ["CREATED", "CREATED"]
|
||||
assert create_invited.await_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_create_invited_users_handles_csv_duplicates_and_invalid_rows(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
create_invited = mocker.patch(
|
||||
"backend.data.invited_user.create_invited_user",
|
||||
AsyncMock(
|
||||
side_effect=[
|
||||
_invited_user_db_record(),
|
||||
PreconditionFailed("An invited user with this email already exists"),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
result = await bulk_create_invited_users_from_file(
|
||||
"invites.csv",
|
||||
(
|
||||
"email,name\n"
|
||||
"valid@example.com,Valid User\n"
|
||||
"not-an-email,Bad Row\n"
|
||||
"valid@example.com,Duplicate In File\n"
|
||||
"existing@example.com,Existing User\n"
|
||||
).encode("utf-8"),
|
||||
)
|
||||
|
||||
assert result.created_count == 1
|
||||
assert result.skipped_count == 2
|
||||
assert result.error_count == 1
|
||||
assert [row.status for row in result.results] == [
|
||||
"CREATED",
|
||||
"ERROR",
|
||||
"SKIPPED",
|
||||
"SKIPPED",
|
||||
]
|
||||
assert create_invited.await_count == 2
|
||||
@@ -13,7 +13,14 @@ from prisma.types import (
|
||||
)
|
||||
|
||||
# from backend.notifications.models import NotificationEvent
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
EmailStr,
|
||||
Field,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from backend.util.exceptions import DatabaseError
|
||||
from backend.util.json import SafeJson
|
||||
@@ -175,10 +182,26 @@ class RefundRequestData(BaseNotificationData):
|
||||
balance: int
|
||||
|
||||
|
||||
class AgentApprovalData(BaseNotificationData):
|
||||
class _LegacyAgentFieldsMixin:
|
||||
"""Temporary patch to handle existing queued payloads"""
|
||||
|
||||
# FIXME: remove in next release
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _map_legacy_agent_fields(cls, values: Any):
|
||||
if isinstance(values, dict):
|
||||
if "graph_id" not in values and "agent_id" in values:
|
||||
values["graph_id"] = values.pop("agent_id")
|
||||
if "graph_version" not in values and "agent_version" in values:
|
||||
values["graph_version"] = values.pop("agent_version")
|
||||
return values
|
||||
|
||||
|
||||
class AgentApprovalData(_LegacyAgentFieldsMixin, BaseNotificationData):
|
||||
agent_name: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
reviewer_name: str
|
||||
reviewer_email: str
|
||||
comments: str
|
||||
@@ -193,10 +216,10 @@ class AgentApprovalData(BaseNotificationData):
|
||||
return value
|
||||
|
||||
|
||||
class AgentRejectionData(BaseNotificationData):
|
||||
class AgentRejectionData(_LegacyAgentFieldsMixin, BaseNotificationData):
|
||||
agent_name: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
reviewer_name: str
|
||||
reviewer_email: str
|
||||
comments: str
|
||||
|
||||
@@ -15,8 +15,8 @@ class TestAgentApprovalData:
|
||||
"""Test creating valid AgentApprovalData."""
|
||||
data = AgentApprovalData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
graph_id="test-agent-123",
|
||||
graph_version=1,
|
||||
reviewer_name="John Doe",
|
||||
reviewer_email="john@example.com",
|
||||
comments="Great agent, approved!",
|
||||
@@ -25,8 +25,8 @@ class TestAgentApprovalData:
|
||||
)
|
||||
|
||||
assert data.agent_name == "Test Agent"
|
||||
assert data.agent_id == "test-agent-123"
|
||||
assert data.agent_version == 1
|
||||
assert data.graph_id == "test-agent-123"
|
||||
assert data.graph_version == 1
|
||||
assert data.reviewer_name == "John Doe"
|
||||
assert data.reviewer_email == "john@example.com"
|
||||
assert data.comments == "Great agent, approved!"
|
||||
@@ -40,8 +40,8 @@ class TestAgentApprovalData:
|
||||
):
|
||||
AgentApprovalData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
graph_id="test-agent-123",
|
||||
graph_version=1,
|
||||
reviewer_name="John Doe",
|
||||
reviewer_email="john@example.com",
|
||||
comments="Great agent, approved!",
|
||||
@@ -53,8 +53,8 @@ class TestAgentApprovalData:
|
||||
"""Test AgentApprovalData with empty comments."""
|
||||
data = AgentApprovalData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
graph_id="test-agent-123",
|
||||
graph_version=1,
|
||||
reviewer_name="John Doe",
|
||||
reviewer_email="john@example.com",
|
||||
comments="", # Empty comments
|
||||
@@ -72,8 +72,8 @@ class TestAgentRejectionData:
|
||||
"""Test creating valid AgentRejectionData."""
|
||||
data = AgentRejectionData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
graph_id="test-agent-123",
|
||||
graph_version=1,
|
||||
reviewer_name="Jane Doe",
|
||||
reviewer_email="jane@example.com",
|
||||
comments="Please fix the security issues before resubmitting.",
|
||||
@@ -82,8 +82,8 @@ class TestAgentRejectionData:
|
||||
)
|
||||
|
||||
assert data.agent_name == "Test Agent"
|
||||
assert data.agent_id == "test-agent-123"
|
||||
assert data.agent_version == 1
|
||||
assert data.graph_id == "test-agent-123"
|
||||
assert data.graph_version == 1
|
||||
assert data.reviewer_name == "Jane Doe"
|
||||
assert data.reviewer_email == "jane@example.com"
|
||||
assert data.comments == "Please fix the security issues before resubmitting."
|
||||
@@ -97,8 +97,8 @@ class TestAgentRejectionData:
|
||||
):
|
||||
AgentRejectionData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
graph_id="test-agent-123",
|
||||
graph_version=1,
|
||||
reviewer_name="Jane Doe",
|
||||
reviewer_email="jane@example.com",
|
||||
comments="Please fix the security issues.",
|
||||
@@ -111,8 +111,8 @@ class TestAgentRejectionData:
|
||||
long_comment = "A" * 1000 # Very long comment
|
||||
data = AgentRejectionData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
graph_id="test-agent-123",
|
||||
graph_version=1,
|
||||
reviewer_name="Jane Doe",
|
||||
reviewer_email="jane@example.com",
|
||||
comments=long_comment,
|
||||
@@ -126,8 +126,8 @@ class TestAgentRejectionData:
|
||||
"""Test that models can be serialized and deserialized."""
|
||||
original_data = AgentRejectionData(
|
||||
agent_name="Test Agent",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
graph_id="test-agent-123",
|
||||
graph_version=1,
|
||||
reviewer_name="Jane Doe",
|
||||
reviewer_email="jane@example.com",
|
||||
comments="Please fix the issues.",
|
||||
@@ -142,8 +142,8 @@ class TestAgentRejectionData:
|
||||
restored_data = AgentRejectionData.model_validate(data_dict)
|
||||
|
||||
assert restored_data.agent_name == original_data.agent_name
|
||||
assert restored_data.agent_id == original_data.agent_id
|
||||
assert restored_data.agent_version == original_data.agent_version
|
||||
assert restored_data.graph_id == original_data.graph_id
|
||||
assert restored_data.graph_version == original_data.graph_version
|
||||
assert restored_data.reviewer_name == original_data.reviewer_name
|
||||
assert restored_data.reviewer_email == original_data.reviewer_email
|
||||
assert restored_data.comments == original_data.comments
|
||||
|
||||
@@ -244,7 +244,10 @@ def _clean_and_split(text: str) -> list[str]:
|
||||
|
||||
|
||||
def _calculate_points(
|
||||
agent, categories: list[str], custom: list[str], integrations: list[str]
|
||||
agent: prisma.models.StoreAgent,
|
||||
categories: list[str],
|
||||
custom: list[str],
|
||||
integrations: list[str],
|
||||
) -> int:
|
||||
"""
|
||||
Calculates the total points for an agent based on the specified criteria.
|
||||
@@ -397,7 +400,7 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
storeAgents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where={
|
||||
"is_available": True,
|
||||
"useForOnboarding": True,
|
||||
"use_for_onboarding": True,
|
||||
},
|
||||
order=[
|
||||
{"featured": "desc"},
|
||||
@@ -407,7 +410,7 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
take=100,
|
||||
)
|
||||
|
||||
# If not enough agents found, relax the useForOnboarding filter
|
||||
# If not enough agents found, relax the use_for_onboarding filter
|
||||
if len(storeAgents) < 2:
|
||||
storeAgents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=prisma.types.StoreAgentWhereInput(**where_clause),
|
||||
@@ -420,7 +423,7 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
)
|
||||
|
||||
# Calculate points for the first X agents and choose the top 2
|
||||
agent_points = []
|
||||
agent_points: list[tuple[prisma.models.StoreAgent, int]] = []
|
||||
for agent in storeAgents[:POINTS_AGENT_COUNT]:
|
||||
points = _calculate_points(
|
||||
agent, categories, custom, user_onboarding.integrations
|
||||
@@ -430,28 +433,7 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
agent_points.sort(key=lambda x: x[1], reverse=True)
|
||||
recommended_agents = [agent for agent, _ in agent_points[:2]]
|
||||
|
||||
return [
|
||||
StoreAgentDetails(
|
||||
store_listing_version_id=agent.storeListingVersionId,
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_video=agent.agent_video or "",
|
||||
agent_output_demo=agent.agent_output_demo or "",
|
||||
agent_image=agent.agent_image,
|
||||
creator=agent.creator_username,
|
||||
creator_avatar=agent.creator_avatar,
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
categories=agent.categories,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
versions=agent.versions,
|
||||
agentGraphVersions=agent.agentGraphVersions,
|
||||
agentGraphId=agent.agentGraphId,
|
||||
last_updated=agent.updated_at,
|
||||
)
|
||||
for agent in recommended_agents
|
||||
]
|
||||
return [StoreAgentDetails.from_db(agent) for agent in recommended_agents]
|
||||
|
||||
|
||||
@cached(maxsize=1, ttl_seconds=300) # Cache for 5 minutes since this rarely changes
|
||||
|
||||
@@ -380,6 +380,35 @@ async def extract_business_understanding(
|
||||
return BusinessUnderstandingInput(**cleaned)
|
||||
|
||||
|
||||
async def get_business_understanding_input_from_tally(
|
||||
email: str,
|
||||
*,
|
||||
require_api_key: bool = False,
|
||||
) -> Optional[BusinessUnderstandingInput]:
|
||||
settings = Settings()
|
||||
if not settings.secrets.tally_api_key:
|
||||
if require_api_key:
|
||||
raise RuntimeError("Tally API key is not configured")
|
||||
logger.debug("Tally: no API key configured, skipping")
|
||||
return None
|
||||
|
||||
masked = _mask_email(email)
|
||||
result = await find_submission_by_email(TALLY_FORM_ID, email)
|
||||
if result is None:
|
||||
logger.debug(f"Tally: no submission found for {masked}")
|
||||
return None
|
||||
|
||||
submission, questions = result
|
||||
logger.info(f"Tally: found submission for {masked}, extracting understanding")
|
||||
|
||||
formatted = format_submission_for_llm(submission, questions)
|
||||
if not formatted.strip():
|
||||
logger.warning("Tally: formatted submission was empty, skipping")
|
||||
return None
|
||||
|
||||
return await extract_business_understanding(formatted)
|
||||
|
||||
|
||||
async def populate_understanding_from_tally(user_id: str, email: str) -> None:
|
||||
"""Main orchestrator: check Tally for a matching submission and populate understanding.
|
||||
|
||||
@@ -394,30 +423,10 @@ async def populate_understanding_from_tally(user_id: str, email: str) -> None:
|
||||
)
|
||||
return
|
||||
|
||||
# Check API key is configured
|
||||
settings = Settings()
|
||||
if not settings.secrets.tally_api_key:
|
||||
logger.debug("Tally: no API key configured, skipping")
|
||||
understanding_input = await get_business_understanding_input_from_tally(email)
|
||||
if understanding_input is None:
|
||||
return
|
||||
|
||||
# Look up submission by email
|
||||
masked = _mask_email(email)
|
||||
result = await find_submission_by_email(TALLY_FORM_ID, email)
|
||||
if result is None:
|
||||
logger.debug(f"Tally: no submission found for {masked}")
|
||||
return
|
||||
|
||||
submission, questions = result
|
||||
logger.info(f"Tally: found submission for {masked}, extracting understanding")
|
||||
|
||||
# Format and extract
|
||||
formatted = format_submission_for_llm(submission, questions)
|
||||
if not formatted.strip():
|
||||
logger.warning("Tally: formatted submission was empty, skipping")
|
||||
return
|
||||
|
||||
understanding_input = await extract_business_understanding(formatted)
|
||||
|
||||
# Upsert into database
|
||||
await upsert_business_understanding(user_id, understanding_input)
|
||||
logger.info(f"Tally: successfully populated understanding for user {user_id}")
|
||||
|
||||
@@ -166,6 +166,56 @@ def _merge_lists(existing: list | None, new: list | None) -> list | None:
|
||||
return merged
|
||||
|
||||
|
||||
def merge_business_understanding_data(
|
||||
existing_data: dict[str, Any],
|
||||
input_data: BusinessUnderstandingInput,
|
||||
) -> dict[str, Any]:
|
||||
merged_data = dict(existing_data)
|
||||
|
||||
merged_business: dict[str, Any] = {}
|
||||
if isinstance(merged_data.get("business"), dict):
|
||||
merged_business = dict(merged_data["business"])
|
||||
|
||||
business_string_fields = [
|
||||
"job_title",
|
||||
"business_name",
|
||||
"industry",
|
||||
"business_size",
|
||||
"user_role",
|
||||
"additional_notes",
|
||||
]
|
||||
business_list_fields = [
|
||||
"key_workflows",
|
||||
"daily_activities",
|
||||
"pain_points",
|
||||
"bottlenecks",
|
||||
"manual_tasks",
|
||||
"automation_goals",
|
||||
"current_software",
|
||||
"existing_automation",
|
||||
]
|
||||
|
||||
if input_data.user_name is not None:
|
||||
merged_data["name"] = input_data.user_name
|
||||
|
||||
for field in business_string_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
merged_business[field] = value
|
||||
|
||||
for field in business_list_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
existing_list = _json_to_list(merged_business.get(field))
|
||||
merged_list = _merge_lists(existing_list, value)
|
||||
merged_business[field] = merged_list
|
||||
|
||||
merged_business["version"] = 1
|
||||
merged_data["business"] = merged_business
|
||||
|
||||
return merged_data
|
||||
|
||||
|
||||
async def _get_from_cache(user_id: str) -> Optional[BusinessUnderstanding]:
|
||||
"""Get business understanding from Redis cache."""
|
||||
try:
|
||||
@@ -245,63 +295,18 @@ async def upsert_business_understanding(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
|
||||
# Get existing data structure or start fresh
|
||||
existing_data: dict[str, Any] = {}
|
||||
if existing and isinstance(existing.data, dict):
|
||||
existing_data = dict(existing.data)
|
||||
|
||||
existing_business: dict[str, Any] = {}
|
||||
if isinstance(existing_data.get("business"), dict):
|
||||
existing_business = dict(existing_data["business"])
|
||||
|
||||
# Business fields (stored inside business object)
|
||||
business_string_fields = [
|
||||
"job_title",
|
||||
"business_name",
|
||||
"industry",
|
||||
"business_size",
|
||||
"user_role",
|
||||
"additional_notes",
|
||||
]
|
||||
business_list_fields = [
|
||||
"key_workflows",
|
||||
"daily_activities",
|
||||
"pain_points",
|
||||
"bottlenecks",
|
||||
"manual_tasks",
|
||||
"automation_goals",
|
||||
"current_software",
|
||||
"existing_automation",
|
||||
]
|
||||
|
||||
# Handle top-level name field
|
||||
if input_data.user_name is not None:
|
||||
existing_data["name"] = input_data.user_name
|
||||
|
||||
# Business string fields - overwrite if provided
|
||||
for field in business_string_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
existing_business[field] = value
|
||||
|
||||
# Business list fields - merge with existing
|
||||
for field in business_list_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
existing_list = _json_to_list(existing_business.get(field))
|
||||
merged = _merge_lists(existing_list, value)
|
||||
existing_business[field] = merged
|
||||
|
||||
# Set version and nest business data
|
||||
existing_business["version"] = 1
|
||||
existing_data["business"] = existing_business
|
||||
merged_data = merge_business_understanding_data(existing_data, input_data)
|
||||
|
||||
# Upsert with the merged data
|
||||
record = await CoPilotUnderstanding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "data": SafeJson(existing_data)},
|
||||
"update": {"data": SafeJson(existing_data)},
|
||||
"create": {"userId": user_id, "data": SafeJson(merged_data)},
|
||||
"update": {"data": SafeJson(merged_data)},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import fastapi.responses
|
||||
import prisma
|
||||
import pytest
|
||||
|
||||
import backend.api.features.library.model
|
||||
@@ -497,9 +498,24 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
|
||||
# Ensure the test user has a Profile (required for store submissions)
|
||||
existing_profile = await prisma.models.Profile.prisma().find_first(
|
||||
where={"userId": test_user.id}
|
||||
)
|
||||
if not existing_profile:
|
||||
await prisma.models.Profile.prisma().create(
|
||||
data=prisma.types.ProfileCreateInput(
|
||||
userId=test_user.id,
|
||||
name=test_user.name or "Test User",
|
||||
username=f"test-user-{test_user.id[:8]}",
|
||||
description="Test user profile",
|
||||
links=[],
|
||||
)
|
||||
)
|
||||
|
||||
store_submission_request = backend.api.features.store.model.StoreSubmissionRequest(
|
||||
agent_id=test_graph.id,
|
||||
agent_version=test_graph.version,
|
||||
graph_id=test_graph.id,
|
||||
graph_version=test_graph.version,
|
||||
slug=test_graph.id,
|
||||
name="Test name",
|
||||
sub_heading="Test sub heading",
|
||||
@@ -517,8 +533,8 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
assert False, "Failed to create store listing"
|
||||
|
||||
slv_id = (
|
||||
store_listing.store_listing_version_id
|
||||
if store_listing.store_listing_version_id is not None
|
||||
store_listing.listing_version_id
|
||||
if store_listing.listing_version_id is not None
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from backend.data import graph as graph_db
|
||||
from backend.data import human_review as human_review_db
|
||||
from backend.data import onboarding as onboarding_db
|
||||
from backend.data import user as user_db
|
||||
from backend.data import workspace as workspace_db
|
||||
|
||||
# Import dynamic field utilities from centralized location
|
||||
from backend.data.block import BlockInput, BlockOutputEntry
|
||||
@@ -32,7 +33,6 @@ from backend.data.execution import (
|
||||
from backend.data.graph import GraphModel, Node
|
||||
from backend.data.model import USER_TIMEZONE_NOT_SET, CredentialsMetaInput, GraphInput
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
get_async_execution_queue,
|
||||
@@ -481,6 +481,22 @@ async def _construct_starting_node_execution_input(
|
||||
if nodes_input_masks and (node_input_mask := nodes_input_masks.get(node.id)):
|
||||
input_data.update(node_input_mask)
|
||||
|
||||
# Webhook-triggered agents cannot be executed directly without payload data.
|
||||
# Legitimate webhook triggers provide payload via nodes_input_masks above.
|
||||
if (
|
||||
block.block_type
|
||||
in (
|
||||
BlockType.WEBHOOK,
|
||||
BlockType.WEBHOOK_MANUAL,
|
||||
)
|
||||
and "payload" not in input_data
|
||||
):
|
||||
raise ValueError(
|
||||
"This agent is triggered by an external event (webhook) "
|
||||
"and cannot be executed directly. "
|
||||
"Please use the appropriate trigger to run this agent."
|
||||
)
|
||||
|
||||
input_data, error = validate_exec(node, input_data)
|
||||
if input_data is None:
|
||||
raise ValueError(error)
|
||||
@@ -831,8 +847,9 @@ async def add_graph_execution(
|
||||
udb = user_db
|
||||
gdb = graph_db
|
||||
odb = onboarding_db
|
||||
wdb = workspace_db
|
||||
else:
|
||||
edb = udb = gdb = odb = get_database_manager_async_client()
|
||||
edb = udb = gdb = odb = wdb = get_database_manager_async_client()
|
||||
|
||||
# Get or create the graph execution
|
||||
if graph_exec_id:
|
||||
@@ -892,7 +909,7 @@ async def add_graph_execution(
|
||||
if execution_context is None:
|
||||
user = await udb.get_user_by_id(user_id)
|
||||
settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id)
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
workspace = await wdb.get_or_create_workspace(user_id)
|
||||
|
||||
execution_context = ExecutionContext(
|
||||
# Execution identity
|
||||
|
||||
@@ -368,12 +368,10 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
mock_wdb = mocker.patch("backend.executor.utils.workspace_db")
|
||||
mock_workspace = mocker.MagicMock()
|
||||
mock_workspace.id = "test-workspace-id"
|
||||
mocker.patch(
|
||||
"backend.executor.utils.get_or_create_workspace",
|
||||
new=mocker.AsyncMock(return_value=mock_workspace),
|
||||
)
|
||||
mock_wdb.get_or_create_workspace = mocker.AsyncMock(return_value=mock_workspace)
|
||||
|
||||
# Setup mock returns
|
||||
# The function returns (graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip)
|
||||
@@ -649,12 +647,10 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
mock_wdb = mocker.patch("backend.executor.utils.workspace_db")
|
||||
mock_workspace = mocker.MagicMock()
|
||||
mock_workspace.id = "test-workspace-id"
|
||||
mocker.patch(
|
||||
"backend.executor.utils.get_or_create_workspace",
|
||||
new=mocker.AsyncMock(return_value=mock_workspace),
|
||||
)
|
||||
mock_wdb.get_or_create_workspace = mocker.AsyncMock(return_value=mock_workspace)
|
||||
|
||||
# Setup returns - include nodes_to_skip in the tuple
|
||||
mock_validate.return_value = (
|
||||
|
||||
@@ -76,7 +76,6 @@ class TelegramWebhooksManager(BaseWebhooksManager):
|
||||
credentials_id=credentials.id,
|
||||
webhook_type=webhook_type,
|
||||
resource=resource,
|
||||
events=None, # Ignore events for this lookup
|
||||
):
|
||||
# Re-register with Telegram using the same URL but new allowed_updates
|
||||
ingress_url = webhook_ingress_url(self.PROVIDER_NAME, existing.id)
|
||||
@@ -143,10 +142,6 @@ class TelegramWebhooksManager(BaseWebhooksManager):
|
||||
elif "video" in message:
|
||||
event_type = "message.video"
|
||||
else:
|
||||
logger.warning(
|
||||
"Unknown Telegram webhook payload type; "
|
||||
f"message.keys() = {message.keys()}"
|
||||
)
|
||||
event_type = "message.other"
|
||||
elif "edited_message" in payload:
|
||||
event_type = "message.edited_message"
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
{#
|
||||
Template variables:
|
||||
data.agent_name: the name of the approved agent
|
||||
data.agent_id: the ID of the agent
|
||||
data.agent_version: the version of the agent
|
||||
data.graph_id: the ID of the agent
|
||||
data.graph_version: the version of the agent
|
||||
data.reviewer_name: the name of the reviewer who approved it
|
||||
data.reviewer_email: the email of the reviewer
|
||||
data.comments: comments from the reviewer
|
||||
@@ -70,4 +70,4 @@
|
||||
Thank you for contributing to the AutoGPT ecosystem! 🚀
|
||||
</p>
|
||||
|
||||
{% endblock %}
|
||||
{% endblock %}
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
{#
|
||||
Template variables:
|
||||
data.agent_name: the name of the rejected agent
|
||||
data.agent_id: the ID of the agent
|
||||
data.agent_version: the version of the agent
|
||||
data.graph_id: the ID of the agent
|
||||
data.graph_version: the version of the agent
|
||||
data.reviewer_name: the name of the reviewer who rejected it
|
||||
data.reviewer_email: the email of the reviewer
|
||||
data.comments: comments from the reviewer explaining the rejection
|
||||
@@ -74,4 +74,4 @@
|
||||
We're excited to see your improved agent submission! 🚀
|
||||
</p>
|
||||
|
||||
{% endblock %}
|
||||
{% endblock %}
|
||||
|
||||
@@ -64,6 +64,10 @@ class GraphNotInLibraryError(GraphNotAccessibleError):
|
||||
"""Raised when attempting to execute a graph that is not / no longer in the user's library."""
|
||||
|
||||
|
||||
class PreconditionFailed(Exception):
|
||||
"""The user must do something else first before trying the current operation"""
|
||||
|
||||
|
||||
class InsufficientBalanceError(ValueError):
|
||||
user_id: str
|
||||
message: str
|
||||
|
||||
@@ -72,19 +72,58 @@ def dumps(
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@overload
|
||||
def loads(data: str | bytes, *args, target_type: Type[T], **kwargs) -> T: ...
|
||||
# Sentinel value to detect when fallback is not provided
|
||||
_NO_FALLBACK = object()
|
||||
|
||||
|
||||
@overload
|
||||
def loads(data: str | bytes, *args, **kwargs) -> Any: ...
|
||||
def loads(
|
||||
data: str | bytes, *args, target_type: Type[T], fallback: T | None = None, **kwargs
|
||||
) -> T:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def loads(data: str | bytes, *args, fallback: Any = None, **kwargs) -> Any:
|
||||
pass
|
||||
|
||||
|
||||
def loads(
|
||||
data: str | bytes, *args, target_type: Type[T] | None = None, **kwargs
|
||||
data: str | bytes,
|
||||
*args,
|
||||
target_type: Type[T] | None = None,
|
||||
fallback: Any = _NO_FALLBACK,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
parsed = orjson.loads(data)
|
||||
"""Parse JSON with optional fallback on decode errors.
|
||||
|
||||
Args:
|
||||
data: JSON string or bytes to parse
|
||||
target_type: Optional type to validate/cast result to
|
||||
fallback: Value to return on JSONDecodeError. If not provided, raises.
|
||||
**kwargs: Additional arguments (unused, for compatibility)
|
||||
|
||||
Returns:
|
||||
Parsed JSON data, or fallback value if parsing fails
|
||||
|
||||
Raises:
|
||||
orjson.JSONDecodeError: Only if fallback is not provided
|
||||
|
||||
Examples:
|
||||
>>> loads('{"valid": "json"}')
|
||||
{'valid': 'json'}
|
||||
>>> loads('invalid json', fallback=None)
|
||||
None
|
||||
>>> loads('invalid json', fallback={})
|
||||
{}
|
||||
>>> loads('invalid json') # raises orjson.JSONDecodeError
|
||||
"""
|
||||
try:
|
||||
parsed = orjson.loads(data)
|
||||
except orjson.JSONDecodeError:
|
||||
if fallback is not _NO_FALLBACK:
|
||||
return fallback
|
||||
raise
|
||||
|
||||
if target_type:
|
||||
return type_match(parsed, target_type)
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
BEGIN;
|
||||
|
||||
-- Drop illogical column StoreListing.agentGraphVersion;
|
||||
ALTER TABLE "StoreListing" DROP CONSTRAINT "StoreListing_agentGraphId_agentGraphVersion_fkey";
|
||||
DROP INDEX "StoreListing_agentGraphId_agentGraphVersion_idx";
|
||||
ALTER TABLE "StoreListing" DROP COLUMN "agentGraphVersion";
|
||||
|
||||
-- Add uniqueness constraint to Profile.userId and remove invalid data
|
||||
--
|
||||
-- Delete any profiles with null userId (which is invalid and doesn't occur in theory)
|
||||
DELETE FROM "Profile" WHERE "userId" IS NULL;
|
||||
--
|
||||
-- Delete duplicate profiles per userId, keeping the most recently updated one
|
||||
DELETE FROM "Profile"
|
||||
WHERE "id" IN (
|
||||
SELECT "id" FROM (
|
||||
SELECT "id", ROW_NUMBER() OVER (
|
||||
PARTITION BY "userId" ORDER BY "updatedAt" DESC, "id" DESC
|
||||
) AS rn
|
||||
FROM "Profile"
|
||||
) ranked
|
||||
WHERE rn > 1
|
||||
);
|
||||
--
|
||||
-- Add userId uniqueness constraint
|
||||
ALTER TABLE "Profile" ALTER COLUMN "userId" SET NOT NULL;
|
||||
CREATE UNIQUE INDEX "Profile_userId_key" ON "Profile"("userId");
|
||||
|
||||
-- Add formal relation StoreListing.owningUserId -> Profile.userId
|
||||
ALTER TABLE "StoreListing" ADD CONSTRAINT "StoreListing_owner_Profile_fkey" FOREIGN KEY ("owningUserId") REFERENCES "Profile"("userId") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
COMMIT;
|
||||
@@ -0,0 +1,219 @@
|
||||
-- Update the StoreSubmission and StoreAgent views with additional fields, clearer field names, and faster joins.
|
||||
-- Steps:
|
||||
-- 1. Update `mv_agent_run_counts` to exclude runs by the agent's creator
|
||||
-- a. Drop dependent views `StoreAgent` and `Creator`
|
||||
-- b. Update `mv_agent_run_counts` and its index
|
||||
-- c. Recreate `StoreAgent` view (with updates)
|
||||
-- d. Restore `Creator` view
|
||||
-- 2. Update `StoreSubmission` view
|
||||
-- 3. Update `StoreListingReview` indices to make `StoreSubmission` query more efficient
|
||||
|
||||
BEGIN;
|
||||
|
||||
-- Drop views that are dependent on mv_agent_run_counts
|
||||
DROP VIEW IF EXISTS "StoreAgent";
|
||||
DROP VIEW IF EXISTS "Creator";
|
||||
|
||||
-- Update materialized view for agent run counts to exclude runs by the agent's creator
|
||||
DROP INDEX IF EXISTS "idx_mv_agent_run_counts";
|
||||
DROP MATERIALIZED VIEW IF EXISTS "mv_agent_run_counts";
|
||||
CREATE MATERIALIZED VIEW "mv_agent_run_counts" AS
|
||||
SELECT
|
||||
run."agentGraphId" AS graph_id,
|
||||
COUNT(*) AS run_count
|
||||
FROM "AgentGraphExecution" run
|
||||
JOIN "AgentGraph" graph ON graph.id = run."agentGraphId"
|
||||
-- Exclude runs by the agent's creator to avoid inflating run counts
|
||||
WHERE graph."userId" != run."userId"
|
||||
GROUP BY run."agentGraphId";
|
||||
|
||||
-- Recreate index
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS "idx_mv_agent_run_counts" ON "mv_agent_run_counts"("graph_id");
|
||||
|
||||
-- Re-populate the materialized view
|
||||
REFRESH MATERIALIZED VIEW "mv_agent_run_counts";
|
||||
|
||||
|
||||
-- Recreate the StoreAgent view with the following changes
|
||||
-- (compared to 20260115210000_remove_storelistingversion_search):
|
||||
-- - Narrow to *explicitly active* version (sl.activeVersionId) instead of MAX(version)
|
||||
-- - Add `recommended_schedule_cron` column
|
||||
-- - Rename `"storeListingVersionId"` -> `listing_version_id`
|
||||
-- - Rename `"agentGraphVersions"` -> `graph_versions`
|
||||
-- - Rename `"agentGraphId"` -> `graph_id`
|
||||
-- - Rename `"useForOnboarding"` -> `use_for_onboarding`
|
||||
CREATE OR REPLACE VIEW "StoreAgent" AS
|
||||
WITH store_agent_versions AS (
|
||||
SELECT
|
||||
"storeListingId",
|
||||
array_agg(DISTINCT version::text ORDER BY version::text) AS versions
|
||||
FROM "StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
GROUP BY "storeListingId"
|
||||
),
|
||||
agent_graph_versions AS (
|
||||
SELECT
|
||||
"storeListingId",
|
||||
array_agg(DISTINCT "agentGraphVersion"::text ORDER BY "agentGraphVersion"::text) AS graph_versions
|
||||
FROM "StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
GROUP BY "storeListingId"
|
||||
)
|
||||
SELECT
|
||||
sl.id AS listing_id,
|
||||
slv.id AS listing_version_id,
|
||||
slv."createdAt" AS updated_at,
|
||||
sl.slug,
|
||||
COALESCE(slv.name, '') AS agent_name,
|
||||
slv."videoUrl" AS agent_video,
|
||||
slv."agentOutputDemoUrl" AS agent_output_demo,
|
||||
COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image,
|
||||
slv."isFeatured" AS featured,
|
||||
cp.username AS creator_username,
|
||||
cp."avatarUrl" AS creator_avatar,
|
||||
slv."subHeading" AS sub_heading,
|
||||
slv.description,
|
||||
slv.categories,
|
||||
COALESCE(arc.run_count, 0::bigint) AS runs,
|
||||
COALESCE(reviews.avg_rating, 0.0)::double precision AS rating,
|
||||
COALESCE(sav.versions, ARRAY[slv.version::text]) AS versions,
|
||||
slv."agentGraphId" AS graph_id,
|
||||
COALESCE(
|
||||
agv.graph_versions,
|
||||
ARRAY[slv."agentGraphVersion"::text]
|
||||
) AS graph_versions,
|
||||
slv."isAvailable" AS is_available,
|
||||
COALESCE(sl."useForOnboarding", false) AS use_for_onboarding,
|
||||
slv."recommendedScheduleCron" AS recommended_schedule_cron
|
||||
FROM "StoreListing" AS sl
|
||||
JOIN "StoreListingVersion" AS slv
|
||||
ON slv."storeListingId" = sl.id
|
||||
AND slv.id = sl."activeVersionId"
|
||||
AND slv."submissionStatus" = 'APPROVED'
|
||||
JOIN "AgentGraph" AS ag
|
||||
ON slv."agentGraphId" = ag.id
|
||||
AND slv."agentGraphVersion" = ag.version
|
||||
LEFT JOIN "Profile" AS cp
|
||||
ON sl."owningUserId" = cp."userId"
|
||||
LEFT JOIN "mv_review_stats" AS reviews
|
||||
ON sl.id = reviews."storeListingId"
|
||||
LEFT JOIN "mv_agent_run_counts" AS arc
|
||||
ON ag.id = arc.graph_id
|
||||
LEFT JOIN store_agent_versions AS sav
|
||||
ON sl.id = sav."storeListingId"
|
||||
LEFT JOIN agent_graph_versions AS agv
|
||||
ON sl.id = agv."storeListingId"
|
||||
WHERE sl."isDeleted" = false
|
||||
AND sl."hasApprovedVersion" = true;
|
||||
|
||||
|
||||
-- Restore Creator view as last updated in 20250604130249_optimise_store_agent_and_creator_views,
|
||||
-- with minor changes:
|
||||
-- - Ensure top_categories always TEXT[]
|
||||
-- - Filter out empty ('') categories
|
||||
CREATE OR REPLACE VIEW "Creator" AS
|
||||
WITH creator_listings AS (
|
||||
SELECT
|
||||
sl."owningUserId",
|
||||
sl.id AS listing_id,
|
||||
slv."agentGraphId",
|
||||
slv.categories,
|
||||
sr.score,
|
||||
ar.run_count
|
||||
FROM "StoreListing" sl
|
||||
JOIN "StoreListingVersion" slv
|
||||
ON slv."storeListingId" = sl.id
|
||||
AND slv."submissionStatus" = 'APPROVED'
|
||||
LEFT JOIN "StoreListingReview" sr
|
||||
ON sr."storeListingVersionId" = slv.id
|
||||
LEFT JOIN "mv_agent_run_counts" ar
|
||||
ON ar.graph_id = slv."agentGraphId"
|
||||
WHERE sl."isDeleted" = false
|
||||
AND sl."hasApprovedVersion" = true
|
||||
),
|
||||
creator_stats AS (
|
||||
SELECT
|
||||
cl."owningUserId",
|
||||
COUNT(DISTINCT cl.listing_id) AS num_agents,
|
||||
AVG(COALESCE(cl.score, 0)::numeric) AS agent_rating,
|
||||
SUM(COALESCE(cl.run_count, 0)) AS agent_runs,
|
||||
array_agg(DISTINCT cat ORDER BY cat)
|
||||
FILTER (WHERE cat IS NOT NULL AND cat != '') AS all_categories
|
||||
FROM creator_listings cl
|
||||
LEFT JOIN LATERAL unnest(COALESCE(cl.categories, ARRAY[]::text[])) AS cat ON true
|
||||
GROUP BY cl."owningUserId"
|
||||
)
|
||||
SELECT
|
||||
p.username,
|
||||
p.name,
|
||||
p."avatarUrl" AS avatar_url,
|
||||
p.description,
|
||||
COALESCE(cs.all_categories, ARRAY[]::text[]) AS top_categories,
|
||||
p.links,
|
||||
p."isFeatured" AS is_featured,
|
||||
COALESCE(cs.num_agents, 0::bigint) AS num_agents,
|
||||
COALESCE(cs.agent_rating, 0.0) AS agent_rating,
|
||||
COALESCE(cs.agent_runs, 0::numeric) AS agent_runs
|
||||
FROM "Profile" p
|
||||
LEFT JOIN creator_stats cs ON cs."owningUserId" = p."userId";
|
||||
|
||||
|
||||
-- Recreate the StoreSubmission view with updated fields & query strategy:
|
||||
-- - Uses mv_agent_run_counts instead of full AgentGraphExecution table scan + aggregation
|
||||
-- - Renamed agent_id, agent_version -> graph_id, graph_version
|
||||
-- - Renamed store_listing_version_id -> listing_version_id
|
||||
-- - Renamed date_submitted -> submitted_at
|
||||
-- - Renamed runs, rating -> run_count, review_avg_rating
|
||||
-- - Added fields: instructions, agent_output_demo_url, review_count, is_deleted
|
||||
DROP VIEW IF EXISTS "StoreSubmission";
|
||||
CREATE OR REPLACE VIEW "StoreSubmission" AS
|
||||
WITH review_stats AS (
|
||||
SELECT
|
||||
"storeListingVersionId" AS version_id, -- more specific than mv_review_stats
|
||||
avg(score) AS avg_rating,
|
||||
count(*) AS review_count
|
||||
FROM "StoreListingReview"
|
||||
GROUP BY "storeListingVersionId"
|
||||
)
|
||||
SELECT
|
||||
sl.id AS listing_id,
|
||||
sl."owningUserId" AS user_id,
|
||||
sl.slug AS slug,
|
||||
|
||||
slv.id AS listing_version_id,
|
||||
slv.version AS listing_version,
|
||||
slv."agentGraphId" AS graph_id,
|
||||
slv."agentGraphVersion" AS graph_version,
|
||||
slv.name AS name,
|
||||
slv."subHeading" AS sub_heading,
|
||||
slv.description AS description,
|
||||
slv.instructions AS instructions,
|
||||
slv.categories AS categories,
|
||||
slv."imageUrls" AS image_urls,
|
||||
slv."videoUrl" AS video_url,
|
||||
slv."agentOutputDemoUrl" AS agent_output_demo_url,
|
||||
slv."submittedAt" AS submitted_at,
|
||||
slv."changesSummary" AS changes_summary,
|
||||
slv."submissionStatus" AS status,
|
||||
slv."reviewedAt" AS reviewed_at,
|
||||
slv."reviewerId" AS reviewer_id,
|
||||
slv."reviewComments" AS review_comments,
|
||||
slv."internalComments" AS internal_comments,
|
||||
slv."isDeleted" AS is_deleted,
|
||||
|
||||
COALESCE(run_stats.run_count, 0::bigint) AS run_count,
|
||||
COALESCE(review_stats.review_count, 0::bigint) AS review_count,
|
||||
COALESCE(review_stats.avg_rating, 0.0)::double precision AS review_avg_rating
|
||||
FROM "StoreListing" sl
|
||||
JOIN "StoreListingVersion" slv ON slv."storeListingId" = sl.id
|
||||
LEFT JOIN review_stats ON review_stats.version_id = slv.id
|
||||
LEFT JOIN mv_agent_run_counts run_stats ON run_stats.graph_id = slv."agentGraphId"
|
||||
WHERE sl."isDeleted" = false;
|
||||
|
||||
|
||||
-- Drop unused index on StoreListingReview.reviewByUserId
|
||||
DROP INDEX IF EXISTS "StoreListingReview_reviewByUserId_idx";
|
||||
-- Add index on storeListingVersionId to make StoreSubmission query faster
|
||||
CREATE INDEX "StoreListingReview_storeListingVersionId_idx" ON "StoreListingReview"("storeListingVersionId");
|
||||
|
||||
COMMIT;
|
||||
@@ -0,0 +1,114 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "InvitedUserStatus" AS ENUM ('INVITED', 'CLAIMED', 'REVOKED');
|
||||
|
||||
-- CreateEnum
|
||||
CREATE TYPE "TallyComputationStatus" AS ENUM ('PENDING', 'RUNNING', 'READY', 'FAILED');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "InvitedUser" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"email" TEXT NOT NULL,
|
||||
"status" "InvitedUserStatus" NOT NULL DEFAULT 'INVITED',
|
||||
"authUserId" TEXT,
|
||||
"name" TEXT,
|
||||
"tallyUnderstanding" JSONB,
|
||||
"tallyStatus" "TallyComputationStatus" NOT NULL DEFAULT 'PENDING',
|
||||
"tallyComputedAt" TIMESTAMP(3),
|
||||
"tallyError" TEXT,
|
||||
|
||||
CONSTRAINT "InvitedUser_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "InvitedUser_email_key" ON "InvitedUser"("email");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "InvitedUser_authUserId_key" ON "InvitedUser"("authUserId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "InvitedUser_status_idx" ON "InvitedUser"("status");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "InvitedUser_tallyStatus_idx" ON "InvitedUser"("tallyStatus");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "InvitedUser" ADD CONSTRAINT "InvitedUser_authUserId_fkey"
|
||||
FOREIGN KEY ("authUserId") REFERENCES "User"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
|
||||
DO $$
|
||||
DECLARE
|
||||
allowed_users_schema TEXT;
|
||||
BEGIN
|
||||
SELECT table_schema
|
||||
INTO allowed_users_schema
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = 'allowed_users'
|
||||
AND column_name = 'email'
|
||||
ORDER BY CASE
|
||||
WHEN table_schema = 'platform' THEN 0
|
||||
WHEN table_schema = 'public' THEN 1
|
||||
ELSE 2
|
||||
END
|
||||
LIMIT 1;
|
||||
|
||||
IF allowed_users_schema IS NOT NULL THEN
|
||||
EXECUTE format(
|
||||
'INSERT INTO platform."InvitedUser" ("id", "email", "status", "tallyStatus", "createdAt", "updatedAt")
|
||||
SELECT gen_random_uuid()::text,
|
||||
lower(email),
|
||||
''INVITED''::platform."InvitedUserStatus",
|
||||
''PENDING''::platform."TallyComputationStatus",
|
||||
now(),
|
||||
now()
|
||||
FROM %I.allowed_users
|
||||
WHERE email IS NOT NULL
|
||||
ON CONFLICT ("email") DO NOTHING',
|
||||
allowed_users_schema
|
||||
);
|
||||
END IF;
|
||||
END $$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION platform.ensure_invited_user_can_register()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
IF NEW.email IS NULL THEN
|
||||
RAISE EXCEPTION 'The email address "%" is not allowed to register. Please contact support for assistance.', NEW.email
|
||||
USING ERRCODE = 'P0001';
|
||||
END IF;
|
||||
|
||||
IF lower(split_part(NEW.email, '@', 2)) = 'agpt.co' THEN
|
||||
RETURN NEW;
|
||||
END IF;
|
||||
|
||||
IF EXISTS (
|
||||
SELECT 1
|
||||
FROM platform."InvitedUser" invited_user
|
||||
WHERE lower(invited_user.email) = lower(NEW.email)
|
||||
AND invited_user.status = 'INVITED'::platform."InvitedUserStatus"
|
||||
) THEN
|
||||
RETURN NEW;
|
||||
END IF;
|
||||
|
||||
RAISE EXCEPTION 'The email address "%" is not allowed to register. Please contact support for assistance.', NEW.email
|
||||
USING ERRCODE = 'P0001';
|
||||
END;
|
||||
$$ LANGUAGE plpgsql SECURITY DEFINER;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'auth'
|
||||
AND table_name = 'users'
|
||||
) THEN
|
||||
DROP TRIGGER IF EXISTS user_added_to_platform ON auth.users;
|
||||
DROP TRIGGER IF EXISTS invited_user_signup_gate ON auth.users;
|
||||
|
||||
CREATE TRIGGER invited_user_signup_gate
|
||||
BEFORE INSERT ON auth.users
|
||||
FOR EACH ROW EXECUTE FUNCTION platform.ensure_invited_user_can_register();
|
||||
END IF;
|
||||
END $$;
|
||||
@@ -65,6 +65,7 @@ model User {
|
||||
NotificationBatches UserNotificationBatch[]
|
||||
PendingHumanReviews PendingHumanReview[]
|
||||
Workspace UserWorkspace?
|
||||
ClaimedInvite InvitedUser? @relation("InvitedUserAuthUser")
|
||||
|
||||
// OAuth Provider relations
|
||||
OAuthApplications OAuthApplication[]
|
||||
@@ -73,6 +74,39 @@ model User {
|
||||
OAuthRefreshTokens OAuthRefreshToken[]
|
||||
}
|
||||
|
||||
enum InvitedUserStatus {
|
||||
INVITED
|
||||
CLAIMED
|
||||
REVOKED
|
||||
}
|
||||
|
||||
enum TallyComputationStatus {
|
||||
PENDING
|
||||
RUNNING
|
||||
READY
|
||||
FAILED
|
||||
}
|
||||
|
||||
model InvitedUser {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
email String @unique
|
||||
status InvitedUserStatus @default(INVITED)
|
||||
authUserId String? @unique
|
||||
AuthUser User? @relation("InvitedUserAuthUser", fields: [authUserId], references: [id], onDelete: SetNull)
|
||||
name String?
|
||||
|
||||
tallyUnderstanding Json?
|
||||
tallyStatus TallyComputationStatus @default(PENDING)
|
||||
tallyComputedAt DateTime?
|
||||
tallyError String?
|
||||
|
||||
@@index([status])
|
||||
@@index([tallyStatus])
|
||||
}
|
||||
|
||||
enum OnboardingStep {
|
||||
// Introductory onboarding (Library)
|
||||
WELCOME
|
||||
@@ -281,7 +315,6 @@ model AgentGraph {
|
||||
|
||||
Presets AgentPreset[]
|
||||
LibraryAgents LibraryAgent[]
|
||||
StoreListings StoreListing[]
|
||||
StoreListingVersions StoreListingVersion[]
|
||||
|
||||
@@id(name: "graphVersionId", [id, version])
|
||||
@@ -814,10 +847,8 @@ model Profile {
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
// Only 1 of user or group can be set.
|
||||
// The user this profile belongs to, if any.
|
||||
userId String?
|
||||
User User? @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
userId String @unique
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
name String
|
||||
username String @unique
|
||||
@@ -830,6 +861,7 @@ model Profile {
|
||||
isFeatured Boolean @default(false)
|
||||
|
||||
LibraryAgents LibraryAgent[]
|
||||
StoreListings StoreListing[]
|
||||
|
||||
@@index([userId])
|
||||
}
|
||||
@@ -860,9 +892,9 @@ view Creator {
|
||||
}
|
||||
|
||||
view StoreAgent {
|
||||
listing_id String @id
|
||||
storeListingVersionId String
|
||||
updated_at DateTime
|
||||
listing_id String @id
|
||||
listing_version_id String
|
||||
updated_at DateTime
|
||||
|
||||
slug String
|
||||
agent_name String
|
||||
@@ -879,10 +911,12 @@ view StoreAgent {
|
||||
runs Int
|
||||
rating Float
|
||||
versions String[]
|
||||
agentGraphVersions String[]
|
||||
agentGraphId String
|
||||
graph_id String
|
||||
graph_versions String[]
|
||||
is_available Boolean @default(true)
|
||||
useForOnboarding Boolean @default(false)
|
||||
use_for_onboarding Boolean @default(false)
|
||||
|
||||
recommended_schedule_cron String?
|
||||
|
||||
// Materialized views used (refreshed every 15 minutes via pg_cron):
|
||||
// - mv_agent_run_counts - Pre-aggregated agent execution counts by agentGraphId
|
||||
@@ -896,41 +930,52 @@ view StoreAgent {
|
||||
}
|
||||
|
||||
view StoreSubmission {
|
||||
listing_id String @id
|
||||
user_id String
|
||||
slug String
|
||||
name String
|
||||
sub_heading String
|
||||
description String
|
||||
image_urls String[]
|
||||
date_submitted DateTime
|
||||
status SubmissionStatus
|
||||
runs Int
|
||||
rating Float
|
||||
agent_id String
|
||||
agent_version Int
|
||||
store_listing_version_id String
|
||||
reviewer_id String?
|
||||
review_comments String?
|
||||
internal_comments String?
|
||||
reviewed_at DateTime?
|
||||
changes_summary String?
|
||||
video_url String?
|
||||
categories String[]
|
||||
// From StoreListing:
|
||||
listing_id String
|
||||
user_id String
|
||||
slug String
|
||||
|
||||
// Index or unique are not applied to views
|
||||
// From StoreListingVersion:
|
||||
listing_version_id String @id
|
||||
listing_version Int
|
||||
graph_id String
|
||||
graph_version Int
|
||||
|
||||
name String
|
||||
sub_heading String
|
||||
description String
|
||||
instructions String?
|
||||
categories String[]
|
||||
image_urls String[]
|
||||
video_url String?
|
||||
agent_output_demo_url String?
|
||||
|
||||
submitted_at DateTime?
|
||||
changes_summary String?
|
||||
status SubmissionStatus
|
||||
reviewed_at DateTime?
|
||||
reviewer_id String?
|
||||
review_comments String?
|
||||
internal_comments String?
|
||||
|
||||
is_deleted Boolean
|
||||
|
||||
// Aggregated from AgentGraphExecutions and StoreListingReviews:
|
||||
run_count Int
|
||||
review_count Int
|
||||
review_avg_rating Float
|
||||
}
|
||||
|
||||
// Note: This is actually a MATERIALIZED VIEW in the database
|
||||
// Refreshed automatically every 15 minutes via pg_cron (with fallback to manual refresh)
|
||||
view mv_agent_run_counts {
|
||||
agentGraphId String @unique
|
||||
run_count Int
|
||||
graph_id String @unique
|
||||
run_count Int // excluding runs by the graph's creator
|
||||
|
||||
// Pre-aggregated count of AgentGraphExecution records by agentGraphId
|
||||
// Used by StoreAgent and Creator views for performance optimization
|
||||
// Unique index created automatically on agentGraphId for fast lookups
|
||||
// Refresh uses CONCURRENTLY to avoid blocking reads
|
||||
// Pre-aggregated count of AgentGraphExecution records by agentGraphId.
|
||||
// Used by StoreAgent, Creator, and StoreSubmission views for performance optimization.
|
||||
// - Should have a unique index on graph_id for fast lookups
|
||||
// - Refresh should use CONCURRENTLY to avoid blocking reads
|
||||
}
|
||||
|
||||
// Note: This is actually a MATERIALIZED VIEW in the database
|
||||
@@ -979,22 +1024,18 @@ model StoreListing {
|
||||
ActiveVersion StoreListingVersion? @relation("ActiveVersion", fields: [activeVersionId], references: [id])
|
||||
|
||||
// The agent link here is only so we can do lookup on agentId
|
||||
agentGraphId String
|
||||
agentGraphVersion Int
|
||||
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version], onDelete: Cascade)
|
||||
agentGraphId String @unique
|
||||
|
||||
owningUserId String
|
||||
OwningUser User @relation(fields: [owningUserId], references: [id])
|
||||
owningUserId String
|
||||
OwningUser User @relation(fields: [owningUserId], references: [id])
|
||||
CreatorProfile Profile @relation(fields: [owningUserId], references: [userId], map: "StoreListing_owner_Profile_fkey", onDelete: Cascade)
|
||||
|
||||
// Relations
|
||||
Versions StoreListingVersion[] @relation("ListingVersions")
|
||||
|
||||
// Unique index on agentId to ensure only one listing per agent, regardless of number of versions the agent has.
|
||||
@@unique([agentGraphId])
|
||||
@@unique([owningUserId, slug])
|
||||
// Used in the view query
|
||||
@@index([isDeleted, hasApprovedVersion])
|
||||
@@index([agentGraphId, agentGraphVersion])
|
||||
}
|
||||
|
||||
model StoreListingVersion {
|
||||
@@ -1089,16 +1130,16 @@ model UnifiedContentEmbedding {
|
||||
// Search data
|
||||
embedding Unsupported("vector(1536)") // pgvector embedding (extension in platform schema)
|
||||
searchableText String // Combined text for search and fallback
|
||||
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector")) // Full-text search (auto-populated by trigger)
|
||||
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector")) // Full-text search (auto-populated by trigger)
|
||||
metadata Json @default("{}") // Content-specific metadata
|
||||
// NO @@index for search - GIN index "UnifiedContentEmbedding_search_idx" created via SQL migration
|
||||
// Prisma may generate DROP INDEX on migrate dev - that's okay, migration recreates it
|
||||
|
||||
@@unique([contentType, contentId, userId], map: "UnifiedContentEmbedding_contentType_contentId_userId_key")
|
||||
@@index([contentType])
|
||||
@@index([userId])
|
||||
@@index([contentType, userId])
|
||||
@@index([embedding], map: "UnifiedContentEmbedding_embedding_idx")
|
||||
// NO @@index for search - GIN index "UnifiedContentEmbedding_search_idx" created via SQL migration
|
||||
// Prisma may generate DROP INDEX on migrate dev - that's okay, migration recreates it
|
||||
}
|
||||
|
||||
model StoreListingReview {
|
||||
@@ -1115,8 +1156,9 @@ model StoreListingReview {
|
||||
score Int
|
||||
comments String?
|
||||
|
||||
// Enforce one review per user per listing version
|
||||
@@unique([storeListingVersionId, reviewByUserId])
|
||||
@@index([reviewByUserId])
|
||||
@@index([storeListingVersionId])
|
||||
}
|
||||
|
||||
enum SubmissionStatus {
|
||||
|
||||
@@ -23,14 +23,14 @@
|
||||
"1.0.0",
|
||||
"1.1.0"
|
||||
],
|
||||
"agentGraphVersions": [
|
||||
"graph_id": "test-graph-id",
|
||||
"graph_versions": [
|
||||
"1",
|
||||
"2"
|
||||
],
|
||||
"agentGraphId": "test-graph-id",
|
||||
"last_updated": "2023-01-01T00:00:00",
|
||||
"recommended_schedule_cron": null,
|
||||
"active_version_id": null,
|
||||
"has_approved_version": false,
|
||||
"active_version_id": "test-version-id",
|
||||
"has_approved_version": true,
|
||||
"changelog": null
|
||||
}
|
||||
@@ -1,14 +1,16 @@
|
||||
{
|
||||
"name": "Test User",
|
||||
"username": "creator1",
|
||||
"name": "Test User",
|
||||
"description": "Test creator description",
|
||||
"avatar_url": "avatar.jpg",
|
||||
"links": [
|
||||
"link1.com",
|
||||
"link2.com"
|
||||
],
|
||||
"avatar_url": "avatar.jpg",
|
||||
"agent_rating": 4.8,
|
||||
"is_featured": true,
|
||||
"num_agents": 5,
|
||||
"agent_runs": 1000,
|
||||
"agent_rating": 4.8,
|
||||
"top_categories": [
|
||||
"category1",
|
||||
"category2"
|
||||
|
||||
@@ -1,54 +1,94 @@
|
||||
{
|
||||
"creators": [
|
||||
{
|
||||
"name": "Creator 0",
|
||||
"username": "creator0",
|
||||
"name": "Creator 0",
|
||||
"description": "Creator 0 description",
|
||||
"avatar_url": "avatar0.jpg",
|
||||
"links": [
|
||||
"user0.link.com"
|
||||
],
|
||||
"is_featured": false,
|
||||
"num_agents": 1,
|
||||
"agent_rating": 4.5,
|
||||
"agent_runs": 100,
|
||||
"is_featured": false
|
||||
"agent_rating": 4.5,
|
||||
"top_categories": [
|
||||
"cat1",
|
||||
"cat2",
|
||||
"cat3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Creator 1",
|
||||
"username": "creator1",
|
||||
"name": "Creator 1",
|
||||
"description": "Creator 1 description",
|
||||
"avatar_url": "avatar1.jpg",
|
||||
"links": [
|
||||
"user1.link.com"
|
||||
],
|
||||
"is_featured": false,
|
||||
"num_agents": 1,
|
||||
"agent_rating": 4.5,
|
||||
"agent_runs": 100,
|
||||
"is_featured": false
|
||||
"agent_rating": 4.5,
|
||||
"top_categories": [
|
||||
"cat1",
|
||||
"cat2",
|
||||
"cat3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Creator 2",
|
||||
"username": "creator2",
|
||||
"name": "Creator 2",
|
||||
"description": "Creator 2 description",
|
||||
"avatar_url": "avatar2.jpg",
|
||||
"links": [
|
||||
"user2.link.com"
|
||||
],
|
||||
"is_featured": false,
|
||||
"num_agents": 1,
|
||||
"agent_rating": 4.5,
|
||||
"agent_runs": 100,
|
||||
"is_featured": false
|
||||
"agent_rating": 4.5,
|
||||
"top_categories": [
|
||||
"cat1",
|
||||
"cat2",
|
||||
"cat3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Creator 3",
|
||||
"username": "creator3",
|
||||
"name": "Creator 3",
|
||||
"description": "Creator 3 description",
|
||||
"avatar_url": "avatar3.jpg",
|
||||
"links": [
|
||||
"user3.link.com"
|
||||
],
|
||||
"is_featured": false,
|
||||
"num_agents": 1,
|
||||
"agent_rating": 4.5,
|
||||
"agent_runs": 100,
|
||||
"is_featured": false
|
||||
"agent_rating": 4.5,
|
||||
"top_categories": [
|
||||
"cat1",
|
||||
"cat2",
|
||||
"cat3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Creator 4",
|
||||
"username": "creator4",
|
||||
"name": "Creator 4",
|
||||
"description": "Creator 4 description",
|
||||
"avatar_url": "avatar4.jpg",
|
||||
"links": [
|
||||
"user4.link.com"
|
||||
],
|
||||
"is_featured": false,
|
||||
"num_agents": 1,
|
||||
"agent_rating": 4.5,
|
||||
"agent_runs": 100,
|
||||
"is_featured": false
|
||||
"agent_rating": 4.5,
|
||||
"top_categories": [
|
||||
"cat1",
|
||||
"cat2",
|
||||
"cat3"
|
||||
]
|
||||
}
|
||||
],
|
||||
"pagination": {
|
||||
|
||||
@@ -2,32 +2,33 @@
|
||||
"submissions": [
|
||||
{
|
||||
"listing_id": "test-listing-id",
|
||||
"agent_id": "test-agent-id",
|
||||
"agent_version": 1,
|
||||
"user_id": "test-user-id",
|
||||
"slug": "test-agent",
|
||||
"listing_version_id": "test-version-id",
|
||||
"listing_version": 1,
|
||||
"graph_id": "test-agent-id",
|
||||
"graph_version": 1,
|
||||
"name": "Test Agent",
|
||||
"sub_heading": "Test agent subheading",
|
||||
"slug": "test-agent",
|
||||
"description": "Test agent description",
|
||||
"instructions": null,
|
||||
"instructions": "Click the button!",
|
||||
"categories": [
|
||||
"test-category"
|
||||
],
|
||||
"image_urls": [
|
||||
"test.jpg"
|
||||
],
|
||||
"date_submitted": "2023-01-01T00:00:00",
|
||||
"video_url": "test.mp4",
|
||||
"agent_output_demo_url": "demo_video.mp4",
|
||||
"submitted_at": "2023-01-01T00:00:00",
|
||||
"changes_summary": "Initial Submission",
|
||||
"status": "APPROVED",
|
||||
"runs": 50,
|
||||
"rating": 4.2,
|
||||
"store_listing_version_id": null,
|
||||
"version": null,
|
||||
"reviewed_at": null,
|
||||
"reviewer_id": null,
|
||||
"review_comments": null,
|
||||
"internal_comments": null,
|
||||
"reviewed_at": null,
|
||||
"changes_summary": null,
|
||||
"video_url": "test.mp4",
|
||||
"agent_output_demo_url": null,
|
||||
"categories": [
|
||||
"test-category"
|
||||
]
|
||||
"run_count": 50,
|
||||
"review_count": 5,
|
||||
"review_avg_rating": 4.2
|
||||
}
|
||||
],
|
||||
"pagination": {
|
||||
|
||||
@@ -128,7 +128,7 @@ class TestDataCreator:
|
||||
email = "test123@gmail.com"
|
||||
else:
|
||||
email = faker.unique.email()
|
||||
password = "testpassword123" # Standard test password
|
||||
password = "testpassword123" # Standard test password # pragma: allowlist secret # noqa
|
||||
user_id = f"test-user-{i}-{faker.uuid4()}"
|
||||
|
||||
# Create user in Supabase Auth (if needed)
|
||||
@@ -571,8 +571,8 @@ class TestDataCreator:
|
||||
if test_user and self.agent_graphs:
|
||||
test_submission_data = {
|
||||
"user_id": test_user["id"],
|
||||
"agent_id": self.agent_graphs[0]["id"],
|
||||
"agent_version": 1,
|
||||
"graph_id": self.agent_graphs[0]["id"],
|
||||
"graph_version": 1,
|
||||
"slug": "test-agent-submission",
|
||||
"name": "Test Agent Submission",
|
||||
"sub_heading": "A test agent for frontend testing",
|
||||
@@ -593,9 +593,9 @@ class TestDataCreator:
|
||||
print("✅ Created special test store submission for test123@gmail.com")
|
||||
|
||||
# ALWAYS approve and feature the test submission
|
||||
if test_submission.store_listing_version_id:
|
||||
if test_submission.listing_version_id:
|
||||
approved_submission = await review_store_submission(
|
||||
store_listing_version_id=test_submission.store_listing_version_id,
|
||||
store_listing_version_id=test_submission.listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Test submission approved",
|
||||
internal_comments="Auto-approved test submission",
|
||||
@@ -605,7 +605,7 @@ class TestDataCreator:
|
||||
print("✅ Approved test store submission")
|
||||
|
||||
await prisma.storelistingversion.update(
|
||||
where={"id": test_submission.store_listing_version_id},
|
||||
where={"id": test_submission.listing_version_id},
|
||||
data={"isFeatured": True},
|
||||
)
|
||||
featured_count += 1
|
||||
@@ -640,8 +640,8 @@ class TestDataCreator:
|
||||
|
||||
submission = await create_store_submission(
|
||||
user_id=user["id"],
|
||||
agent_id=graph["id"],
|
||||
agent_version=graph.get("version", 1),
|
||||
graph_id=graph["id"],
|
||||
graph_version=graph.get("version", 1),
|
||||
slug=faker.slug(),
|
||||
name=graph.get("name", faker.sentence(nb_words=3)),
|
||||
sub_heading=faker.sentence(),
|
||||
@@ -654,7 +654,7 @@ class TestDataCreator:
|
||||
submissions.append(submission.model_dump())
|
||||
print(f"✅ Created store submission: {submission.name}")
|
||||
|
||||
if submission.store_listing_version_id:
|
||||
if submission.listing_version_id:
|
||||
# DETERMINISTIC: First N submissions are always approved
|
||||
# First GUARANTEED_FEATURED_AGENTS of those are always featured
|
||||
should_approve = (
|
||||
@@ -667,7 +667,7 @@ class TestDataCreator:
|
||||
try:
|
||||
reviewer_id = random.choice(self.users)["id"]
|
||||
approved_submission = await review_store_submission(
|
||||
store_listing_version_id=submission.store_listing_version_id,
|
||||
store_listing_version_id=submission.listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Auto-approved for E2E testing",
|
||||
internal_comments="Automatically approved by E2E test data script",
|
||||
@@ -683,9 +683,7 @@ class TestDataCreator:
|
||||
if should_feature:
|
||||
try:
|
||||
await prisma.storelistingversion.update(
|
||||
where={
|
||||
"id": submission.store_listing_version_id
|
||||
},
|
||||
where={"id": submission.listing_version_id},
|
||||
data={"isFeatured": True},
|
||||
)
|
||||
featured_count += 1
|
||||
@@ -699,9 +697,7 @@ class TestDataCreator:
|
||||
elif random.random() < 0.2:
|
||||
try:
|
||||
await prisma.storelistingversion.update(
|
||||
where={
|
||||
"id": submission.store_listing_version_id
|
||||
},
|
||||
where={"id": submission.listing_version_id},
|
||||
data={"isFeatured": True},
|
||||
)
|
||||
featured_count += 1
|
||||
@@ -721,7 +717,7 @@ class TestDataCreator:
|
||||
try:
|
||||
reviewer_id = random.choice(self.users)["id"]
|
||||
await review_store_submission(
|
||||
store_listing_version_id=submission.store_listing_version_id,
|
||||
store_listing_version_id=submission.listing_version_id,
|
||||
is_approved=False,
|
||||
external_comments="Submission rejected - needs improvements",
|
||||
internal_comments="Automatically rejected by E2E test data script",
|
||||
|
||||
@@ -394,7 +394,6 @@ async def main():
|
||||
listing = await db.storelisting.create(
|
||||
data={
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"owningUserId": user.id,
|
||||
"hasApprovedVersion": random.choice([True, False]),
|
||||
"slug": slug,
|
||||
|
||||
393
autogpt_platform/docs/beta-invite-design.md
Normal file
393
autogpt_platform/docs/beta-invite-design.md
Normal file
@@ -0,0 +1,393 @@
|
||||
# Beta Invite Pre-Provisioning Design
|
||||
|
||||
## Problem
|
||||
|
||||
The current signup path is split across three places:
|
||||
|
||||
- Supabase creates the auth user and password.
|
||||
- The backend lazily creates `platform.User` on first authenticated request in [`backend/backend/data/user.py`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/backend/backend/data/user.py).
|
||||
- A Postgres trigger creates `platform.Profile` after `auth.users` insert in [`backend/migrations/20250205100104_add_profile_trigger/migration.sql`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/backend/migrations/20250205100104_add_profile_trigger/migration.sql).
|
||||
|
||||
That works for open signup plus a DB-level allowlist, but it does not give the platform a durable object representing:
|
||||
|
||||
- a beta invite before the person signs up
|
||||
- pre-computed onboarding data tied to an email before auth exists
|
||||
- the distinction between "invited", "claimed", "active", and "revoked"
|
||||
|
||||
It also makes first-login setup racey because `User`, `Profile`, `UserOnboarding`, and `CoPilotUnderstanding` are not created from one source of truth.
|
||||
|
||||
## Goals
|
||||
|
||||
- Allow staff to invite a user before they have a Supabase account.
|
||||
- Pre-create the platform-side user record and related data before first login.
|
||||
- Keep Supabase responsible for password entry, email verification, sessions, and OAuth providers.
|
||||
- Support both password signup and magic-link / OAuth signup for invited users.
|
||||
- Preserve the existing ability to populate Tally-derived understanding and onboarding defaults.
|
||||
- Make first login idempotent and safe if the user retries or signs in with a different method.
|
||||
|
||||
## Non-goals
|
||||
|
||||
- Replacing Supabase Auth.
|
||||
- Building a full enterprise identity management system.
|
||||
- Solving general-purpose team/org invites in this change.
|
||||
|
||||
## Proposed model
|
||||
|
||||
Introduce invite-backed pre-provisioning with email as the pre-auth identity key, then bind the invite to the Supabase `auth.users.id` when the user claims it.
|
||||
|
||||
### New tables
|
||||
|
||||
#### `BetaInvite`
|
||||
|
||||
Represents an invitation sent to a person before they create credentials.
|
||||
|
||||
Suggested fields:
|
||||
|
||||
- `id`
|
||||
- `email` unique
|
||||
- `status` enum: `PENDING`, `CLAIMED`, `EXPIRED`, `REVOKED`
|
||||
- `inviteTokenHash` nullable
|
||||
- `invitedByUserId` nullable
|
||||
- `expiresAt` nullable
|
||||
- `claimedAt` nullable
|
||||
- `claimedAuthUserId` nullable unique
|
||||
- `metadata` jsonb
|
||||
- `createdAt`
|
||||
- `updatedAt`
|
||||
|
||||
`metadata` should hold operational fields only, for example:
|
||||
|
||||
- source: manual import, admin UI, CSV
|
||||
- cohort: beta wave name
|
||||
- notes
|
||||
- original tally email if normalization differs
|
||||
|
||||
#### `PreProvisionedUser`
|
||||
|
||||
Represents the platform-side user state that exists before Supabase auth is bound.
|
||||
|
||||
Suggested fields:
|
||||
|
||||
- `id`
|
||||
- `inviteId` unique
|
||||
- `email` unique
|
||||
- `authUserId` nullable unique
|
||||
- `status` enum: `PENDING_CLAIM`, `ACTIVE`, `MERGE_REQUIRED`, `DISABLED`
|
||||
- `name` nullable
|
||||
- `timezone` default `not-set`
|
||||
- `emailVerified` default `false`
|
||||
- `metadata` jsonb
|
||||
- `createdAt`
|
||||
- `updatedAt`
|
||||
|
||||
This is the durable record staff can enrich before login.
|
||||
|
||||
#### `PreProvisionedUserSeed`
|
||||
|
||||
Stores structured seed data that should become first-class user records on claim.
|
||||
|
||||
Suggested fields:
|
||||
|
||||
- `id`
|
||||
- `preProvisionedUserId` unique
|
||||
- `profile` jsonb nullable
|
||||
- `onboarding` jsonb nullable
|
||||
- `businessUnderstanding` jsonb nullable
|
||||
- `promptContext` jsonb nullable
|
||||
- `createdAt`
|
||||
- `updatedAt`
|
||||
|
||||
This avoids polluting the main `User.metadata` blob and gives explicit ownership over what is seed data versus ongoing user-authored data.
|
||||
|
||||
## Why not pre-create `platform.User` directly?
|
||||
|
||||
`platform.User.id` is currently designed to equal the Supabase user id in [`backend/schema.prisma`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/backend/schema.prisma). Pre-creating `User` with a fake UUID would ripple through every FK and complicate the eventual bind. Reusing email as the join key inside `User` is also not safe enough because the rest of the system assumes `User.id` is the canonical identity.
|
||||
|
||||
A separate pre-provisioning layer is lower risk:
|
||||
|
||||
- it avoids breaking every relation off `User`
|
||||
- it keeps the existing auth token model intact
|
||||
- it gives a clean migration path to merge into `User` once a Supabase identity exists
|
||||
|
||||
## Claim flow
|
||||
|
||||
### 1. Staff creates invite
|
||||
|
||||
Admin API or script:
|
||||
|
||||
1. Create `BetaInvite`.
|
||||
2. Create `PreProvisionedUser`.
|
||||
3. Create `PreProvisionedUserSeed`.
|
||||
4. Optionally fetch and store Tally-derived understanding immediately by email.
|
||||
5. Optionally send invite email containing a claim link.
|
||||
|
||||
### 2. User opens signup page
|
||||
|
||||
Two supported modes:
|
||||
|
||||
- direct signup with invited email
|
||||
- invite-link signup with a short-lived token
|
||||
|
||||
Preferred UX:
|
||||
|
||||
- `/signup?invite=<token>`
|
||||
- frontend resolves token to masked email and invite status
|
||||
- email field is prefilled and locked by default
|
||||
|
||||
### 3. User creates Supabase account
|
||||
|
||||
Frontend still calls Supabase `signUp`, but include invite context in user metadata:
|
||||
|
||||
- `invite_id`
|
||||
- `invite_email`
|
||||
|
||||
This is useful for debugging but should not be the source of truth.
|
||||
|
||||
### 4. First authenticated backend call activates invite
|
||||
|
||||
Replace the current pure `get_or_create_user` behavior with:
|
||||
|
||||
1. Read JWT `sub` and `email`.
|
||||
2. Look for existing `platform.User` by `id`.
|
||||
3. If found, return it.
|
||||
4. Else look for `PreProvisionedUser` by normalized email and `status = PENDING_CLAIM`.
|
||||
5. In one transaction:
|
||||
- create `platform.User` with `id = auth sub`
|
||||
- bind `PreProvisionedUser.authUserId = auth sub`
|
||||
- mark `PreProvisionedUser.status = ACTIVE`
|
||||
- mark `BetaInvite.status = CLAIMED`
|
||||
- create or upsert `Profile`
|
||||
- create or upsert `UserOnboarding`
|
||||
- create or upsert `CoPilotUnderstanding`
|
||||
- create `UserWorkspace` if required for the product experience
|
||||
6. If no pre-provisioned record exists, either:
|
||||
- reject access for closed beta, or
|
||||
- fall back to normal creation if a feature flag allows open signup
|
||||
|
||||
This activation logic should live in the backend service, not a Postgres trigger, because it needs to coordinate across multiple platform tables and apply merge rules.
|
||||
|
||||
## Data merge rules
|
||||
|
||||
When activating a pre-provisioned invite:
|
||||
|
||||
- `User`
|
||||
- source of truth for `id` is Supabase JWT `sub`
|
||||
- source of truth for `email` is Supabase auth email
|
||||
- `name` prefers pre-provisioned value, falls back to auth metadata name
|
||||
|
||||
- `Profile`
|
||||
- if seed profile exists, use it
|
||||
- else create the current generated default
|
||||
- never overwrite an existing profile if activation is retried
|
||||
|
||||
- `UserOnboarding`
|
||||
- seed values are initial defaults only
|
||||
- use `upsert` with create-on-missing semantics
|
||||
|
||||
- `CoPilotUnderstanding`
|
||||
- seed from stored Tally extraction if present
|
||||
- skip Tally backfill on first login if this already exists
|
||||
|
||||
- prompt/tally-specific context
|
||||
- do not jam this into `User.metadata` unless no better typed model exists
|
||||
- use `PreProvisionedUserSeed.promptContext` now, migrate to typed tables later if product solidifies
|
||||
|
||||
## Invite enforcement
|
||||
|
||||
The current closed-beta gate appears to rely on a Supabase-side DB error surfaced as "not allowed" in [`frontend/src/app/api/auth/utils.ts`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/frontend/src/app/api/auth/utils.ts).
|
||||
|
||||
That should evolve to:
|
||||
|
||||
### Short term
|
||||
|
||||
Keep the current Supabase signup restriction, but have it check `BetaInvite`/`PreProvisionedUser` instead of a raw allowlist.
|
||||
|
||||
### Medium term
|
||||
|
||||
Stop using a generic DB exception as the main product signal and expose a backend endpoint:
|
||||
|
||||
- `POST /internal/beta-invites/validate`
|
||||
- `GET /internal/beta-invites/:token`
|
||||
|
||||
Then the frontend can fail earlier with a specific state:
|
||||
|
||||
- invited and ready
|
||||
- invite expired
|
||||
- already claimed
|
||||
- not invited
|
||||
|
||||
Supabase should still remain the hard gate for credential creation, but the UI should stop depending on opaque trigger failures for normal control flow.
|
||||
|
||||
## Trigger changes
|
||||
|
||||
The `auth.users` trigger that creates `platform.User` and `platform.Profile` should be removed once activation logic is live.
|
||||
|
||||
Reason:
|
||||
|
||||
- it cannot see or safely apply pre-provisioned seed data
|
||||
- it only handles create, not merge/bind
|
||||
- it duplicates responsibility already present in `get_or_create_user`
|
||||
|
||||
Interim state during rollout:
|
||||
|
||||
- keep trigger disabled for production once backend activation ships
|
||||
- keep a backfill script for any auth users created during the transition
|
||||
|
||||
## API surface
|
||||
|
||||
### Admin APIs
|
||||
|
||||
- `POST /admin/beta-invites`
|
||||
- create invite + pre-provisioned user + seed data
|
||||
- `POST /admin/beta-invites/bulk`
|
||||
- CSV import for beta cohorts
|
||||
- `POST /admin/beta-invites/:id/resend`
|
||||
- `POST /admin/beta-invites/:id/revoke`
|
||||
- `GET /admin/beta-invites`
|
||||
|
||||
### Public/auth-adjacent APIs
|
||||
|
||||
- `GET /beta-invites/lookup?token=...`
|
||||
- return safe invite info for signup page
|
||||
- `POST /beta-invites/claim-preview`
|
||||
- optional: check whether an email is invited before attempting Supabase signup
|
||||
|
||||
### Internal service function changes
|
||||
|
||||
- Replace `get_or_create_user` with `get_or_activate_user`
|
||||
- keep a compatibility wrapper if needed to limit churn
|
||||
|
||||
## Suggested backend implementation
|
||||
|
||||
### New module
|
||||
|
||||
- `backend/backend/data/beta_invite.py`
|
||||
|
||||
Responsibilities:
|
||||
|
||||
- create invite
|
||||
- resolve invite token
|
||||
- normalize email
|
||||
- activate pre-provisioned user
|
||||
- revoke / expire invite
|
||||
|
||||
### Existing module changes
|
||||
|
||||
- [`backend/backend/data/user.py`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/backend/backend/data/user.py)
|
||||
- move lazy creation logic into activation-aware flow
|
||||
|
||||
- [`backend/backend/api/features/v1.py`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/backend/backend/api/features/v1.py)
|
||||
- `/auth/user` should call activation-aware function
|
||||
- only run background Tally population if no seeded understanding exists
|
||||
|
||||
- [`backend/backend/data/tally.py`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/backend/backend/data/tally.py)
|
||||
- add a reusable "seed from email" function for invite creation time
|
||||
|
||||
## Suggested frontend implementation
|
||||
|
||||
### Signup
|
||||
|
||||
- read invite token from URL
|
||||
- call invite lookup endpoint
|
||||
- prefill locked email when token is valid
|
||||
- if invite is invalid, show specific error, not generic waitlist modal
|
||||
|
||||
Files likely affected:
|
||||
|
||||
- [`frontend/src/app/(platform)/signup/page.tsx`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/frontend/src/app/(platform)/signup/page.tsx)
|
||||
- [`frontend/src/app/(platform)/signup/actions.ts`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/frontend/src/app/(platform)/signup/actions.ts)
|
||||
- [`frontend/src/components/auth/WaitlistErrorContent.tsx`](/Users/swifty/work/agpt/AutoGPT/autogpt_platform/frontend/src/components/auth/WaitlistErrorContent.tsx)
|
||||
|
||||
### Admin UI
|
||||
|
||||
Add a simple internal page for beta invites instead of manually editing allowlists.
|
||||
|
||||
Possible location:
|
||||
|
||||
- `frontend/src/app/(platform)/admin/users/invites/page.tsx`
|
||||
|
||||
## Rollout plan
|
||||
|
||||
### Phase 1: schema and service layer
|
||||
|
||||
- add `BetaInvite`, `PreProvisionedUser`, `PreProvisionedUserSeed`
|
||||
- implement activation transaction
|
||||
- keep existing trigger/allowlist in place
|
||||
|
||||
### Phase 2: admin creation path
|
||||
|
||||
- add admin API or CLI script
|
||||
- support single invite and CSV bulk upload
|
||||
- seed Tally/business understanding during invite creation
|
||||
|
||||
### Phase 3: signup UX
|
||||
|
||||
- invite token lookup
|
||||
- better invite state messaging
|
||||
- preserve existing closed beta modal for non-invited traffic
|
||||
|
||||
### Phase 4: remove legacy coupling
|
||||
|
||||
- disable `auth.users` profile trigger
|
||||
- simplify `get_or_create_user`
|
||||
- migrate allowlist logic to invite tables
|
||||
|
||||
## Edge cases
|
||||
|
||||
### Existing auth user, no platform user
|
||||
|
||||
This already happens today. Activation flow should treat it as:
|
||||
|
||||
- if auth email matches a pending pre-provisioned invite, bind and activate
|
||||
- else create a plain `User` only if open-signup feature flag is enabled
|
||||
|
||||
### Existing platform user, invited again
|
||||
|
||||
Do not create another `PreProvisionedUser`. Create a second invite only if product explicitly wants re-invites. Otherwise reject as duplicate.
|
||||
|
||||
### Email changed after invite
|
||||
|
||||
Support admin-side reissue:
|
||||
|
||||
- revoke old invite
|
||||
- create new invite with new email
|
||||
- move seed data forward
|
||||
|
||||
Do not automatically bind across unrelated email addresses.
|
||||
|
||||
### OAuth signup
|
||||
|
||||
OAuth provider signups still work as long as the resulting Supabase email matches the invited email. If the provider returns a different email, activation should fail with a clear UI message.
|
||||
|
||||
### Tally data arrives after invite creation
|
||||
|
||||
Allow re-seeding before claim if `CoPilotUnderstanding` has not yet been created on activation.
|
||||
|
||||
## Migration notes
|
||||
|
||||
### Existing beta users
|
||||
|
||||
No immediate migration required for already-active users. This system is mainly for future invited users.
|
||||
|
||||
### Existing allowlist entries
|
||||
|
||||
Backfill them into `BetaInvite` plus `PreProvisionedUser`, then swap the Supabase gating logic to consult invite tables instead of the legacy allowlist table/trigger.
|
||||
|
||||
## Recommendation
|
||||
|
||||
Implement the pre-provisioning layer first and keep `platform.User` bound to Supabase `auth.users.id`. That is the lowest-risk design because it respects the existing identity model while giving the business exactly what it needs:
|
||||
|
||||
- invite someone before signup
|
||||
- compute and store Tally/onboarding/prompt defaults before first login
|
||||
- activate those defaults atomically when the user actually creates credentials
|
||||
|
||||
## First implementation slice
|
||||
|
||||
The smallest useful slice is:
|
||||
|
||||
1. Add `BetaInvite`, `PreProvisionedUser`, and `PreProvisionedUserSeed`.
|
||||
2. Add a backend admin endpoint to create an invite from email plus optional seed payload.
|
||||
3. Change `/auth/user` activation logic to bind and materialize seeded `Profile`, `UserOnboarding`, and `CoPilotUnderstanding`.
|
||||
4. Keep the existing signup UI, but validate invite membership before or during signup.
|
||||
|
||||
That delivers the core behavior without needing the full admin UI on day one.
|
||||
@@ -10,6 +10,7 @@
|
||||
"cssVariables": false,
|
||||
"prefix": ""
|
||||
},
|
||||
"iconLibrary": "radix",
|
||||
"aliases": {
|
||||
"components": "@/components",
|
||||
"utils": "@/lib/utils"
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
import { Sidebar } from "@/components/__legacy__/Sidebar";
|
||||
import { Users, DollarSign, UserSearch, FileText } from "lucide-react";
|
||||
import {
|
||||
Users,
|
||||
DollarSign,
|
||||
UserSearch,
|
||||
FileText,
|
||||
UserPlus,
|
||||
} from "lucide-react";
|
||||
|
||||
import { IconSliders } from "@/components/__legacy__/ui/icons";
|
||||
|
||||
@@ -16,6 +22,11 @@ const sidebarLinkGroups = [
|
||||
href: "/admin/spending",
|
||||
icon: <DollarSign className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Beta Invites",
|
||||
href: "/admin/users",
|
||||
icon: <UserPlus className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "User Impersonation",
|
||||
href: "/admin/impersonation",
|
||||
|
||||
@@ -1,33 +1,39 @@
|
||||
"use server";
|
||||
|
||||
import { revalidatePath } from "next/cache";
|
||||
import BackendApi from "@/lib/autogpt-server-api";
|
||||
import {
|
||||
StoreListingsWithVersionsResponse,
|
||||
SubmissionStatus,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
getV2GetAdminListingsHistory,
|
||||
postV2ReviewStoreSubmission,
|
||||
getV2AdminDownloadAgentFile,
|
||||
} from "@/app/api/__generated__/endpoints/admin/admin";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
|
||||
|
||||
export async function approveAgent(formData: FormData) {
|
||||
const data = {
|
||||
store_listing_version_id: formData.get("id") as string,
|
||||
const storeListingVersionId = formData.get("id") as string;
|
||||
const comments = formData.get("comments") as string;
|
||||
|
||||
await postV2ReviewStoreSubmission(storeListingVersionId, {
|
||||
store_listing_version_id: storeListingVersionId,
|
||||
is_approved: true,
|
||||
comments: formData.get("comments") as string,
|
||||
};
|
||||
const api = new BackendApi();
|
||||
await api.reviewSubmissionAdmin(data.store_listing_version_id, data);
|
||||
comments,
|
||||
});
|
||||
|
||||
revalidatePath("/admin/marketplace");
|
||||
}
|
||||
|
||||
export async function rejectAgent(formData: FormData) {
|
||||
const data = {
|
||||
store_listing_version_id: formData.get("id") as string,
|
||||
const storeListingVersionId = formData.get("id") as string;
|
||||
const comments = formData.get("comments") as string;
|
||||
const internal_comments =
|
||||
(formData.get("internal_comments") as string) || undefined;
|
||||
|
||||
await postV2ReviewStoreSubmission(storeListingVersionId, {
|
||||
store_listing_version_id: storeListingVersionId,
|
||||
is_approved: false,
|
||||
comments: formData.get("comments") as string,
|
||||
internal_comments: formData.get("internal_comments") as string,
|
||||
};
|
||||
const api = new BackendApi();
|
||||
await api.reviewSubmissionAdmin(data.store_listing_version_id, data);
|
||||
comments,
|
||||
internal_comments,
|
||||
});
|
||||
|
||||
revalidatePath("/admin/marketplace");
|
||||
}
|
||||
@@ -37,26 +43,18 @@ export async function getAdminListingsWithVersions(
|
||||
search?: string,
|
||||
page: number = 1,
|
||||
pageSize: number = 20,
|
||||
): Promise<StoreListingsWithVersionsResponse> {
|
||||
const data: Record<string, any> = {
|
||||
) {
|
||||
const response = await getV2GetAdminListingsHistory({
|
||||
status,
|
||||
search,
|
||||
page,
|
||||
page_size: pageSize,
|
||||
};
|
||||
});
|
||||
|
||||
if (status) {
|
||||
data.status = status;
|
||||
}
|
||||
|
||||
if (search) {
|
||||
data.search = search;
|
||||
}
|
||||
const api = new BackendApi();
|
||||
const response = await api.getAdminListingsWithVersions(data);
|
||||
return response;
|
||||
return okData(response);
|
||||
}
|
||||
|
||||
export async function downloadAsAdmin(storeListingVersion: string) {
|
||||
const api = new BackendApi();
|
||||
const file = await api.downloadStoreAgentAdmin(storeListingVersion);
|
||||
return file;
|
||||
const response = await getV2AdminDownloadAgentFile(storeListingVersion);
|
||||
return okData(response);
|
||||
}
|
||||
|
||||
@@ -6,10 +6,8 @@ import {
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/__legacy__/ui/table";
|
||||
import {
|
||||
StoreSubmission,
|
||||
SubmissionStatus,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import type { StoreSubmissionAdminView } from "@/app/api/__generated__/models/storeSubmissionAdminView";
|
||||
import type { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
|
||||
import { PaginationControls } from "../../../../../components/__legacy__/ui/pagination-controls";
|
||||
import { getAdminListingsWithVersions } from "@/app/(platform)/admin/marketplace/actions";
|
||||
import { ExpandableRow } from "./ExpandleRow";
|
||||
@@ -17,12 +15,14 @@ import { SearchAndFilterAdminMarketplace } from "./SearchFilterForm";
|
||||
|
||||
// Helper function to get the latest version by version number
|
||||
const getLatestVersionByNumber = (
|
||||
versions: StoreSubmission[],
|
||||
): StoreSubmission | null => {
|
||||
versions: StoreSubmissionAdminView[] | undefined,
|
||||
): StoreSubmissionAdminView | null => {
|
||||
if (!versions || versions.length === 0) return null;
|
||||
return versions.reduce(
|
||||
(latest, current) =>
|
||||
(current.version ?? 0) > (latest.version ?? 1) ? current : latest,
|
||||
(current.listing_version ?? 0) > (latest.listing_version ?? 1)
|
||||
? current
|
||||
: latest,
|
||||
versions[0],
|
||||
);
|
||||
};
|
||||
@@ -37,12 +37,14 @@ export async function AdminAgentsDataTable({
|
||||
initialSearch?: string;
|
||||
}) {
|
||||
// Server-side data fetching
|
||||
const { listings, pagination } = await getAdminListingsWithVersions(
|
||||
const data = await getAdminListingsWithVersions(
|
||||
initialStatus,
|
||||
initialSearch,
|
||||
initialPage,
|
||||
10,
|
||||
);
|
||||
const listings = data?.listings ?? [];
|
||||
const pagination = data?.pagination;
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
@@ -92,7 +94,7 @@ export async function AdminAgentsDataTable({
|
||||
|
||||
<PaginationControls
|
||||
currentPage={initialPage}
|
||||
totalPages={pagination.total_pages}
|
||||
totalPages={pagination?.total_pages ?? 1}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -13,7 +13,7 @@ import {
|
||||
} from "@/components/__legacy__/ui/dialog";
|
||||
import { Label } from "@/components/__legacy__/ui/label";
|
||||
import { Textarea } from "@/components/__legacy__/ui/textarea";
|
||||
import type { StoreSubmission } from "@/lib/autogpt-server-api/types";
|
||||
import type { StoreSubmissionAdminView } from "@/app/api/__generated__/models/storeSubmissionAdminView";
|
||||
import { useRouter } from "next/navigation";
|
||||
import {
|
||||
approveAgent,
|
||||
@@ -23,7 +23,7 @@ import {
|
||||
export function ApproveRejectButtons({
|
||||
version,
|
||||
}: {
|
||||
version: StoreSubmission;
|
||||
version: StoreSubmissionAdminView;
|
||||
}) {
|
||||
const router = useRouter();
|
||||
const [isApproveDialogOpen, setIsApproveDialogOpen] = useState(false);
|
||||
@@ -95,7 +95,7 @@ export function ApproveRejectButtons({
|
||||
<input
|
||||
type="hidden"
|
||||
name="id"
|
||||
value={version.store_listing_version_id || ""}
|
||||
value={version.listing_version_id || ""}
|
||||
/>
|
||||
|
||||
<div className="grid gap-4 py-4">
|
||||
@@ -142,7 +142,7 @@ export function ApproveRejectButtons({
|
||||
<input
|
||||
type="hidden"
|
||||
name="id"
|
||||
value={version.store_listing_version_id || ""}
|
||||
value={version.listing_version_id || ""}
|
||||
/>
|
||||
|
||||
<div className="grid gap-4 py-4">
|
||||
|
||||
@@ -12,11 +12,9 @@ import {
|
||||
import { Badge } from "@/components/__legacy__/ui/badge";
|
||||
import { ChevronDown, ChevronRight } from "lucide-react";
|
||||
import { formatDistanceToNow } from "date-fns";
|
||||
import {
|
||||
type StoreListingWithVersions,
|
||||
type StoreSubmission,
|
||||
SubmissionStatus,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import type { StoreListingWithVersionsAdminView } from "@/app/api/__generated__/models/storeListingWithVersionsAdminView";
|
||||
import type { StoreSubmissionAdminView } from "@/app/api/__generated__/models/storeSubmissionAdminView";
|
||||
import { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
|
||||
import { ApproveRejectButtons } from "./ApproveRejectButton";
|
||||
import { DownloadAgentAdminButton } from "./DownloadAgentButton";
|
||||
|
||||
@@ -38,8 +36,8 @@ export function ExpandableRow({
|
||||
listing,
|
||||
latestVersion,
|
||||
}: {
|
||||
listing: StoreListingWithVersions;
|
||||
latestVersion: StoreSubmission | null;
|
||||
listing: StoreListingWithVersionsAdminView;
|
||||
latestVersion: StoreSubmissionAdminView | null;
|
||||
}) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
|
||||
@@ -69,17 +67,17 @@ export function ExpandableRow({
|
||||
{latestVersion?.status && getStatusBadge(latestVersion.status)}
|
||||
</TableCell>
|
||||
<TableCell onClick={() => setExpanded(!expanded)}>
|
||||
{latestVersion?.date_submitted
|
||||
? formatDistanceToNow(new Date(latestVersion.date_submitted), {
|
||||
{latestVersion?.submitted_at
|
||||
? formatDistanceToNow(new Date(latestVersion.submitted_at), {
|
||||
addSuffix: true,
|
||||
})
|
||||
: "Unknown"}
|
||||
</TableCell>
|
||||
<TableCell className="text-right">
|
||||
<div className="flex justify-end gap-2">
|
||||
{latestVersion?.store_listing_version_id && (
|
||||
{latestVersion?.listing_version_id && (
|
||||
<DownloadAgentAdminButton
|
||||
storeListingVersionId={latestVersion.store_listing_version_id}
|
||||
storeListingVersionId={latestVersion.listing_version_id}
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -115,14 +113,17 @@ export function ExpandableRow({
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{listing.versions
|
||||
.sort((a, b) => (b.version ?? 1) - (a.version ?? 0))
|
||||
{(listing.versions ?? [])
|
||||
.sort(
|
||||
(a, b) =>
|
||||
(b.listing_version ?? 1) - (a.listing_version ?? 0),
|
||||
)
|
||||
.map((version) => (
|
||||
<TableRow key={version.store_listing_version_id}>
|
||||
<TableRow key={version.listing_version_id}>
|
||||
<TableCell>
|
||||
v{version.version || "?"}
|
||||
{version.store_listing_version_id ===
|
||||
listing.active_version_id && (
|
||||
v{version.listing_version || "?"}
|
||||
{version.listing_version_id ===
|
||||
listing.active_listing_version_id && (
|
||||
<Badge className="ml-2 bg-blue-500">Active</Badge>
|
||||
)}
|
||||
</TableCell>
|
||||
@@ -131,9 +132,9 @@ export function ExpandableRow({
|
||||
{version.changes_summary || "No summary"}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{version.date_submitted
|
||||
{version.submitted_at
|
||||
? formatDistanceToNow(
|
||||
new Date(version.date_submitted),
|
||||
new Date(version.submitted_at),
|
||||
{ addSuffix: true },
|
||||
)
|
||||
: "Unknown"}
|
||||
@@ -182,10 +183,10 @@ export function ExpandableRow({
|
||||
{/* <TableCell>{version.categories.join(", ")}</TableCell> */}
|
||||
<TableCell className="text-right">
|
||||
<div className="flex justify-end gap-2">
|
||||
{version.store_listing_version_id && (
|
||||
{version.listing_version_id && (
|
||||
<DownloadAgentAdminButton
|
||||
storeListingVersionId={
|
||||
version.store_listing_version_id
|
||||
version.listing_version_id
|
||||
}
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -12,7 +12,7 @@ import {
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/__legacy__/ui/select";
|
||||
import { SubmissionStatus } from "@/lib/autogpt-server-api/types";
|
||||
import { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
|
||||
|
||||
export function SearchAndFilterAdminMarketplace({
|
||||
initialSearch,
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { withRoleAccess } from "@/lib/withRoleAccess";
|
||||
import { Suspense } from "react";
|
||||
import type { SubmissionStatus } from "@/lib/autogpt-server-api/types";
|
||||
import type { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
|
||||
import { AdminAgentsDataTable } from "./components/AdminAgentsDataTable";
|
||||
|
||||
type MarketplaceAdminPageSearchParams = {
|
||||
page?: string;
|
||||
status?: string;
|
||||
status?: SubmissionStatus;
|
||||
search?: string;
|
||||
};
|
||||
|
||||
@@ -15,7 +15,7 @@ async function AdminMarketplaceDashboard({
|
||||
searchParams: MarketplaceAdminPageSearchParams;
|
||||
}) {
|
||||
const page = searchParams.page ? Number.parseInt(searchParams.page) : 1;
|
||||
const status = searchParams.status as SubmissionStatus | undefined;
|
||||
const status = searchParams.status;
|
||||
const search = searchParams.search;
|
||||
|
||||
return (
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
"use client";
|
||||
|
||||
import { Card } from "@/components/atoms/Card/Card";
|
||||
import { BulkInviteForm } from "../BulkInviteForm/BulkInviteForm";
|
||||
import { InviteUserForm } from "../InviteUserForm/InviteUserForm";
|
||||
import { InvitedUsersTable } from "../InvitedUsersTable/InvitedUsersTable";
|
||||
import { useAdminUsersPage } from "../../useAdminUsersPage";
|
||||
|
||||
export function AdminUsersPage() {
|
||||
const {
|
||||
email,
|
||||
name,
|
||||
bulkInviteFile,
|
||||
bulkInviteInputKey,
|
||||
lastBulkInviteResult,
|
||||
invitedUsers,
|
||||
isLoadingInvitedUsers,
|
||||
isRefreshingInvitedUsers,
|
||||
isCreatingInvite,
|
||||
isBulkInviting,
|
||||
pendingInviteAction,
|
||||
setEmail,
|
||||
setName,
|
||||
handleBulkInviteFileChange,
|
||||
handleBulkInviteSubmit,
|
||||
handleCreateInvite,
|
||||
handleRetryTally,
|
||||
handleRevoke,
|
||||
} = useAdminUsersPage();
|
||||
|
||||
return (
|
||||
<div className="mx-auto flex max-w-7xl flex-col gap-6 p-6">
|
||||
<div className="flex flex-col gap-2">
|
||||
<h1 className="text-3xl font-bold text-zinc-900">Beta Invites</h1>
|
||||
<p className="max-w-3xl text-sm text-zinc-600">
|
||||
Pre-provision beta users before they sign up. Invites store the
|
||||
platform-side record, run Tally understanding extraction, and activate
|
||||
the real account on the user's first authenticated request.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-6 xl:grid-cols-[24rem,1fr]">
|
||||
<div className="flex flex-col gap-6">
|
||||
<Card className="border border-zinc-200 shadow-sm">
|
||||
<InviteUserForm
|
||||
email={email}
|
||||
name={name}
|
||||
isSubmitting={isCreatingInvite}
|
||||
onEmailChange={setEmail}
|
||||
onNameChange={setName}
|
||||
onSubmit={handleCreateInvite}
|
||||
/>
|
||||
</Card>
|
||||
|
||||
<Card className="border border-zinc-200 shadow-sm">
|
||||
<BulkInviteForm
|
||||
selectedFile={bulkInviteFile}
|
||||
inputKey={bulkInviteInputKey}
|
||||
isSubmitting={isBulkInviting}
|
||||
lastResult={lastBulkInviteResult}
|
||||
onFileChange={handleBulkInviteFileChange}
|
||||
onSubmit={handleBulkInviteSubmit}
|
||||
/>
|
||||
</Card>
|
||||
</div>
|
||||
|
||||
<Card className="border border-zinc-200 shadow-sm">
|
||||
<InvitedUsersTable
|
||||
invitedUsers={invitedUsers}
|
||||
isLoading={isLoadingInvitedUsers}
|
||||
isRefreshing={isRefreshingInvitedUsers}
|
||||
pendingInviteAction={pendingInviteAction}
|
||||
onRetryTally={handleRetryTally}
|
||||
onRevoke={handleRevoke}
|
||||
/>
|
||||
</Card>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
"use client";
|
||||
|
||||
import type { BulkInvitedUsersResponse } from "@/app/api/__generated__/models/bulkInvitedUsersResponse";
|
||||
import { Badge } from "@/components/atoms/Badge/Badge";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { FormEvent } from "react";
|
||||
|
||||
interface Props {
|
||||
selectedFile: File | null;
|
||||
inputKey: number;
|
||||
isSubmitting: boolean;
|
||||
lastResult: BulkInvitedUsersResponse | null;
|
||||
onFileChange: (file: File | null) => void;
|
||||
onSubmit: (event: FormEvent<HTMLFormElement>) => void;
|
||||
}
|
||||
|
||||
function getStatusVariant(status: "CREATED" | "SKIPPED" | "ERROR") {
|
||||
if (status === "CREATED") {
|
||||
return "success";
|
||||
}
|
||||
|
||||
if (status === "ERROR") {
|
||||
return "error";
|
||||
}
|
||||
|
||||
return "info";
|
||||
}
|
||||
|
||||
export function BulkInviteForm({
|
||||
selectedFile,
|
||||
inputKey,
|
||||
isSubmitting,
|
||||
lastResult,
|
||||
onFileChange,
|
||||
onSubmit,
|
||||
}: Props) {
|
||||
return (
|
||||
<form className="flex flex-col gap-4" onSubmit={onSubmit}>
|
||||
<div className="flex flex-col gap-1">
|
||||
<h2 className="text-xl font-semibold text-zinc-900">Bulk invite</h2>
|
||||
<p className="text-sm text-zinc-600">
|
||||
Upload a <span className="font-medium text-zinc-800">.txt</span> file
|
||||
with one email per line, or a{" "}
|
||||
<span className="font-medium text-zinc-800">.csv</span> with
|
||||
<span className="font-medium text-zinc-800"> email</span> and optional
|
||||
<span className="font-medium text-zinc-800"> name</span> columns.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<label className="flex cursor-pointer flex-col gap-2 rounded-2xl border border-dashed border-zinc-300 bg-zinc-50 px-4 py-5 text-sm text-zinc-600 transition-colors hover:border-zinc-400 hover:bg-zinc-100">
|
||||
<span className="font-medium text-zinc-900">
|
||||
{selectedFile ? selectedFile.name : "Choose invite file"}
|
||||
</span>
|
||||
<span>Maximum 500 rows, UTF-8 encoded.</span>
|
||||
<input
|
||||
key={inputKey}
|
||||
type="file"
|
||||
accept=".txt,.csv,text/plain,text/csv"
|
||||
disabled={isSubmitting}
|
||||
className="hidden"
|
||||
onChange={(event) =>
|
||||
onFileChange(event.target.files?.item(0) ?? null)
|
||||
}
|
||||
/>
|
||||
</label>
|
||||
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
loading={isSubmitting}
|
||||
disabled={!selectedFile}
|
||||
className="w-full"
|
||||
>
|
||||
{isSubmitting ? "Uploading invites..." : "Upload invite file"}
|
||||
</Button>
|
||||
|
||||
{lastResult ? (
|
||||
<div className="flex flex-col gap-3 rounded-2xl border border-zinc-200 bg-zinc-50 p-4">
|
||||
<div className="grid grid-cols-3 gap-2 text-center">
|
||||
<div className="rounded-xl bg-white px-3 py-2">
|
||||
<div className="text-lg font-semibold text-zinc-900">
|
||||
{lastResult.created_count}
|
||||
</div>
|
||||
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
|
||||
Created
|
||||
</div>
|
||||
</div>
|
||||
<div className="rounded-xl bg-white px-3 py-2">
|
||||
<div className="text-lg font-semibold text-zinc-900">
|
||||
{lastResult.skipped_count}
|
||||
</div>
|
||||
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
|
||||
Skipped
|
||||
</div>
|
||||
</div>
|
||||
<div className="rounded-xl bg-white px-3 py-2">
|
||||
<div className="text-lg font-semibold text-zinc-900">
|
||||
{lastResult.error_count}
|
||||
</div>
|
||||
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
|
||||
Errors
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="max-h-64 overflow-y-auto rounded-xl border border-zinc-200 bg-white">
|
||||
<div className="flex flex-col divide-y divide-zinc-100">
|
||||
{lastResult.results.map((row) => (
|
||||
<div
|
||||
key={`${row.row_number}-${row.email ?? row.message}`}
|
||||
className="flex items-start gap-3 px-3 py-3"
|
||||
>
|
||||
<Badge variant={getStatusVariant(row.status)} size="small">
|
||||
{row.status}
|
||||
</Badge>
|
||||
<div className="flex min-w-0 flex-1 flex-col gap-1">
|
||||
<span className="text-sm font-medium text-zinc-900">
|
||||
Row {row.row_number}
|
||||
{row.email ? ` · ${row.email}` : ""}
|
||||
</span>
|
||||
<span className="text-xs text-zinc-500">{row.message}</span>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : null}
|
||||
</form>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import type { FormEvent } from "react";
|
||||
|
||||
interface Props {
|
||||
email: string;
|
||||
name: string;
|
||||
isSubmitting: boolean;
|
||||
onEmailChange: (value: string) => void;
|
||||
onNameChange: (value: string) => void;
|
||||
onSubmit: (event: FormEvent<HTMLFormElement>) => void;
|
||||
}
|
||||
|
||||
export function InviteUserForm({
|
||||
email,
|
||||
name,
|
||||
isSubmitting,
|
||||
onEmailChange,
|
||||
onNameChange,
|
||||
onSubmit,
|
||||
}: Props) {
|
||||
return (
|
||||
<form className="flex flex-col gap-4" onSubmit={onSubmit}>
|
||||
<div className="flex flex-col gap-1">
|
||||
<h2 className="text-xl font-semibold text-zinc-900">Create invite</h2>
|
||||
<p className="text-sm text-zinc-600">
|
||||
The invite is stored immediately, then Tally pre-seeding starts in the
|
||||
background.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<Input
|
||||
id="invite-email"
|
||||
label="Email"
|
||||
type="email"
|
||||
value={email}
|
||||
placeholder="jane@example.com"
|
||||
autoComplete="email"
|
||||
disabled={isSubmitting}
|
||||
onChange={(event) => onEmailChange(event.target.value)}
|
||||
/>
|
||||
|
||||
<Input
|
||||
id="invite-name"
|
||||
label="Name"
|
||||
type="text"
|
||||
value={name}
|
||||
placeholder="Jane Doe"
|
||||
disabled={isSubmitting}
|
||||
onChange={(event) => onNameChange(event.target.value)}
|
||||
/>
|
||||
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
loading={isSubmitting}
|
||||
disabled={!email.trim()}
|
||||
className="w-full"
|
||||
>
|
||||
{isSubmitting ? "Creating invite..." : "Create invite"}
|
||||
</Button>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,209 @@
|
||||
"use client";
|
||||
|
||||
import type { InvitedUserResponse } from "@/app/api/__generated__/models/invitedUserResponse";
|
||||
import { Badge } from "@/components/atoms/Badge/Badge";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/__legacy__/ui/table";
|
||||
|
||||
interface Props {
|
||||
invitedUsers: InvitedUserResponse[];
|
||||
isLoading: boolean;
|
||||
isRefreshing: boolean;
|
||||
pendingInviteAction: string | null;
|
||||
onRetryTally: (invitedUserId: string) => void;
|
||||
onRevoke: (invitedUserId: string) => void;
|
||||
}
|
||||
|
||||
function getInviteBadgeVariant(status: InvitedUserResponse["status"]) {
|
||||
if (status === "CLAIMED") {
|
||||
return "success";
|
||||
}
|
||||
|
||||
if (status === "REVOKED") {
|
||||
return "error";
|
||||
}
|
||||
|
||||
return "info";
|
||||
}
|
||||
|
||||
function getTallyBadgeVariant(status: InvitedUserResponse["tally_status"]) {
|
||||
if (status === "READY") {
|
||||
return "success";
|
||||
}
|
||||
|
||||
if (status === "FAILED") {
|
||||
return "error";
|
||||
}
|
||||
|
||||
return "info";
|
||||
}
|
||||
|
||||
function formatDate(value: Date | undefined) {
|
||||
if (!value) {
|
||||
return "-";
|
||||
}
|
||||
|
||||
return value.toLocaleString();
|
||||
}
|
||||
|
||||
function getTallySummary(invitedUser: InvitedUserResponse) {
|
||||
if (invitedUser.tally_status === "FAILED" && invitedUser.tally_error) {
|
||||
return invitedUser.tally_error;
|
||||
}
|
||||
|
||||
if (invitedUser.tally_status === "READY" && invitedUser.tally_understanding) {
|
||||
return "Stored and ready for activation";
|
||||
}
|
||||
|
||||
if (invitedUser.tally_status === "READY") {
|
||||
return "No matching Tally submission found";
|
||||
}
|
||||
|
||||
if (invitedUser.tally_status === "RUNNING") {
|
||||
return "Extraction in progress";
|
||||
}
|
||||
|
||||
return "Waiting to run";
|
||||
}
|
||||
|
||||
function isActionPending(
|
||||
pendingInviteAction: string | null,
|
||||
action: "retry" | "revoke",
|
||||
invitedUserId: string,
|
||||
) {
|
||||
return pendingInviteAction === `${action}:${invitedUserId}`;
|
||||
}
|
||||
|
||||
export function InvitedUsersTable({
|
||||
invitedUsers,
|
||||
isLoading,
|
||||
isRefreshing,
|
||||
pendingInviteAction,
|
||||
onRetryTally,
|
||||
onRevoke,
|
||||
}: Props) {
|
||||
return (
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex items-center justify-between gap-4">
|
||||
<div className="flex flex-col gap-1">
|
||||
<h2 className="text-xl font-semibold text-zinc-900">Invited users</h2>
|
||||
<p className="text-sm text-zinc-600">
|
||||
Live invite state, claim status, and Tally pre-seeding progress.
|
||||
</p>
|
||||
</div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
{isRefreshing ? "Refreshing" : `${invitedUsers.length} total`}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div className="overflow-hidden rounded-2xl border border-zinc-200">
|
||||
<Table>
|
||||
<TableHeader className="bg-zinc-50">
|
||||
<TableRow>
|
||||
<TableHead>Email</TableHead>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>Invite</TableHead>
|
||||
<TableHead>Tally</TableHead>
|
||||
<TableHead>Claimed User</TableHead>
|
||||
<TableHead>Created</TableHead>
|
||||
<TableHead className="text-right">Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{isLoading ? (
|
||||
<TableRow>
|
||||
<TableCell
|
||||
colSpan={7}
|
||||
className="py-10 text-center text-zinc-500"
|
||||
>
|
||||
Loading invited users...
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
) : invitedUsers.length === 0 ? (
|
||||
<TableRow>
|
||||
<TableCell
|
||||
colSpan={7}
|
||||
className="py-10 text-center text-zinc-500"
|
||||
>
|
||||
No invited users yet
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
) : (
|
||||
invitedUsers.map((invitedUser) => (
|
||||
<TableRow key={invitedUser.id} className="align-top">
|
||||
<TableCell className="font-medium text-zinc-900">
|
||||
{invitedUser.email}
|
||||
</TableCell>
|
||||
<TableCell>{invitedUser.name || "-"}</TableCell>
|
||||
<TableCell>
|
||||
<Badge variant={getInviteBadgeVariant(invitedUser.status)}>
|
||||
{invitedUser.status}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex max-w-xs flex-col gap-2">
|
||||
<Badge
|
||||
variant={getTallyBadgeVariant(invitedUser.tally_status)}
|
||||
>
|
||||
{invitedUser.tally_status}
|
||||
</Badge>
|
||||
<span className="text-xs text-zinc-500">
|
||||
{getTallySummary(invitedUser)}
|
||||
</span>
|
||||
<span className="text-xs text-zinc-400">
|
||||
{formatDate(invitedUser.tally_computed_at ?? undefined)}
|
||||
</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell className="font-mono text-xs text-zinc-500">
|
||||
{invitedUser.auth_user_id || "-"}
|
||||
</TableCell>
|
||||
<TableCell className="text-sm text-zinc-500">
|
||||
{formatDate(invitedUser.created_at)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex justify-end gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
disabled={invitedUser.status === "REVOKED"}
|
||||
loading={isActionPending(
|
||||
pendingInviteAction,
|
||||
"retry",
|
||||
invitedUser.id,
|
||||
)}
|
||||
onClick={() => onRetryTally(invitedUser.id)}
|
||||
>
|
||||
Retry Tally
|
||||
</Button>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="small"
|
||||
disabled={invitedUser.status !== "INVITED"}
|
||||
loading={isActionPending(
|
||||
pendingInviteAction,
|
||||
"revoke",
|
||||
invitedUser.id,
|
||||
)}
|
||||
onClick={() => onRevoke(invitedUser.id)}
|
||||
>
|
||||
Revoke
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))
|
||||
)}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,16 +1,11 @@
|
||||
import { withRoleAccess } from "@/lib/withRoleAccess";
|
||||
import React from "react";
|
||||
import { AdminUsersPage } from "./components/AdminUsersPage/AdminUsersPage";
|
||||
|
||||
function AdminUsers() {
|
||||
return (
|
||||
<div>
|
||||
<h1>Users Dashboard</h1>
|
||||
{/* Add your admin-only content here */}
|
||||
</div>
|
||||
);
|
||||
return <AdminUsersPage />;
|
||||
}
|
||||
|
||||
export default async function AdminUsersPage() {
|
||||
export default async function AdminUsersRoute() {
|
||||
"use server";
|
||||
const withAdminAccess = await withRoleAccess(["admin"]);
|
||||
const ProtectedAdminUsers = await withAdminAccess(AdminUsers);
|
||||
|
||||
@@ -0,0 +1,201 @@
|
||||
"use client";
|
||||
|
||||
import type { BulkInvitedUsersResponse } from "@/app/api/__generated__/models/bulkInvitedUsersResponse";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import {
|
||||
getGetV2ListInvitedUsersQueryKey,
|
||||
useGetV2ListInvitedUsers,
|
||||
usePostV2BulkCreateInvitedUsers,
|
||||
usePostV2CreateInvitedUser,
|
||||
usePostV2RetryInvitedUserTally,
|
||||
usePostV2RevokeInvitedUser,
|
||||
} from "@/app/api/__generated__/endpoints/admin/admin";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { ApiError } from "@/lib/autogpt-server-api/helpers";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { type FormEvent, useState } from "react";
|
||||
|
||||
function getErrorMessage(error: unknown) {
|
||||
if (error instanceof ApiError) {
|
||||
return error.message;
|
||||
}
|
||||
|
||||
if (error instanceof Error) {
|
||||
return error.message;
|
||||
}
|
||||
|
||||
return "Something went wrong";
|
||||
}
|
||||
|
||||
export function useAdminUsersPage() {
|
||||
const queryClient = useQueryClient();
|
||||
const { toast } = useToast();
|
||||
const [email, setEmail] = useState("");
|
||||
const [name, setName] = useState("");
|
||||
const [bulkInviteFile, setBulkInviteFile] = useState<File | null>(null);
|
||||
const [bulkInviteInputKey, setBulkInviteInputKey] = useState(0);
|
||||
const [lastBulkInviteResult, setLastBulkInviteResult] =
|
||||
useState<BulkInvitedUsersResponse | null>(null);
|
||||
const [pendingInviteAction, setPendingInviteAction] = useState<string | null>(
|
||||
null,
|
||||
);
|
||||
|
||||
const invitedUsersQuery = useGetV2ListInvitedUsers({
|
||||
query: {
|
||||
select: okData,
|
||||
refetchInterval: 5000,
|
||||
},
|
||||
});
|
||||
|
||||
const createInvitedUserMutation = usePostV2CreateInvitedUser({
|
||||
mutation: {
|
||||
onSuccess: async () => {
|
||||
setEmail("");
|
||||
setName("");
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListInvitedUsersQueryKey(),
|
||||
});
|
||||
toast({
|
||||
title: "Invited user created",
|
||||
variant: "default",
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
toast({
|
||||
title: getErrorMessage(error),
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const bulkCreateInvitedUsersMutation = usePostV2BulkCreateInvitedUsers({
|
||||
mutation: {
|
||||
onSuccess: async (response) => {
|
||||
const result = okData(response) ?? null;
|
||||
setBulkInviteFile(null);
|
||||
setBulkInviteInputKey((currentValue) => currentValue + 1);
|
||||
setLastBulkInviteResult(result);
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListInvitedUsersQueryKey(),
|
||||
});
|
||||
toast({
|
||||
title: result
|
||||
? `${result.created_count} invites created`
|
||||
: "Bulk invite upload complete",
|
||||
variant: "default",
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
toast({
|
||||
title: getErrorMessage(error),
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const retryInvitedUserTallyMutation = usePostV2RetryInvitedUserTally({
|
||||
mutation: {
|
||||
onSuccess: async () => {
|
||||
setPendingInviteAction(null);
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListInvitedUsersQueryKey(),
|
||||
});
|
||||
toast({
|
||||
title: "Tally pre-seeding restarted",
|
||||
variant: "default",
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
setPendingInviteAction(null);
|
||||
toast({
|
||||
title: getErrorMessage(error),
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const revokeInvitedUserMutation = usePostV2RevokeInvitedUser({
|
||||
mutation: {
|
||||
onSuccess: async () => {
|
||||
setPendingInviteAction(null);
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListInvitedUsersQueryKey(),
|
||||
});
|
||||
toast({
|
||||
title: "Invite revoked",
|
||||
variant: "default",
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
setPendingInviteAction(null);
|
||||
toast({
|
||||
title: getErrorMessage(error),
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
function handleCreateInvite(event: FormEvent<HTMLFormElement>) {
|
||||
event.preventDefault();
|
||||
|
||||
createInvitedUserMutation.mutate({
|
||||
data: {
|
||||
email,
|
||||
name: name.trim() || null,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function handleRetryTally(invitedUserId: string) {
|
||||
setPendingInviteAction(`retry:${invitedUserId}`);
|
||||
retryInvitedUserTallyMutation.mutate({ invitedUserId });
|
||||
}
|
||||
|
||||
function handleBulkInviteFileChange(file: File | null) {
|
||||
setBulkInviteFile(file);
|
||||
}
|
||||
|
||||
function handleBulkInviteSubmit(event: FormEvent<HTMLFormElement>) {
|
||||
event.preventDefault();
|
||||
|
||||
if (!bulkInviteFile) {
|
||||
return;
|
||||
}
|
||||
|
||||
bulkCreateInvitedUsersMutation.mutate({
|
||||
data: {
|
||||
file: bulkInviteFile,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function handleRevoke(invitedUserId: string) {
|
||||
setPendingInviteAction(`revoke:${invitedUserId}`);
|
||||
revokeInvitedUserMutation.mutate({ invitedUserId });
|
||||
}
|
||||
|
||||
return {
|
||||
email,
|
||||
name,
|
||||
bulkInviteFile,
|
||||
bulkInviteInputKey,
|
||||
lastBulkInviteResult,
|
||||
invitedUsers: invitedUsersQuery.data?.invited_users ?? [],
|
||||
isLoadingInvitedUsers: invitedUsersQuery.isLoading,
|
||||
isRefreshingInvitedUsers: invitedUsersQuery.isFetching,
|
||||
isCreatingInvite: createInvitedUserMutation.isPending,
|
||||
isBulkInviting: bulkCreateInvitedUsersMutation.isPending,
|
||||
pendingInviteAction,
|
||||
setEmail,
|
||||
setName,
|
||||
handleBulkInviteFileChange,
|
||||
handleBulkInviteSubmit,
|
||||
handleCreateInvite,
|
||||
handleRetryTally,
|
||||
handleRevoke,
|
||||
};
|
||||
}
|
||||
@@ -151,6 +151,9 @@ export function ChatInput({
|
||||
onFilesSelected={handleFilesSelected}
|
||||
disabled={isBusy}
|
||||
/>
|
||||
</PromptInputTools>
|
||||
|
||||
<div className="flex items-center gap-4">
|
||||
{showMicButton && (
|
||||
<RecordingButton
|
||||
isRecording={isRecording}
|
||||
@@ -160,13 +163,12 @@ export function ChatInput({
|
||||
onClick={toggleRecording}
|
||||
/>
|
||||
)}
|
||||
</PromptInputTools>
|
||||
|
||||
{isStreaming ? (
|
||||
<PromptInputSubmit status="streaming" onStop={onStop} />
|
||||
) : (
|
||||
<PromptInputSubmit disabled={!canSend} />
|
||||
)}
|
||||
{isStreaming ? (
|
||||
<PromptInputSubmit status="streaming" onStop={onStop} />
|
||||
) : (
|
||||
<PromptInputSubmit disabled={!canSend} />
|
||||
)}
|
||||
</div>
|
||||
</PromptInputFooter>
|
||||
</InputGroup>
|
||||
</form>
|
||||
|
||||
@@ -28,10 +28,9 @@ export function RecordingButton({
|
||||
disabled={disabled}
|
||||
onClick={onClick}
|
||||
className={cn(
|
||||
"border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
|
||||
"border-0 bg-white text-zinc-500 hover:bg-zinc-50 hover:text-zinc-700",
|
||||
disabled && "opacity-40",
|
||||
isRecording &&
|
||||
"animate-pulse border-red-500 bg-red-500 text-white hover:border-red-600 hover:bg-red-600",
|
||||
isRecording && "animate-pulse bg-red-500 text-white hover:bg-red-600",
|
||||
isTranscribing && "bg-zinc-100 text-zinc-400",
|
||||
isStreaming && "opacity-40",
|
||||
)}
|
||||
|
||||
@@ -5,15 +5,19 @@ import {
|
||||
} from "@/components/ai-elements/conversation";
|
||||
import { Message, MessageContent } from "@/components/ai-elements/message";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { FileUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { FileUIPart, ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { TOOL_PART_PREFIX } from "../JobStatsBar/constants";
|
||||
import { TurnStatsBar } from "../JobStatsBar/TurnStatsBar";
|
||||
import { parseSpecialMarkers } from "./helpers";
|
||||
import { AssistantMessageActions } from "./components/AssistantMessageActions";
|
||||
import { CollapsedToolGroup } from "./components/CollapsedToolGroup";
|
||||
import { MessageAttachments } from "./components/MessageAttachments";
|
||||
import { MessagePartRenderer } from "./components/MessagePartRenderer";
|
||||
import { ReasoningCollapse } from "./components/ReasoningCollapse";
|
||||
import { ThinkingIndicator } from "./components/ThinkingIndicator";
|
||||
|
||||
type MessagePart = UIMessage<unknown, UIDataTypes, UITools>["parts"][number];
|
||||
|
||||
interface Props {
|
||||
messages: UIMessage<unknown, UIDataTypes, UITools>[];
|
||||
status: string;
|
||||
@@ -23,6 +27,132 @@ interface Props {
|
||||
sessionID?: string | null;
|
||||
}
|
||||
|
||||
function isCompletedToolPart(part: MessagePart): part is ToolUIPart {
|
||||
return (
|
||||
part.type.startsWith("tool-") &&
|
||||
"state" in part &&
|
||||
(part.state === "output-available" || part.state === "output-error")
|
||||
);
|
||||
}
|
||||
|
||||
type RenderSegment =
|
||||
| { kind: "part"; part: MessagePart; index: number }
|
||||
| { kind: "collapsed-group"; parts: ToolUIPart[] };
|
||||
|
||||
// Tool types that have custom renderers and should NOT be collapsed
|
||||
const CUSTOM_TOOL_TYPES = new Set([
|
||||
"tool-find_block",
|
||||
"tool-find_agent",
|
||||
"tool-find_library_agent",
|
||||
"tool-search_docs",
|
||||
"tool-get_doc_page",
|
||||
"tool-run_block",
|
||||
"tool-run_mcp_tool",
|
||||
"tool-run_agent",
|
||||
"tool-schedule_agent",
|
||||
"tool-create_agent",
|
||||
"tool-edit_agent",
|
||||
"tool-view_agent_output",
|
||||
"tool-search_feature_requests",
|
||||
"tool-create_feature_request",
|
||||
]);
|
||||
|
||||
/**
|
||||
* Groups consecutive completed generic tool parts into collapsed segments.
|
||||
* Non-generic tools (those with custom renderers) and active/streaming tools
|
||||
* are left as individual parts.
|
||||
*/
|
||||
function buildRenderSegments(
|
||||
parts: MessagePart[],
|
||||
baseIndex = 0,
|
||||
): RenderSegment[] {
|
||||
const segments: RenderSegment[] = [];
|
||||
let pendingGroup: Array<{ part: ToolUIPart; index: number }> | null = null;
|
||||
|
||||
function flushGroup() {
|
||||
if (!pendingGroup) return;
|
||||
if (pendingGroup.length >= 2) {
|
||||
segments.push({
|
||||
kind: "collapsed-group",
|
||||
parts: pendingGroup.map((p) => p.part),
|
||||
});
|
||||
} else {
|
||||
for (const p of pendingGroup) {
|
||||
segments.push({ kind: "part", part: p.part, index: p.index });
|
||||
}
|
||||
}
|
||||
pendingGroup = null;
|
||||
}
|
||||
|
||||
parts.forEach((part, i) => {
|
||||
const absoluteIndex = baseIndex + i;
|
||||
const isGenericCompletedTool =
|
||||
isCompletedToolPart(part) && !CUSTOM_TOOL_TYPES.has(part.type);
|
||||
|
||||
if (isGenericCompletedTool) {
|
||||
if (!pendingGroup) pendingGroup = [];
|
||||
pendingGroup.push({ part: part as ToolUIPart, index: absoluteIndex });
|
||||
} else {
|
||||
flushGroup();
|
||||
segments.push({ kind: "part", part, index: absoluteIndex });
|
||||
}
|
||||
});
|
||||
|
||||
flushGroup();
|
||||
return segments;
|
||||
}
|
||||
|
||||
/**
|
||||
* For finalized assistant messages, split parts into "reasoning" (intermediate
|
||||
* text + tools before the final response) and "response" (final text after the
|
||||
* last tool). If there are no tools, everything is response.
|
||||
*/
|
||||
function splitReasoningAndResponse(parts: MessagePart[]): {
|
||||
reasoning: MessagePart[];
|
||||
response: MessagePart[];
|
||||
} {
|
||||
const lastToolIndex = parts.findLastIndex((p) => p.type.startsWith("tool-"));
|
||||
|
||||
// No tools → everything is response
|
||||
if (lastToolIndex === -1) {
|
||||
return { reasoning: [], response: parts };
|
||||
}
|
||||
|
||||
// Check if there's any text after the last tool
|
||||
const hasResponseAfterTools = parts
|
||||
.slice(lastToolIndex + 1)
|
||||
.some((p) => p.type === "text");
|
||||
|
||||
if (!hasResponseAfterTools) {
|
||||
// No final text response → don't collapse anything
|
||||
return { reasoning: [], response: parts };
|
||||
}
|
||||
|
||||
return {
|
||||
reasoning: parts.slice(0, lastToolIndex + 1),
|
||||
response: parts.slice(lastToolIndex + 1),
|
||||
};
|
||||
}
|
||||
|
||||
function renderSegments(
|
||||
segments: RenderSegment[],
|
||||
messageID: string,
|
||||
): React.ReactNode[] {
|
||||
return segments.map((seg, segIdx) => {
|
||||
if (seg.kind === "collapsed-group") {
|
||||
return <CollapsedToolGroup key={`group-${segIdx}`} parts={seg.parts} />;
|
||||
}
|
||||
return (
|
||||
<MessagePartRenderer
|
||||
key={`${messageID}-${seg.index}`}
|
||||
part={seg.part}
|
||||
messageID={messageID}
|
||||
partIndex={seg.index}
|
||||
/>
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
/** Collect all messages belonging to a turn: the user message + every
|
||||
* assistant message up to (but not including) the next user message. */
|
||||
function getTurnMessages(
|
||||
@@ -119,6 +249,24 @@ export function ChatMessagesContainer({
|
||||
(p): p is FileUIPart => p.type === "file",
|
||||
);
|
||||
|
||||
// For finalized assistant messages, split into reasoning + response.
|
||||
// During streaming, show everything normally with tool collapsing.
|
||||
const isFinalized =
|
||||
message.role === "assistant" && !isCurrentlyStreaming;
|
||||
const { reasoning, response } = isFinalized
|
||||
? splitReasoningAndResponse(message.parts)
|
||||
: { reasoning: [] as MessagePart[], response: message.parts };
|
||||
const hasReasoning = reasoning.length > 0;
|
||||
|
||||
const responseStartIndex = message.parts.length - response.length;
|
||||
const responseSegments =
|
||||
message.role === "assistant"
|
||||
? buildRenderSegments(response, responseStartIndex)
|
||||
: null;
|
||||
const reasoningSegments = hasReasoning
|
||||
? buildRenderSegments(reasoning, 0)
|
||||
: null;
|
||||
|
||||
return (
|
||||
<Message from={message.role} key={message.id}>
|
||||
<MessageContent
|
||||
@@ -128,14 +276,21 @@ export function ChatMessagesContainer({
|
||||
"group-[.is-assistant]:bg-transparent group-[.is-assistant]:text-slate-900"
|
||||
}
|
||||
>
|
||||
{message.parts.map((part, i) => (
|
||||
<MessagePartRenderer
|
||||
key={`${message.id}-${i}`}
|
||||
part={part}
|
||||
messageID={message.id}
|
||||
partIndex={i}
|
||||
/>
|
||||
))}
|
||||
{hasReasoning && reasoningSegments && (
|
||||
<ReasoningCollapse>
|
||||
{renderSegments(reasoningSegments, message.id)}
|
||||
</ReasoningCollapse>
|
||||
)}
|
||||
{responseSegments
|
||||
? renderSegments(responseSegments, message.id)
|
||||
: message.parts.map((part, i) => (
|
||||
<MessagePartRenderer
|
||||
key={`${message.id}-${i}`}
|
||||
part={part}
|
||||
messageID={message.id}
|
||||
partIndex={i}
|
||||
/>
|
||||
))}
|
||||
{isLastInTurn && !isCurrentlyStreaming && (
|
||||
<TurnStatsBar
|
||||
turnMessages={getTurnMessages(messages, messageIndex)}
|
||||
|
||||
@@ -0,0 +1,152 @@
|
||||
"use client";
|
||||
|
||||
import { useId, useState } from "react";
|
||||
import {
|
||||
ArrowsClockwiseIcon,
|
||||
CaretRightIcon,
|
||||
CheckCircleIcon,
|
||||
FileIcon,
|
||||
FilesIcon,
|
||||
GearIcon,
|
||||
GlobeIcon,
|
||||
ListChecksIcon,
|
||||
MagnifyingGlassIcon,
|
||||
MonitorIcon,
|
||||
PencilSimpleIcon,
|
||||
TerminalIcon,
|
||||
TrashIcon,
|
||||
WarningDiamondIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import type { ToolUIPart } from "ai";
|
||||
import {
|
||||
type ToolCategory,
|
||||
extractToolName,
|
||||
getAnimationText,
|
||||
getToolCategory,
|
||||
} from "../../../tools/GenericTool/helpers";
|
||||
|
||||
interface Props {
|
||||
parts: ToolUIPart[];
|
||||
}
|
||||
|
||||
/** Category icon matching GenericTool's ToolIcon for completed states. */
|
||||
function EntryIcon({
|
||||
category,
|
||||
isError,
|
||||
}: {
|
||||
category: ToolCategory;
|
||||
isError: boolean;
|
||||
}) {
|
||||
if (isError) {
|
||||
return (
|
||||
<WarningDiamondIcon size={14} weight="regular" className="text-red-500" />
|
||||
);
|
||||
}
|
||||
|
||||
const iconClass = "text-green-500";
|
||||
switch (category) {
|
||||
case "bash":
|
||||
return <TerminalIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "web":
|
||||
return <GlobeIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "browser":
|
||||
return <MonitorIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "file-read":
|
||||
case "file-write":
|
||||
return <FileIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "file-delete":
|
||||
return <TrashIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "file-list":
|
||||
return <FilesIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "search":
|
||||
return (
|
||||
<MagnifyingGlassIcon size={14} weight="regular" className={iconClass} />
|
||||
);
|
||||
case "edit":
|
||||
return (
|
||||
<PencilSimpleIcon size={14} weight="regular" className={iconClass} />
|
||||
);
|
||||
case "todo":
|
||||
return (
|
||||
<ListChecksIcon size={14} weight="regular" className={iconClass} />
|
||||
);
|
||||
case "compaction":
|
||||
return (
|
||||
<ArrowsClockwiseIcon size={14} weight="regular" className={iconClass} />
|
||||
);
|
||||
default:
|
||||
return <GearIcon size={14} weight="regular" className={iconClass} />;
|
||||
}
|
||||
}
|
||||
|
||||
export function CollapsedToolGroup({ parts }: Props) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const panelId = useId();
|
||||
|
||||
const errorCount = parts.filter((p) => p.state === "output-error").length;
|
||||
const label =
|
||||
errorCount > 0
|
||||
? `${parts.length} tool calls (${errorCount} failed)`
|
||||
: `${parts.length} tool calls completed`;
|
||||
|
||||
return (
|
||||
<div className="py-1">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setExpanded(!expanded)}
|
||||
aria-expanded={expanded}
|
||||
aria-controls={panelId}
|
||||
className="flex items-center gap-1.5 text-sm text-muted-foreground transition-colors hover:text-foreground"
|
||||
>
|
||||
<CaretRightIcon
|
||||
size={12}
|
||||
weight="bold"
|
||||
className={
|
||||
"transition-transform duration-150 " + (expanded ? "rotate-90" : "")
|
||||
}
|
||||
/>
|
||||
{errorCount > 0 ? (
|
||||
<WarningDiamondIcon
|
||||
size={14}
|
||||
weight="regular"
|
||||
className="text-red-500"
|
||||
/>
|
||||
) : (
|
||||
<CheckCircleIcon
|
||||
size={14}
|
||||
weight="regular"
|
||||
className="text-green-500"
|
||||
/>
|
||||
)}
|
||||
<span>{label}</span>
|
||||
</button>
|
||||
|
||||
{expanded && (
|
||||
<div
|
||||
id={panelId}
|
||||
className="ml-5 mt-1 space-y-0.5 border-l border-neutral-200 pl-3"
|
||||
>
|
||||
{parts.map((part) => {
|
||||
const toolName = extractToolName(part);
|
||||
const category = getToolCategory(toolName);
|
||||
const text = getAnimationText(part, category);
|
||||
const isError = part.state === "output-error";
|
||||
|
||||
return (
|
||||
<div
|
||||
key={part.toolCallId}
|
||||
className={
|
||||
"flex items-center gap-1.5 text-xs " +
|
||||
(isError ? "text-red-500" : "text-muted-foreground")
|
||||
}
|
||||
>
|
||||
<EntryIcon category={category} isError={isError} />
|
||||
<span>{text}</span>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -15,6 +15,7 @@ export function FeedbackModal({ isOpen, onSubmit, onCancel }: Props) {
|
||||
const [comment, setComment] = useState("");
|
||||
|
||||
function handleSubmit() {
|
||||
if (!comment.trim()) return;
|
||||
onSubmit(comment);
|
||||
setComment("");
|
||||
}
|
||||
@@ -36,7 +37,7 @@ export function FeedbackModal({ isOpen, onSubmit, onCancel }: Props) {
|
||||
>
|
||||
<Dialog.Content>
|
||||
<div className="mx-auto w-[95%] space-y-4">
|
||||
<p className="text-sm text-slate-600">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Your feedback helps us improve. Share details below.
|
||||
</p>
|
||||
<Textarea
|
||||
@@ -48,12 +49,18 @@ export function FeedbackModal({ isOpen, onSubmit, onCancel }: Props) {
|
||||
className="resize-none"
|
||||
/>
|
||||
<div className="flex items-center justify-between">
|
||||
<p className="text-xs text-slate-400">{comment.length}/2000</p>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
{comment.length}/2000
|
||||
</p>
|
||||
<div className="flex gap-2">
|
||||
<Button variant="outline" size="sm" onClick={handleClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button size="sm" onClick={handleSubmit}>
|
||||
<Button
|
||||
size="sm"
|
||||
onClick={handleSubmit}
|
||||
disabled={!comment.trim()}
|
||||
>
|
||||
Submit feedback
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
SearchFeatureRequestsTool,
|
||||
} from "../../../tools/FeatureRequests/FeatureRequests";
|
||||
import { FindAgentsTool } from "../../../tools/FindAgents/FindAgents";
|
||||
import { FolderTool } from "../../../tools/FolderTool/FolderTool";
|
||||
import { FindBlocksTool } from "../../../tools/FindBlocks/FindBlocks";
|
||||
import { GenericTool } from "../../../tools/GenericTool/GenericTool";
|
||||
import { RunAgentTool } from "../../../tools/RunAgent/RunAgent";
|
||||
@@ -145,6 +146,13 @@ export function MessagePartRenderer({ part, messageID, partIndex }: Props) {
|
||||
return <SearchFeatureRequestsTool key={key} part={part as ToolUIPart} />;
|
||||
case "tool-create_feature_request":
|
||||
return <CreateFeatureRequestTool key={key} part={part as ToolUIPart} />;
|
||||
case "tool-create_folder":
|
||||
case "tool-list_folders":
|
||||
case "tool-update_folder":
|
||||
case "tool-move_folder":
|
||||
case "tool-delete_folder":
|
||||
case "tool-move_agents_to_folder":
|
||||
return <FolderTool key={key} part={part as ToolUIPart} />;
|
||||
default:
|
||||
// Render a generic tool indicator for SDK built-in
|
||||
// tools (Read, Glob, Grep, etc.) or any unrecognized tool
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
"use client";
|
||||
|
||||
import { LightbulbIcon } from "@phosphor-icons/react";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
|
||||
interface Props {
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
export function ReasoningCollapse({ children }: Props) {
|
||||
return (
|
||||
<Dialog title="Reasoning">
|
||||
<Dialog.Trigger>
|
||||
<button
|
||||
type="button"
|
||||
className="flex items-center gap-1 text-xs text-zinc-500 transition-colors hover:text-zinc-700"
|
||||
>
|
||||
<LightbulbIcon size={12} weight="bold" />
|
||||
<span>Show reasoning</span>
|
||||
</button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="space-y-1">{children}</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -3,6 +3,7 @@ import {
|
||||
getGetV2ListSessionsQueryKey,
|
||||
useDeleteV2DeleteSession,
|
||||
useGetV2ListSessions,
|
||||
usePatchV2UpdateSessionTitle,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
@@ -17,7 +18,6 @@ import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import {
|
||||
Sidebar,
|
||||
SidebarContent,
|
||||
SidebarFooter,
|
||||
SidebarHeader,
|
||||
SidebarTrigger,
|
||||
useSidebar,
|
||||
@@ -25,8 +25,9 @@ import {
|
||||
import { cn } from "@/lib/utils";
|
||||
import { DotsThree, PlusCircleIcon, PlusIcon } from "@phosphor-icons/react";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { motion } from "framer-motion";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import { parseAsString, useQueryState } from "nuqs";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
import { DeleteChatDialog } from "../DeleteChatDialog/DeleteChatDialog";
|
||||
|
||||
@@ -65,6 +66,39 @@ export function ChatSidebar() {
|
||||
},
|
||||
});
|
||||
|
||||
const [editingSessionId, setEditingSessionId] = useState<string | null>(null);
|
||||
const [editingTitle, setEditingTitle] = useState("");
|
||||
const renameInputRef = useRef<HTMLInputElement>(null);
|
||||
const renameCancelledRef = useRef(false);
|
||||
|
||||
const { mutate: renameSession } = usePatchV2UpdateSessionTitle({
|
||||
mutation: {
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
});
|
||||
setEditingSessionId(null);
|
||||
},
|
||||
onError: (error) => {
|
||||
toast({
|
||||
title: "Failed to rename chat",
|
||||
description:
|
||||
error instanceof Error ? error.message : "An error occurred",
|
||||
variant: "destructive",
|
||||
});
|
||||
setEditingSessionId(null);
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Auto-focus the rename input when editing starts
|
||||
useEffect(() => {
|
||||
if (editingSessionId && renameInputRef.current) {
|
||||
renameInputRef.current.focus();
|
||||
renameInputRef.current.select();
|
||||
}
|
||||
}, [editingSessionId]);
|
||||
|
||||
const sessions =
|
||||
sessionsResponse?.status === 200 ? sessionsResponse.data.sessions : [];
|
||||
|
||||
@@ -76,6 +110,26 @@ export function ChatSidebar() {
|
||||
setSessionId(id);
|
||||
}
|
||||
|
||||
function handleRenameClick(
|
||||
e: React.MouseEvent,
|
||||
id: string,
|
||||
title: string | null | undefined,
|
||||
) {
|
||||
e.stopPropagation();
|
||||
renameCancelledRef.current = false;
|
||||
setEditingSessionId(id);
|
||||
setEditingTitle(title || "");
|
||||
}
|
||||
|
||||
function handleRenameSubmit(id: string) {
|
||||
const trimmed = editingTitle.trim();
|
||||
if (trimmed) {
|
||||
renameSession({ sessionId: id, data: { title: trimmed } });
|
||||
} else {
|
||||
setEditingSessionId(null);
|
||||
}
|
||||
}
|
||||
|
||||
function handleDeleteClick(
|
||||
e: React.MouseEvent,
|
||||
id: string,
|
||||
@@ -160,29 +214,42 @@ export function ChatSidebar() {
|
||||
</motion.div>
|
||||
</SidebarHeader>
|
||||
)}
|
||||
{!isCollapsed && (
|
||||
<SidebarHeader className="shrink-0 px-4 pb-4 pt-4 shadow-[0_4px_6px_-1px_rgba(0,0,0,0.05)]">
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
transition={{ duration: 0.2, delay: 0.1 }}
|
||||
className="flex flex-col gap-3 px-3"
|
||||
>
|
||||
<div className="flex items-center justify-between">
|
||||
<Text variant="h3" size="body-medium">
|
||||
Your chats
|
||||
</Text>
|
||||
<div className="relative left-6">
|
||||
<SidebarTrigger />
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={handleNewChat}
|
||||
className="w-full"
|
||||
leftIcon={<PlusIcon className="h-4 w-4" weight="bold" />}
|
||||
>
|
||||
New Chat
|
||||
</Button>
|
||||
</motion.div>
|
||||
</SidebarHeader>
|
||||
)}
|
||||
|
||||
<SidebarContent className="gap-4 overflow-y-auto px-4 py-4 [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
||||
{!isCollapsed && (
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
transition={{ duration: 0.2, delay: 0.1 }}
|
||||
className="flex items-center justify-between px-3"
|
||||
>
|
||||
<Text variant="h3" size="body-medium">
|
||||
Your chats
|
||||
</Text>
|
||||
<div className="relative left-6">
|
||||
<SidebarTrigger />
|
||||
</div>
|
||||
</motion.div>
|
||||
)}
|
||||
|
||||
{!isCollapsed && (
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
transition={{ duration: 0.2, delay: 0.15 }}
|
||||
className="mt-4 flex flex-col gap-1"
|
||||
className="flex flex-col gap-1"
|
||||
>
|
||||
{isLoadingSessions ? (
|
||||
<div className="flex min-h-[30rem] items-center justify-center py-4">
|
||||
@@ -203,76 +270,105 @@ export function ChatSidebar() {
|
||||
: "hover:bg-zinc-50",
|
||||
)}
|
||||
>
|
||||
<button
|
||||
onClick={() => handleSelectSession(session.id)}
|
||||
className="w-full px-3 py-2.5 pr-10 text-left"
|
||||
>
|
||||
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
|
||||
<div className="min-w-0 max-w-full">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"truncate font-normal",
|
||||
session.id === sessionId
|
||||
? "text-zinc-600"
|
||||
: "text-zinc-800",
|
||||
)}
|
||||
>
|
||||
{session.title || `Untitled chat`}
|
||||
{editingSessionId === session.id ? (
|
||||
<div className="px-3 py-2.5">
|
||||
<input
|
||||
ref={renameInputRef}
|
||||
type="text"
|
||||
aria-label="Rename chat"
|
||||
value={editingTitle}
|
||||
onChange={(e) => setEditingTitle(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") {
|
||||
e.currentTarget.blur();
|
||||
} else if (e.key === "Escape") {
|
||||
renameCancelledRef.current = true;
|
||||
setEditingSessionId(null);
|
||||
}
|
||||
}}
|
||||
onBlur={() => {
|
||||
if (renameCancelledRef.current) {
|
||||
renameCancelledRef.current = false;
|
||||
return;
|
||||
}
|
||||
handleRenameSubmit(session.id);
|
||||
}}
|
||||
className="w-full rounded border border-zinc-300 bg-white px-2 py-1 text-sm text-zinc-800 outline-none focus:border-purple-500 focus:ring-1 focus:ring-purple-500"
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<button
|
||||
onClick={() => handleSelectSession(session.id)}
|
||||
className="w-full px-3 py-2.5 pr-10 text-left"
|
||||
>
|
||||
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
|
||||
<div className="min-w-0 max-w-full">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"truncate font-normal",
|
||||
session.id === sessionId
|
||||
? "text-zinc-600"
|
||||
: "text-zinc-800",
|
||||
)}
|
||||
>
|
||||
<AnimatePresence mode="wait" initial={false}>
|
||||
<motion.span
|
||||
key={session.title || "untitled"}
|
||||
initial={{ opacity: 0, y: 4 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
exit={{ opacity: 0, y: -4 }}
|
||||
transition={{ duration: 0.2 }}
|
||||
className="block truncate"
|
||||
>
|
||||
{session.title || "Untitled chat"}
|
||||
</motion.span>
|
||||
</AnimatePresence>
|
||||
</Text>
|
||||
</div>
|
||||
<Text variant="small" className="text-neutral-400">
|
||||
{formatDate(session.updated_at)}
|
||||
</Text>
|
||||
</div>
|
||||
<Text variant="small" className="text-neutral-400">
|
||||
{formatDate(session.updated_at)}
|
||||
</Text>
|
||||
</div>
|
||||
</button>
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<button
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
className="absolute right-2 top-1/2 -translate-y-1/2 rounded-full p-1.5 text-zinc-600 transition-all hover:bg-neutral-100"
|
||||
aria-label="More actions"
|
||||
>
|
||||
<DotsThree className="h-4 w-4" />
|
||||
</button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem
|
||||
onClick={(e) =>
|
||||
handleDeleteClick(e, session.id, session.title)
|
||||
}
|
||||
disabled={isDeleting}
|
||||
className="text-red-600 focus:bg-red-50 focus:text-red-600"
|
||||
>
|
||||
Delete chat
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</button>
|
||||
)}
|
||||
{editingSessionId !== session.id && (
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<button
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
className="absolute right-2 top-1/2 -translate-y-1/2 rounded-full p-1.5 text-zinc-600 transition-all hover:bg-neutral-100"
|
||||
aria-label="More actions"
|
||||
>
|
||||
<DotsThree className="h-4 w-4" />
|
||||
</button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem
|
||||
onClick={(e) =>
|
||||
handleRenameClick(e, session.id, session.title)
|
||||
}
|
||||
>
|
||||
Rename
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem
|
||||
onClick={(e) =>
|
||||
handleDeleteClick(e, session.id, session.title)
|
||||
}
|
||||
disabled={isDeleting}
|
||||
className="text-red-600 focus:bg-red-50 focus:text-red-600"
|
||||
>
|
||||
Delete chat
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
)}
|
||||
</div>
|
||||
))
|
||||
)}
|
||||
</motion.div>
|
||||
)}
|
||||
</SidebarContent>
|
||||
{!isCollapsed && sessionId && (
|
||||
<SidebarFooter className="shrink-0 bg-zinc-50 p-3 pb-1 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
transition={{ duration: 0.2, delay: 0.2 }}
|
||||
>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={handleNewChat}
|
||||
className="w-full"
|
||||
leftIcon={<PlusIcon className="h-4 w-4" weight="bold" />}
|
||||
>
|
||||
New Chat
|
||||
</Button>
|
||||
</motion.div>
|
||||
</SidebarFooter>
|
||||
)}
|
||||
</Sidebar>
|
||||
|
||||
<DeleteChatDialog
|
||||
|
||||
@@ -29,7 +29,6 @@ export function DeleteChatDialog({
|
||||
}
|
||||
},
|
||||
}}
|
||||
onClose={isDeleting ? undefined : onCancel}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<Text variant="body">
|
||||
|
||||
@@ -71,6 +71,17 @@ export function MobileDrawer({
|
||||
<X width="1rem" height="1rem" />
|
||||
</Button>
|
||||
</div>
|
||||
<div className="mt-2">
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={onNewChat}
|
||||
className="w-full"
|
||||
leftIcon={<PlusIcon width="1rem" height="1rem" />}
|
||||
>
|
||||
New Chat
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
className={cn(
|
||||
@@ -120,19 +131,6 @@ export function MobileDrawer({
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
{currentSessionId && (
|
||||
<div className="shrink-0 bg-white p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={onNewChat}
|
||||
className="w-full"
|
||||
leftIcon={<PlusIcon width="1rem" height="1rem" />}
|
||||
>
|
||||
New Chat
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</Drawer.Content>
|
||||
</Drawer.Portal>
|
||||
</Drawer.Root>
|
||||
|
||||
@@ -181,6 +181,14 @@ export function convertChatSessionMessagesToUiMessages(
|
||||
|
||||
if (parts.length === 0) return;
|
||||
|
||||
// Merge consecutive assistant messages into a single UIMessage
|
||||
// to avoid split bubbles on page reload.
|
||||
const prevUI = uiMessages[uiMessages.length - 1];
|
||||
if (msg.role === "assistant" && prevUI && prevUI.role === "assistant") {
|
||||
prevUI.parts.push(...parts);
|
||||
return;
|
||||
}
|
||||
|
||||
uiMessages.push({
|
||||
id: `${sessionId}-${index}`,
|
||||
role: msg.role,
|
||||
|
||||
@@ -10,7 +10,7 @@ import {
|
||||
WarningDiamondIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import type { ToolUIPart } from "ai";
|
||||
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
|
||||
import { ScaleLoader } from "../../components/ScaleLoader/ScaleLoader";
|
||||
|
||||
export type CreateAgentToolOutput =
|
||||
| AgentPreviewResponse
|
||||
@@ -134,7 +134,7 @@ export function ToolIcon({
|
||||
);
|
||||
}
|
||||
if (isStreaming) {
|
||||
return <OrbitLoader size={24} />;
|
||||
return <ScaleLoader size={14} />;
|
||||
}
|
||||
return <PlusIcon size={14} weight="regular" className="text-neutral-400" />;
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user