mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
12 Commits
fix/copilo
...
feat/analy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9afd9fa01 | ||
|
|
ddb4f6e9de | ||
|
|
f585d97928 | ||
|
|
7d39234fdd | ||
|
|
6e9d4c4333 | ||
|
|
8aad333a45 | ||
|
|
856f0d980d | ||
|
|
3c3aadd361 | ||
|
|
e87a693fdd | ||
|
|
fe265c10d4 | ||
|
|
5d00a94693 | ||
|
|
6e1605994d |
@@ -37,10 +37,6 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
||||
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
||||
|
||||
## ===== SIGNUP / INVITE GATE ===== ##
|
||||
# Set to true to require an invite before users can sign up
|
||||
ENABLE_INVITE_GATE=false
|
||||
|
||||
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
||||
# Platform URLs (set these for webhooks and OAuth to work)
|
||||
PLATFORM_BASE_URL=http://localhost:8000
|
||||
|
||||
@@ -1,17 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
import prisma.enums
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import UserTransaction
|
||||
from backend.util.models import Pagination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.invited_user import BulkInvitedUsersResult, InvitedUserRecord
|
||||
|
||||
|
||||
class UserHistoryResponse(BaseModel):
|
||||
"""Response model for listings with version history"""
|
||||
@@ -23,70 +14,3 @@ class UserHistoryResponse(BaseModel):
|
||||
class AddUserCreditsResponse(BaseModel):
|
||||
new_balance: int
|
||||
transaction_key: str
|
||||
|
||||
|
||||
class CreateInvitedUserRequest(BaseModel):
|
||||
email: EmailStr
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class InvitedUserResponse(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
status: prisma.enums.InvitedUserStatus
|
||||
auth_user_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
tally_understanding: Optional[dict[str, Any]] = None
|
||||
tally_status: prisma.enums.TallyComputationStatus
|
||||
tally_computed_at: Optional[datetime] = None
|
||||
tally_error: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_record(cls, record: InvitedUserRecord) -> InvitedUserResponse:
|
||||
return cls.model_validate(record.model_dump())
|
||||
|
||||
|
||||
class InvitedUsersResponse(BaseModel):
|
||||
invited_users: list[InvitedUserResponse]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class BulkInvitedUserRowResponse(BaseModel):
|
||||
row_number: int
|
||||
email: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
status: Literal["CREATED", "SKIPPED", "ERROR"]
|
||||
message: str
|
||||
invited_user: Optional[InvitedUserResponse] = None
|
||||
|
||||
|
||||
class BulkInvitedUsersResponse(BaseModel):
|
||||
created_count: int
|
||||
skipped_count: int
|
||||
error_count: int
|
||||
results: list[BulkInvitedUserRowResponse]
|
||||
|
||||
@classmethod
|
||||
def from_result(cls, result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
|
||||
return cls(
|
||||
created_count=result.created_count,
|
||||
skipped_count=result.skipped_count,
|
||||
error_count=result.error_count,
|
||||
results=[
|
||||
BulkInvitedUserRowResponse(
|
||||
row_number=row.row_number,
|
||||
email=row.email,
|
||||
name=row.name,
|
||||
status=row.status,
|
||||
message=row.message,
|
||||
invited_user=(
|
||||
InvitedUserResponse.from_record(row.invited_user)
|
||||
if row.invited_user is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
for row in result.results
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
import logging
|
||||
import math
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, File, Query, Security, UploadFile
|
||||
|
||||
from backend.data.invited_user import (
|
||||
bulk_create_invited_users_from_file,
|
||||
create_invited_user,
|
||||
list_invited_users,
|
||||
retry_invited_user_tally,
|
||||
revoke_invited_user,
|
||||
)
|
||||
from backend.data.tally import mask_email
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from .model import (
|
||||
BulkInvitedUsersResponse,
|
||||
CreateInvitedUserRequest,
|
||||
InvitedUserResponse,
|
||||
InvitedUsersResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["users", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/invited-users",
|
||||
response_model=InvitedUsersResponse,
|
||||
summary="List Invited Users",
|
||||
)
|
||||
async def get_invited_users(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
) -> InvitedUsersResponse:
|
||||
logger.info("Admin user %s requested invited users", admin_user_id)
|
||||
invited_users, total = await list_invited_users(page=page, page_size=page_size)
|
||||
return InvitedUsersResponse(
|
||||
invited_users=[InvitedUserResponse.from_record(iu) for iu in invited_users],
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=max(1, math.ceil(total / page_size)),
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Create Invited User",
|
||||
)
|
||||
async def create_invited_user_route(
|
||||
request: CreateInvitedUserRequest,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s creating invited user for %s",
|
||||
admin_user_id,
|
||||
mask_email(request.email),
|
||||
)
|
||||
invited_user = await create_invited_user(request.email, request.name)
|
||||
logger.info(
|
||||
"Admin user %s created invited user %s",
|
||||
admin_user_id,
|
||||
invited_user.id,
|
||||
)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/bulk",
|
||||
response_model=BulkInvitedUsersResponse,
|
||||
summary="Bulk Create Invited Users",
|
||||
operation_id="postV2BulkCreateInvitedUsers",
|
||||
)
|
||||
async def bulk_create_invited_users_route(
|
||||
file: UploadFile = File(...),
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> BulkInvitedUsersResponse:
|
||||
logger.info(
|
||||
"Admin user %s bulk invited users from %s",
|
||||
admin_user_id,
|
||||
file.filename or "<unnamed>",
|
||||
)
|
||||
content = await file.read()
|
||||
result = await bulk_create_invited_users_from_file(file.filename, content)
|
||||
return BulkInvitedUsersResponse.from_result(result)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/{invited_user_id}/revoke",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Revoke Invited User",
|
||||
)
|
||||
async def revoke_invited_user_route(
|
||||
invited_user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s revoking invited user %s", admin_user_id, invited_user_id
|
||||
)
|
||||
invited_user = await revoke_invited_user(invited_user_id)
|
||||
logger.info("Admin user %s revoked invited user %s", admin_user_id, invited_user_id)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/{invited_user_id}/retry-tally",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Retry Invited User Tally",
|
||||
)
|
||||
async def retry_invited_user_tally_route(
|
||||
invited_user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s retrying Tally seed for invited user %s",
|
||||
admin_user_id,
|
||||
invited_user_id,
|
||||
)
|
||||
invited_user = await retry_invited_user_tally(invited_user_id)
|
||||
logger.info(
|
||||
"Admin user %s retried Tally seed for invited user %s",
|
||||
admin_user_id,
|
||||
invited_user_id,
|
||||
)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
@@ -1,168 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import prisma.enums
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.data.invited_user import (
|
||||
BulkInvitedUserRowResult,
|
||||
BulkInvitedUsersResult,
|
||||
InvitedUserRecord,
|
||||
)
|
||||
|
||||
from .user_admin_routes import router as user_admin_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(user_admin_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _sample_invited_user() -> InvitedUserRecord:
|
||||
now = datetime.now(timezone.utc)
|
||||
return InvitedUserRecord(
|
||||
id="invite-1",
|
||||
email="invited@example.com",
|
||||
status=prisma.enums.InvitedUserStatus.INVITED,
|
||||
auth_user_id=None,
|
||||
name="Invited User",
|
||||
tally_understanding=None,
|
||||
tally_status=prisma.enums.TallyComputationStatus.PENDING,
|
||||
tally_computed_at=None,
|
||||
tally_error=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
|
||||
return BulkInvitedUsersResult(
|
||||
created_count=1,
|
||||
skipped_count=1,
|
||||
error_count=0,
|
||||
results=[
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=1,
|
||||
email="invited@example.com",
|
||||
name=None,
|
||||
status="CREATED",
|
||||
message="Invite created",
|
||||
invited_user=_sample_invited_user(),
|
||||
),
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=2,
|
||||
email="duplicate@example.com",
|
||||
name=None,
|
||||
status="SKIPPED",
|
||||
message="An invited user with this email already exists",
|
||||
invited_user=None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_get_invited_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.list_invited_users",
|
||||
AsyncMock(return_value=([_sample_invited_user()], 1)),
|
||||
)
|
||||
|
||||
response = client.get("/admin/invited-users")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["invited_users"]) == 1
|
||||
assert data["invited_users"][0]["email"] == "invited@example.com"
|
||||
assert data["invited_users"][0]["status"] == "INVITED"
|
||||
assert data["pagination"]["total_items"] == 1
|
||||
assert data["pagination"]["current_page"] == 1
|
||||
assert data["pagination"]["page_size"] == 50
|
||||
|
||||
|
||||
def test_create_invited_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.create_invited_user",
|
||||
AsyncMock(return_value=_sample_invited_user()),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/invited-users",
|
||||
json={"email": "invited@example.com", "name": "Invited User"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == "invited@example.com"
|
||||
assert data["name"] == "Invited User"
|
||||
|
||||
|
||||
def test_bulk_create_invited_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.bulk_create_invited_users_from_file",
|
||||
AsyncMock(return_value=_sample_bulk_invited_users_result()),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/invited-users/bulk",
|
||||
files={
|
||||
"file": ("invites.txt", b"invited@example.com\nduplicate@example.com\n")
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["created_count"] == 1
|
||||
assert data["skipped_count"] == 1
|
||||
assert data["results"][0]["status"] == "CREATED"
|
||||
assert data["results"][1]["status"] == "SKIPPED"
|
||||
|
||||
|
||||
def test_revoke_invited_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
revoked = _sample_invited_user().model_copy(
|
||||
update={"status": prisma.enums.InvitedUserStatus.REVOKED}
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.revoke_invited_user",
|
||||
AsyncMock(return_value=revoked),
|
||||
)
|
||||
|
||||
response = client.post("/admin/invited-users/invite-1/revoke")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "REVOKED"
|
||||
|
||||
|
||||
def test_retry_invited_user_tally(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
retried = _sample_invited_user().model_copy(
|
||||
update={"tally_status": prisma.enums.TallyComputationStatus.RUNNING}
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.retry_invited_user_tally",
|
||||
AsyncMock(return_value=retried),
|
||||
)
|
||||
|
||||
response = client.post("/admin/invited-users/invite-1/retry-tally")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["tally_status"] == "RUNNING"
|
||||
@@ -27,12 +27,6 @@ from backend.copilot.model import (
|
||||
get_user_sessions,
|
||||
update_session_title,
|
||||
)
|
||||
from backend.copilot.rate_limit import (
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
check_rate_limit,
|
||||
get_usage_status,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
from backend.copilot.tools.models import (
|
||||
@@ -59,8 +53,6 @@ from backend.copilot.tools.models import (
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import get_business_understanding
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
@@ -126,8 +118,6 @@ class SessionDetailResponse(BaseModel):
|
||||
user_id: str | None
|
||||
messages: list[dict]
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
@@ -137,7 +127,6 @@ class SessionSummaryResponse(BaseModel):
|
||||
created_at: str
|
||||
updated_at: str
|
||||
title: str | None = None
|
||||
is_processing: bool
|
||||
|
||||
|
||||
class ListSessionsResponse(BaseModel):
|
||||
@@ -196,28 +185,6 @@ async def list_sessions(
|
||||
"""
|
||||
sessions, total_count = await get_user_sessions(user_id, limit, offset)
|
||||
|
||||
# Batch-check Redis for active stream status on each session
|
||||
processing_set: set[str] = set()
|
||||
if sessions:
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
pipe = redis.pipeline(transaction=False)
|
||||
for session in sessions:
|
||||
pipe.hget(
|
||||
f"{config.session_meta_prefix}{session.session_id}",
|
||||
"status",
|
||||
)
|
||||
statuses = await pipe.execute()
|
||||
processing_set = {
|
||||
session.session_id
|
||||
for session, st in zip(sessions, statuses)
|
||||
if st == "running"
|
||||
}
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch processing status from Redis; " "defaulting to empty"
|
||||
)
|
||||
|
||||
return ListSessionsResponse(
|
||||
sessions=[
|
||||
SessionSummaryResponse(
|
||||
@@ -225,7 +192,6 @@ async def list_sessions(
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
title=session.title,
|
||||
is_processing=session.session_id in processing_set,
|
||||
)
|
||||
for session in sessions
|
||||
],
|
||||
@@ -397,10 +363,6 @@ async def get_session(
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
# Sum token usage from session
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
@@ -408,26 +370,6 @@ async def get_session(
|
||||
user_id=session.user_id or None,
|
||||
messages=messages,
|
||||
active_stream=active_stream_info,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/usage")
|
||||
async def get_copilot_usage(
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> CoPilotUsageStatus:
|
||||
"""Get CoPilot usage status for the authenticated user.
|
||||
|
||||
Returns current token usage vs limits for daily and weekly windows.
|
||||
"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
return await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=config.daily_token_limit,
|
||||
weekly_token_limit=config.weekly_token_limit,
|
||||
)
|
||||
|
||||
|
||||
@@ -528,17 +470,6 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
# Pre-turn rate limit check (token-based)
|
||||
if user_id and (config.daily_token_limit > 0 or config.weekly_token_limit > 0):
|
||||
try:
|
||||
await check_rate_limit(
|
||||
user_id=user_id,
|
||||
daily_token_limit=config.daily_token_limit,
|
||||
weekly_token_limit=config.weekly_token_limit,
|
||||
)
|
||||
except RateLimitExceeded as e:
|
||||
raise HTTPException(status_code=429, detail=str(e)) from e
|
||||
|
||||
# Enrich message with file metadata if file_ids are provided.
|
||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||
@@ -897,36 +828,6 @@ async def session_assign_user(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Suggested Prompts ==========
|
||||
|
||||
|
||||
class SuggestedPromptsResponse(BaseModel):
|
||||
"""Response model for user-specific suggested prompts."""
|
||||
|
||||
prompts: list[str]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/suggested-prompts",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
)
|
||||
async def get_suggested_prompts(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> SuggestedPromptsResponse:
|
||||
"""
|
||||
Get LLM-generated suggested prompts for the authenticated user.
|
||||
|
||||
Returns personalized quick-action prompts based on the user's
|
||||
business understanding. Returns an empty list if no custom prompts
|
||||
are available.
|
||||
"""
|
||||
understanding = await get_business_understanding(user_id)
|
||||
if understanding is None:
|
||||
return SuggestedPromptsResponse(prompts=[])
|
||||
|
||||
return SuggestedPromptsResponse(prompts=understanding.suggested_prompts)
|
||||
|
||||
|
||||
# ========== Configuration ==========
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Tests for chat API routes: session title update, file attachment validation, usage, and suggested prompts."""
|
||||
"""Tests for chat API routes: session title update and file attachment validation."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
@@ -250,130 +249,3 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
|
||||
assert call_kwargs["where"]["isDeleted"] is False
|
||||
|
||||
|
||||
# ─── Usage endpoint ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_usage(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
daily_used: int = 500,
|
||||
weekly_used: int = 2000,
|
||||
) -> AsyncMock:
|
||||
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
|
||||
resets_at = datetime.now(UTC) + timedelta(days=1)
|
||||
status = CoPilotUsageStatus(
|
||||
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
|
||||
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
|
||||
)
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_usage_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=status,
|
||||
)
|
||||
|
||||
|
||||
def test_usage_returns_daily_and_weekly(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""GET /usage returns daily and weekly usage."""
|
||||
mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
|
||||
response = client.get("/usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily"]["used"] == 500
|
||||
assert data["weekly"]["used"] == 2000
|
||||
|
||||
mock_get.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=10000,
|
||||
weekly_token_limit=50000,
|
||||
)
|
||||
|
||||
|
||||
def test_usage_uses_config_limits(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
|
||||
mock_get = _mock_usage(mocker)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
|
||||
|
||||
response = client.get("/usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_get.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=99999,
|
||||
weekly_token_limit=77777,
|
||||
)
|
||||
|
||||
|
||||
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_get_business_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
return_value=None,
|
||||
):
|
||||
"""Mock get_business_understanding."""
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=return_value,
|
||||
)
|
||||
|
||||
|
||||
def test_suggested_prompts_returns_prompts(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with understanding and prompts gets them back."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = ["Do X", "Do Y", "Do Z"]
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": ["Do X", "Do Y", "Do Z"]}
|
||||
|
||||
|
||||
def test_suggested_prompts_no_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with no understanding gets empty list."""
|
||||
_mock_get_business_understanding(mocker, return_value=None)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": []}
|
||||
|
||||
|
||||
def test_suggested_prompts_empty_prompts(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with understanding but no prompts gets empty list."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = []
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": []}
|
||||
|
||||
@@ -165,6 +165,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
owner_user_id: str
|
||||
|
||||
image_url: str | None
|
||||
|
||||
@@ -205,9 +206,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
default_factory=list,
|
||||
description="List of recent executions with status, score, and summary",
|
||||
)
|
||||
can_access_graph: bool = pydantic.Field(
|
||||
description="Indicates whether the same user owns the corresponding graph"
|
||||
)
|
||||
can_access_graph: bool
|
||||
is_latest_version: bool
|
||||
is_favorite: bool
|
||||
folder_id: str | None = None
|
||||
@@ -325,6 +324,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id=agent.id,
|
||||
graph_id=agent.agentGraphId,
|
||||
graph_version=agent.agentGraphVersion,
|
||||
owner_user_id=agent.userId,
|
||||
image_url=agent.imageUrl,
|
||||
creator_name=creator_name,
|
||||
creator_image_url=creator_image_url,
|
||||
|
||||
@@ -42,6 +42,7 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
@@ -66,6 +67,7 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-2",
|
||||
graph_id="test-agent-2",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
image_url=None,
|
||||
@@ -129,6 +131,7 @@ async def test_get_favorite_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Favorite Agent 1",
|
||||
description="Test Favorite Description 1",
|
||||
image_url=None,
|
||||
@@ -181,6 +184,7 @@ def test_add_agent_to_library_success(
|
||||
id="test-library-agent-id",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
|
||||
@@ -55,7 +55,6 @@ from backend.data.credit import (
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.graph import GraphSettings
|
||||
from backend.data.invited_user import get_or_activate_user
|
||||
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.data.onboarding import (
|
||||
@@ -71,6 +70,7 @@ from backend.data.onboarding import (
|
||||
update_user_onboarding,
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_by_id,
|
||||
get_user_notification_preference,
|
||||
update_user_email,
|
||||
@@ -136,10 +136,12 @@ _tally_background_tasks: set[asyncio.Task] = set()
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
user = await get_or_activate_user(user_data)
|
||||
user = await get_or_create_user(user_data)
|
||||
|
||||
# Fire-and-forget: backfill Tally understanding when invite pre-seeding did
|
||||
# not produce a stored result before first activation.
|
||||
# Fire-and-forget: populate business understanding from Tally form.
|
||||
# We use created_at proximity instead of an is_new flag because
|
||||
# get_or_create_user is cached — a separate is_new return value would be
|
||||
# unreliable on repeated calls within the cache TTL.
|
||||
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
|
||||
if age_seconds < 30:
|
||||
try:
|
||||
@@ -163,8 +165,7 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_user_email_route(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
email: str = Body(...),
|
||||
user_id: Annotated[str, Security(get_user_id)], email: str = Body(...)
|
||||
) -> dict[str, str]:
|
||||
await update_user_email(user_id, email)
|
||||
|
||||
@@ -178,16 +179,10 @@ async def update_user_email_route(
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_user_timezone_route(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
user_data: dict = Security(get_jwt_payload),
|
||||
) -> TimezoneResponse:
|
||||
"""Get user timezone setting."""
|
||||
try:
|
||||
user = await get_user_by_id(user_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_404_NOT_FOUND,
|
||||
detail="User not found. Please complete activation via /auth/user first.",
|
||||
)
|
||||
user = await get_or_create_user(user_data)
|
||||
return TimezoneResponse(timezone=user.timezone)
|
||||
|
||||
|
||||
@@ -198,8 +193,7 @@ async def get_user_timezone_route(
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_user_timezone_route(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
request: UpdateTimezoneRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
|
||||
) -> TimezoneResponse:
|
||||
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
|
||||
user = await update_user_timezone(user_id, str(request.timezone))
|
||||
|
||||
@@ -51,7 +51,7 @@ def test_get_or_create_user_route(
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_or_activate_user",
|
||||
"backend.api.features.v1.get_or_create_user",
|
||||
return_value=mock_user,
|
||||
)
|
||||
|
||||
|
||||
@@ -94,8 +94,3 @@ class NotificationPayload(pydantic.BaseModel):
|
||||
|
||||
class OnboardingNotificationPayload(NotificationPayload):
|
||||
step: OnboardingStep | None
|
||||
|
||||
|
||||
class CopilotCompletionPayload(NotificationPayload):
|
||||
session_id: str
|
||||
status: Literal["completed", "failed"]
|
||||
|
||||
@@ -19,7 +19,6 @@ from prisma.errors import PrismaError
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.admin.user_admin_routes
|
||||
import backend.api.features.builder
|
||||
import backend.api.features.builder.routes
|
||||
import backend.api.features.chat.routes as chat_routes
|
||||
@@ -312,11 +311,6 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/executions",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.user_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/users",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.executions.review.routes.router,
|
||||
tags=["v2", "executions", "review"],
|
||||
|
||||
@@ -96,7 +96,6 @@ class SendEmailBlock(Block):
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Email sent successfully")],
|
||||
test_mock={"send_email": lambda *args, **kwargs: "Email sent successfully"},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
def github_repo_path(repo_url: str) -> str:
|
||||
"""Extract 'owner/repo' from a GitHub repository URL."""
|
||||
return repo_url.replace("https://github.com/", "")
|
||||
@@ -1,374 +0,0 @@
|
||||
import asyncio
|
||||
from enum import StrEnum
|
||||
from urllib.parse import quote
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
from ._utils import github_repo_path
|
||||
|
||||
|
||||
class GithubListCommitsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to list commits from",
|
||||
default="main",
|
||||
)
|
||||
per_page: int = SchemaField(
|
||||
description="Number of commits to return (max 100)",
|
||||
default=30,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
page: int = SchemaField(
|
||||
description="Page number for pagination",
|
||||
default=1,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class CommitItem(TypedDict):
|
||||
sha: str
|
||||
message: str
|
||||
author: str
|
||||
date: str
|
||||
url: str
|
||||
|
||||
commit: CommitItem = SchemaField(
|
||||
title="Commit", description="A commit with its details"
|
||||
)
|
||||
commits: list[CommitItem] = SchemaField(
|
||||
description="List of commits with their details"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing commits failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8b13f579-d8b6-4dc2-a140-f770428805de",
|
||||
description="This block lists commits on a branch in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListCommitsBlock.Input,
|
||||
output_schema=GithubListCommitsBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "main",
|
||||
"per_page": 30,
|
||||
"page": 1,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"commits",
|
||||
[
|
||||
{
|
||||
"sha": "abc123",
|
||||
"message": "Initial commit",
|
||||
"author": "octocat",
|
||||
"date": "2024-01-01T00:00:00Z",
|
||||
"url": "https://github.com/owner/repo/commit/abc123",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"commit",
|
||||
{
|
||||
"sha": "abc123",
|
||||
"message": "Initial commit",
|
||||
"author": "octocat",
|
||||
"date": "2024-01-01T00:00:00Z",
|
||||
"url": "https://github.com/owner/repo/commit/abc123",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_commits": lambda *args, **kwargs: [
|
||||
{
|
||||
"sha": "abc123",
|
||||
"message": "Initial commit",
|
||||
"author": "octocat",
|
||||
"date": "2024-01-01T00:00:00Z",
|
||||
"url": "https://github.com/owner/repo/commit/abc123",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_commits(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
branch: str,
|
||||
per_page: int,
|
||||
page: int,
|
||||
) -> list[Output.CommitItem]:
|
||||
api = get_api(credentials)
|
||||
commits_url = repo_url + "/commits"
|
||||
params = {"sha": branch, "per_page": str(per_page), "page": str(page)}
|
||||
response = await api.get(commits_url, params=params)
|
||||
data = response.json()
|
||||
repo_path = github_repo_path(repo_url)
|
||||
return [
|
||||
GithubListCommitsBlock.Output.CommitItem(
|
||||
sha=c["sha"],
|
||||
message=c["commit"]["message"],
|
||||
author=(c["commit"].get("author") or {}).get("name", "Unknown"),
|
||||
date=(c["commit"].get("author") or {}).get("date", ""),
|
||||
url=f"https://github.com/{repo_path}/commit/{c['sha']}",
|
||||
)
|
||||
for c in data
|
||||
]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
commits = await self.list_commits(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.per_page,
|
||||
input_data.page,
|
||||
)
|
||||
yield "commits", commits
|
||||
for commit in commits:
|
||||
yield "commit", commit
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class FileOperation(StrEnum):
|
||||
"""File operations for GithubMultiFileCommitBlock.
|
||||
|
||||
UPSERT creates or overwrites a file (the Git Trees API does not distinguish
|
||||
between creation and update — the blob is placed at the given path regardless
|
||||
of whether a file already exists there).
|
||||
|
||||
DELETE removes a file from the tree.
|
||||
"""
|
||||
|
||||
UPSERT = "upsert"
|
||||
DELETE = "delete"
|
||||
|
||||
|
||||
class FileOperationInput(TypedDict):
|
||||
path: str
|
||||
content: str
|
||||
operation: FileOperation
|
||||
|
||||
|
||||
class GithubMultiFileCommitBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch to commit to",
|
||||
placeholder="feature-branch",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Commit message",
|
||||
placeholder="Add new feature",
|
||||
)
|
||||
files: list[FileOperationInput] = SchemaField(
|
||||
description=(
|
||||
"List of file operations. Each item has: "
|
||||
"'path' (file path), 'content' (file content, ignored for delete), "
|
||||
"'operation' (upsert/delete)"
|
||||
),
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
sha: str = SchemaField(description="SHA of the new commit")
|
||||
url: str = SchemaField(description="URL of the new commit")
|
||||
error: str = SchemaField(description="Error message if the commit failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="389eee51-a95e-4230-9bed-92167a327802",
|
||||
description=(
|
||||
"This block creates a single commit with multiple file "
|
||||
"upsert/delete operations using the Git Trees API."
|
||||
),
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMultiFileCommitBlock.Input,
|
||||
output_schema=GithubMultiFileCommitBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "feature",
|
||||
"commit_message": "Add files",
|
||||
"files": [
|
||||
{
|
||||
"path": "src/new.py",
|
||||
"content": "print('hello')",
|
||||
"operation": "upsert",
|
||||
},
|
||||
{
|
||||
"path": "src/old.py",
|
||||
"content": "",
|
||||
"operation": "delete",
|
||||
},
|
||||
],
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("sha", "newcommitsha"),
|
||||
("url", "https://github.com/owner/repo/commit/newcommitsha"),
|
||||
],
|
||||
test_mock={
|
||||
"multi_file_commit": lambda *args, **kwargs: (
|
||||
"newcommitsha",
|
||||
"https://github.com/owner/repo/commit/newcommitsha",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def multi_file_commit(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
files: list[FileOperationInput],
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
safe_branch = quote(branch, safe="")
|
||||
|
||||
# 1. Get the latest commit SHA for the branch
|
||||
ref_url = repo_url + f"/git/refs/heads/{safe_branch}"
|
||||
response = await api.get(ref_url)
|
||||
ref_data = response.json()
|
||||
latest_commit_sha = ref_data["object"]["sha"]
|
||||
|
||||
# 2. Get the tree SHA of the latest commit
|
||||
commit_url = repo_url + f"/git/commits/{latest_commit_sha}"
|
||||
response = await api.get(commit_url)
|
||||
commit_data = response.json()
|
||||
base_tree_sha = commit_data["tree"]["sha"]
|
||||
|
||||
# 3. Build tree entries for each file operation (blobs created concurrently)
|
||||
async def _create_blob(content: str) -> str:
|
||||
blob_url = repo_url + "/git/blobs"
|
||||
blob_response = await api.post(
|
||||
blob_url,
|
||||
json={"content": content, "encoding": "utf-8"},
|
||||
)
|
||||
return blob_response.json()["sha"]
|
||||
|
||||
tree_entries: list[dict] = []
|
||||
upsert_files = []
|
||||
for file_op in files:
|
||||
path = file_op["path"]
|
||||
operation = FileOperation(file_op.get("operation", "upsert"))
|
||||
|
||||
if operation == FileOperation.DELETE:
|
||||
tree_entries.append(
|
||||
{
|
||||
"path": path,
|
||||
"mode": "100644",
|
||||
"type": "blob",
|
||||
"sha": None, # null SHA = delete
|
||||
}
|
||||
)
|
||||
else:
|
||||
upsert_files.append((path, file_op.get("content", "")))
|
||||
|
||||
# Create all blobs concurrently
|
||||
if upsert_files:
|
||||
blob_shas = await asyncio.gather(
|
||||
*[_create_blob(content) for _, content in upsert_files]
|
||||
)
|
||||
for (path, _), blob_sha in zip(upsert_files, blob_shas):
|
||||
tree_entries.append(
|
||||
{
|
||||
"path": path,
|
||||
"mode": "100644",
|
||||
"type": "blob",
|
||||
"sha": blob_sha,
|
||||
}
|
||||
)
|
||||
|
||||
# 4. Create a new tree
|
||||
tree_url = repo_url + "/git/trees"
|
||||
tree_response = await api.post(
|
||||
tree_url,
|
||||
json={"base_tree": base_tree_sha, "tree": tree_entries},
|
||||
)
|
||||
new_tree_sha = tree_response.json()["sha"]
|
||||
|
||||
# 5. Create a new commit
|
||||
new_commit_url = repo_url + "/git/commits"
|
||||
commit_response = await api.post(
|
||||
new_commit_url,
|
||||
json={
|
||||
"message": commit_message,
|
||||
"tree": new_tree_sha,
|
||||
"parents": [latest_commit_sha],
|
||||
},
|
||||
)
|
||||
new_commit_sha = commit_response.json()["sha"]
|
||||
|
||||
# 6. Update the branch reference
|
||||
try:
|
||||
await api.patch(
|
||||
ref_url,
|
||||
json={"sha": new_commit_sha},
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Commit {new_commit_sha} was created but failed to update "
|
||||
f"ref heads/{branch}: {e}. "
|
||||
f"You can recover by manually updating the branch to {new_commit_sha}."
|
||||
) from e
|
||||
|
||||
repo_path = github_repo_path(repo_url)
|
||||
commit_web_url = f"https://github.com/{repo_path}/commit/{new_commit_sha}"
|
||||
return new_commit_sha, commit_web_url
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
sha, url = await self.multi_file_commit(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
input_data.files,
|
||||
)
|
||||
yield "sha", sha
|
||||
yield "url", url
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
@@ -21,8 +20,6 @@ from ._auth import (
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
MergeMethod = Literal["merge", "squash", "rebase"]
|
||||
|
||||
|
||||
class GithubListPullRequestsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
@@ -561,109 +558,12 @@ class GithubListPRReviewersBlock(Block):
|
||||
yield "reviewer", reviewer
|
||||
|
||||
|
||||
class GithubMergePullRequestBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
pr_url: str = SchemaField(
|
||||
description="URL of the GitHub pull request",
|
||||
placeholder="https://github.com/owner/repo/pull/1",
|
||||
)
|
||||
merge_method: MergeMethod = SchemaField(
|
||||
description="Merge method to use: merge, squash, or rebase",
|
||||
default="merge",
|
||||
)
|
||||
commit_title: str = SchemaField(
|
||||
description="Title for the merge commit (optional, used for merge and squash)",
|
||||
default="",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the merge commit (optional, used for merge and squash)",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
sha: str = SchemaField(description="SHA of the merge commit")
|
||||
merged: bool = SchemaField(description="Whether the PR was merged")
|
||||
message: str = SchemaField(description="Merge status message")
|
||||
error: str = SchemaField(description="Error message if the merge failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="77456c22-33d8-4fd4-9eef-50b46a35bb48",
|
||||
description="This block merges a pull request using merge, squash, or rebase.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMergePullRequestBlock.Input,
|
||||
output_schema=GithubMergePullRequestBlock.Output,
|
||||
test_input={
|
||||
"pr_url": "https://github.com/owner/repo/pull/1",
|
||||
"merge_method": "squash",
|
||||
"commit_title": "",
|
||||
"commit_message": "",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("sha", "abc123"),
|
||||
("merged", True),
|
||||
("message", "Pull Request successfully merged"),
|
||||
],
|
||||
test_mock={
|
||||
"merge_pr": lambda *args, **kwargs: (
|
||||
"abc123",
|
||||
True,
|
||||
"Pull Request successfully merged",
|
||||
)
|
||||
},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def merge_pr(
|
||||
credentials: GithubCredentials,
|
||||
pr_url: str,
|
||||
merge_method: MergeMethod,
|
||||
commit_title: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, bool, str]:
|
||||
api = get_api(credentials)
|
||||
merge_url = prepare_pr_api_url(pr_url=pr_url, path="merge")
|
||||
data: dict[str, str] = {"merge_method": merge_method}
|
||||
if commit_title:
|
||||
data["commit_title"] = commit_title
|
||||
if commit_message:
|
||||
data["commit_message"] = commit_message
|
||||
response = await api.put(merge_url, json=data)
|
||||
result = response.json()
|
||||
return result["sha"], result["merged"], result["message"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
sha, merged, message = await self.merge_pr(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
input_data.merge_method,
|
||||
input_data.commit_title,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "sha", sha
|
||||
yield "merged", merged
|
||||
yield "message", message
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
def prepare_pr_api_url(pr_url: str, path: str) -> str:
|
||||
# Pattern to capture the base repository URL and the pull request number
|
||||
pattern = r"^(?:(https?)://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
|
||||
pattern = r"^(?:https?://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
|
||||
match = re.match(pattern, pr_url)
|
||||
if not match:
|
||||
return pr_url
|
||||
|
||||
scheme, base_url, pr_number = match.groups()
|
||||
return f"{scheme or 'https'}://{base_url}/pulls/{pr_number}/{path}"
|
||||
base_url, pr_number = match.groups()
|
||||
return f"{base_url}/pulls/{pr_number}/{path}"
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import base64
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
@@ -17,7 +19,6 @@ from ._auth import (
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
from ._utils import github_repo_path
|
||||
|
||||
|
||||
class GithubListTagsBlock(Block):
|
||||
@@ -88,7 +89,7 @@ class GithubListTagsBlock(Block):
|
||||
tags_url = repo_url + "/tags"
|
||||
response = await api.get(tags_url)
|
||||
data = response.json()
|
||||
repo_path = github_repo_path(repo_url)
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
tags: list[GithubListTagsBlock.Output.TagItem] = [
|
||||
{
|
||||
"name": tag["name"],
|
||||
@@ -114,6 +115,101 @@ class GithubListTagsBlock(Block):
|
||||
yield "tag", tag
|
||||
|
||||
|
||||
class GithubListBranchesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class BranchItem(TypedDict):
|
||||
name: str
|
||||
url: str
|
||||
|
||||
branch: BranchItem = SchemaField(
|
||||
title="Branch",
|
||||
description="Branches with their name and file tree browser URL",
|
||||
)
|
||||
branches: list[BranchItem] = SchemaField(
|
||||
description="List of branches with their name and file tree browser URL"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="74243e49-2bec-4916-8bf4-db43d44aead5",
|
||||
description="This block lists all branches for a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListBranchesBlock.Input,
|
||||
output_schema=GithubListBranchesBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"branches",
|
||||
[
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"branch",
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_branches": lambda *args, **kwargs: [
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_branches(
|
||||
credentials: GithubCredentials, repo_url: str
|
||||
) -> list[Output.BranchItem]:
|
||||
api = get_api(credentials)
|
||||
branches_url = repo_url + "/branches"
|
||||
response = await api.get(branches_url)
|
||||
data = response.json()
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
|
||||
{
|
||||
"name": branch["name"],
|
||||
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
|
||||
}
|
||||
for branch in data
|
||||
]
|
||||
return branches
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
branches = await self.list_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield "branches", branches
|
||||
for branch in branches:
|
||||
yield "branch", branch
|
||||
|
||||
|
||||
class GithubListDiscussionsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
@@ -187,7 +283,7 @@ class GithubListDiscussionsBlock(Block):
|
||||
) -> list[Output.DiscussionItem]:
|
||||
api = get_api(credentials)
|
||||
# GitHub GraphQL API endpoint is different; we'll use api.post with custom URL
|
||||
repo_path = github_repo_path(repo_url)
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
owner, repo = repo_path.split("/")
|
||||
query = """
|
||||
query($owner: String!, $repo: String!, $num: Int!) {
|
||||
@@ -320,6 +416,564 @@ class GithubListReleasesBlock(Block):
|
||||
yield "release", release
|
||||
|
||||
|
||||
class GithubReadFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file in the repository",
|
||||
placeholder="path/to/file",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch to read from",
|
||||
placeholder="branch_name",
|
||||
default="master",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
text_content: str = SchemaField(
|
||||
description="Content of the file (decoded as UTF-8 text)"
|
||||
)
|
||||
raw_content: str = SchemaField(
|
||||
description="Raw base64-encoded content of the file"
|
||||
)
|
||||
size: int = SchemaField(description="The size of the file (in bytes)")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
|
||||
description="This block reads the content of a specified file from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFileBlock.Input,
|
||||
output_schema=GithubReadFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "path/to/file",
|
||||
"branch": "master",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("raw_content", "RmlsZSBjb250ZW50"),
|
||||
("text_content", "File content"),
|
||||
("size", 13),
|
||||
],
|
||||
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_file(
|
||||
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
|
||||
) -> tuple[str, int]:
|
||||
api = get_api(credentials)
|
||||
content_url = repo_url + f"/contents/{file_path}?ref={branch}"
|
||||
response = await api.get(content_url)
|
||||
data = response.json()
|
||||
|
||||
if isinstance(data, list):
|
||||
# Multiple entries of different types exist at this path
|
||||
if not (file := next((f for f in data if f["type"] == "file"), None)):
|
||||
raise TypeError("Not a file")
|
||||
data = file
|
||||
|
||||
if data["type"] != "file":
|
||||
raise TypeError("Not a file")
|
||||
|
||||
return data["content"], data["size"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
content, size = await self.read_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "raw_content", content
|
||||
yield "text_content", base64.b64decode(content).decode("utf-8")
|
||||
yield "size", size
|
||||
|
||||
|
||||
class GithubReadFolderBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
folder_path: str = SchemaField(
|
||||
description="Path to the folder in the repository",
|
||||
placeholder="path/to/folder",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to read from (defaults to master)",
|
||||
placeholder="branch_name",
|
||||
default="master",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class DirEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
|
||||
class FileEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
size: int
|
||||
|
||||
file: FileEntry = SchemaField(description="Files in the folder")
|
||||
dir: DirEntry = SchemaField(description="Directories in the folder")
|
||||
error: str = SchemaField(
|
||||
description="Error message if reading the folder failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
|
||||
description="This block reads the content of a specified folder from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFolderBlock.Input,
|
||||
output_schema=GithubReadFolderBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"folder_path": "path/to/folder",
|
||||
"branch": "master",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"file",
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
},
|
||||
),
|
||||
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
|
||||
],
|
||||
test_mock={
|
||||
"read_folder": lambda *args, **kwargs: (
|
||||
[
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
}
|
||||
],
|
||||
[{"name": "dir2", "path": "path/to/folder/dir2"}],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_folder(
|
||||
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
|
||||
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{folder_path}?ref={branch}"
|
||||
response = await api.get(contents_url)
|
||||
data = response.json()
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise TypeError("Not a folder")
|
||||
|
||||
files: list[GithubReadFolderBlock.Output.FileEntry] = [
|
||||
GithubReadFolderBlock.Output.FileEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
size=entry["size"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "file"
|
||||
]
|
||||
|
||||
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
|
||||
GithubReadFolderBlock.Output.DirEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "dir"
|
||||
]
|
||||
|
||||
return files, dirs
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
files, dirs = await self.read_folder(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.folder_path.lstrip("/"),
|
||||
input_data.branch,
|
||||
)
|
||||
for file in files:
|
||||
yield "file", file
|
||||
for dir in dirs:
|
||||
yield "dir", dir
|
||||
|
||||
|
||||
class GithubMakeBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
new_branch: str = SchemaField(
|
||||
description="Name of the new branch",
|
||||
placeholder="new_branch_name",
|
||||
)
|
||||
source_branch: str = SchemaField(
|
||||
description="Name of the source branch",
|
||||
placeholder="source_branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch creation operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
|
||||
description="This block creates a new branch from a specified source branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMakeBranchBlock.Input,
|
||||
output_schema=GithubMakeBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"new_branch": "new_branch_name",
|
||||
"source_branch": "source_branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch created successfully")],
|
||||
test_mock={
|
||||
"create_branch": lambda *args, **kwargs: "Branch created successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_branch(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
new_branch: str,
|
||||
source_branch: str,
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{source_branch}"
|
||||
response = await api.get(ref_url)
|
||||
data = response.json()
|
||||
sha = data["object"]["sha"]
|
||||
|
||||
# Create the new branch
|
||||
new_ref_url = repo_url + "/git/refs"
|
||||
data = {
|
||||
"ref": f"refs/heads/{new_branch}",
|
||||
"sha": sha,
|
||||
}
|
||||
response = await api.post(new_ref_url, json=data)
|
||||
return "Branch created successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
status = await self.create_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.new_branch,
|
||||
input_data.source_branch,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubDeleteBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Name of the branch to delete",
|
||||
placeholder="branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch deletion operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch deletion failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
|
||||
description="This block deletes a specified branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubDeleteBranchBlock.Input,
|
||||
output_schema=GithubDeleteBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch deleted successfully")],
|
||||
test_mock={
|
||||
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_branch(
|
||||
credentials: GithubCredentials, repo_url: str, branch: str
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{branch}"
|
||||
await api.delete(ref_url)
|
||||
return "Branch deleted successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
status = await self.delete_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubCreateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path where the file should be created",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="Content to write to the file",
|
||||
placeholder="File content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch where the file should be created",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Create new file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the created file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the file creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
|
||||
description="This block creates a new file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCreateFileBlock.Input,
|
||||
output_schema=GithubCreateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Test content",
|
||||
"branch": "main",
|
||||
"commit_message": "Create test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "abc123"),
|
||||
],
|
||||
test_mock={
|
||||
"create_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"abc123",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{file_path}"
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.create_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubUpdateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file to update",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="New content for the file",
|
||||
placeholder="Updated content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch containing the file",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Update file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the updated file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
|
||||
description="This block updates an existing file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubUpdateFileBlock.Input,
|
||||
output_schema=GithubUpdateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Updated content",
|
||||
"branch": "main",
|
||||
"commit_message": "Update test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "def456"),
|
||||
],
|
||||
test_mock={
|
||||
"update_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"def456",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{file_path}"
|
||||
params = {"ref": branch}
|
||||
response = await api.get(contents_url, params=params)
|
||||
data = response.json()
|
||||
|
||||
# Convert new content to base64
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"sha": data["sha"],
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.update_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubCreateRepositoryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
@@ -449,7 +1103,7 @@ class GithubListStargazersBlock(Block):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e96d01ec-b55e-4a99-8ce8-c8776dce850b", # Generated unique UUID
|
||||
id="a4b9c2d1-e5f6-4g7h-8i9j-0k1l2m3n4o5p", # Generated unique UUID
|
||||
description="This block lists all users who have starred a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListStargazersBlock.Input,
|
||||
@@ -518,230 +1172,3 @@ class GithubListStargazersBlock(Block):
|
||||
yield "stargazers", stargazers
|
||||
for stargazer in stargazers:
|
||||
yield "stargazer", stargazer
|
||||
|
||||
|
||||
class GithubGetRepositoryInfoBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
name: str = SchemaField(description="Repository name")
|
||||
full_name: str = SchemaField(description="Full repository name (owner/repo)")
|
||||
description: str = SchemaField(description="Repository description")
|
||||
default_branch: str = SchemaField(description="Default branch name (e.g. main)")
|
||||
private: bool = SchemaField(description="Whether the repository is private")
|
||||
html_url: str = SchemaField(description="Web URL of the repository")
|
||||
clone_url: str = SchemaField(description="Git clone URL")
|
||||
stars: int = SchemaField(description="Number of stars")
|
||||
forks: int = SchemaField(description="Number of forks")
|
||||
open_issues: int = SchemaField(description="Number of open issues")
|
||||
error: str = SchemaField(
|
||||
description="Error message if fetching repo info failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="59d4f241-968a-4040-95da-348ac5c5ce27",
|
||||
description="This block retrieves metadata about a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubGetRepositoryInfoBlock.Input,
|
||||
output_schema=GithubGetRepositoryInfoBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("name", "repo"),
|
||||
("full_name", "owner/repo"),
|
||||
("description", "A test repo"),
|
||||
("default_branch", "main"),
|
||||
("private", False),
|
||||
("html_url", "https://github.com/owner/repo"),
|
||||
("clone_url", "https://github.com/owner/repo.git"),
|
||||
("stars", 42),
|
||||
("forks", 5),
|
||||
("open_issues", 3),
|
||||
],
|
||||
test_mock={
|
||||
"get_repo_info": lambda *args, **kwargs: {
|
||||
"name": "repo",
|
||||
"full_name": "owner/repo",
|
||||
"description": "A test repo",
|
||||
"default_branch": "main",
|
||||
"private": False,
|
||||
"html_url": "https://github.com/owner/repo",
|
||||
"clone_url": "https://github.com/owner/repo.git",
|
||||
"stargazers_count": 42,
|
||||
"forks_count": 5,
|
||||
"open_issues_count": 3,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_repo_info(credentials: GithubCredentials, repo_url: str) -> dict:
|
||||
api = get_api(credentials)
|
||||
response = await api.get(repo_url)
|
||||
return response.json()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data = await self.get_repo_info(credentials, input_data.repo_url)
|
||||
yield "name", data["name"]
|
||||
yield "full_name", data["full_name"]
|
||||
yield "description", data.get("description", "") or ""
|
||||
yield "default_branch", data["default_branch"]
|
||||
yield "private", data["private"]
|
||||
yield "html_url", data["html_url"]
|
||||
yield "clone_url", data["clone_url"]
|
||||
yield "stars", data["stargazers_count"]
|
||||
yield "forks", data["forks_count"]
|
||||
yield "open_issues", data["open_issues_count"]
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubForkRepositoryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository to fork",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
organization: str = SchemaField(
|
||||
description="Organization to fork into (leave empty to fork to your account)",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the forked repository")
|
||||
clone_url: str = SchemaField(description="Git clone URL of the fork")
|
||||
full_name: str = SchemaField(description="Full name of the fork (owner/repo)")
|
||||
error: str = SchemaField(description="Error message if the fork failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a439f2f4-835f-4dae-ba7b-0205ffa70be6",
|
||||
description="This block forks a GitHub repository to your account or an organization.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubForkRepositoryBlock.Input,
|
||||
output_schema=GithubForkRepositoryBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"organization": "",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/myuser/repo"),
|
||||
("clone_url", "https://github.com/myuser/repo.git"),
|
||||
("full_name", "myuser/repo"),
|
||||
],
|
||||
test_mock={
|
||||
"fork_repo": lambda *args, **kwargs: (
|
||||
"https://github.com/myuser/repo",
|
||||
"https://github.com/myuser/repo.git",
|
||||
"myuser/repo",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def fork_repo(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
organization: str,
|
||||
) -> tuple[str, str, str]:
|
||||
api = get_api(credentials)
|
||||
forks_url = repo_url + "/forks"
|
||||
data: dict[str, str] = {}
|
||||
if organization:
|
||||
data["organization"] = organization
|
||||
response = await api.post(forks_url, json=data)
|
||||
result = response.json()
|
||||
return result["html_url"], result["clone_url"], result["full_name"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, clone_url, full_name = await self.fork_repo(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.organization,
|
||||
)
|
||||
yield "url", url
|
||||
yield "clone_url", clone_url
|
||||
yield "full_name", full_name
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubStarRepositoryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository to star",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the star operation")
|
||||
error: str = SchemaField(description="Error message if starring failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bd700764-53e3-44dd-a969-d1854088458f",
|
||||
description="This block stars a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubStarRepositoryBlock.Input,
|
||||
output_schema=GithubStarRepositoryBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Repository starred successfully")],
|
||||
test_mock={
|
||||
"star_repo": lambda *args, **kwargs: "Repository starred successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def star_repo(credentials: GithubCredentials, repo_url: str) -> str:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
repo_path = github_repo_path(repo_url)
|
||||
owner, repo = repo_path.split("/")
|
||||
await api.put(
|
||||
f"https://api.github.com/user/starred/{owner}/{repo}",
|
||||
headers={"Content-Length": "0"},
|
||||
)
|
||||
return "Repository starred successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = await self.star_repo(credentials, input_data.repo_url)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@@ -1,452 +0,0 @@
|
||||
from urllib.parse import quote
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
from ._utils import github_repo_path
|
||||
|
||||
|
||||
class GithubListBranchesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
per_page: int = SchemaField(
|
||||
description="Number of branches to return per page (max 100)",
|
||||
default=30,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
page: int = SchemaField(
|
||||
description="Page number for pagination",
|
||||
default=1,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class BranchItem(TypedDict):
|
||||
name: str
|
||||
url: str
|
||||
|
||||
branch: BranchItem = SchemaField(
|
||||
title="Branch",
|
||||
description="Branches with their name and file tree browser URL",
|
||||
)
|
||||
branches: list[BranchItem] = SchemaField(
|
||||
description="List of branches with their name and file tree browser URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing branches failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="74243e49-2bec-4916-8bf4-db43d44aead5",
|
||||
description="This block lists all branches for a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListBranchesBlock.Input,
|
||||
output_schema=GithubListBranchesBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"per_page": 30,
|
||||
"page": 1,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"branches",
|
||||
[
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"branch",
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_branches": lambda *args, **kwargs: [
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_branches(
|
||||
credentials: GithubCredentials, repo_url: str, per_page: int, page: int
|
||||
) -> list[Output.BranchItem]:
|
||||
api = get_api(credentials)
|
||||
branches_url = repo_url + "/branches"
|
||||
response = await api.get(
|
||||
branches_url, params={"per_page": str(per_page), "page": str(page)}
|
||||
)
|
||||
data = response.json()
|
||||
repo_path = github_repo_path(repo_url)
|
||||
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
|
||||
{
|
||||
"name": branch["name"],
|
||||
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
|
||||
}
|
||||
for branch in data
|
||||
]
|
||||
return branches
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
branches = await self.list_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.per_page,
|
||||
input_data.page,
|
||||
)
|
||||
yield "branches", branches
|
||||
for branch in branches:
|
||||
yield "branch", branch
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubMakeBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
new_branch: str = SchemaField(
|
||||
description="Name of the new branch",
|
||||
placeholder="new_branch_name",
|
||||
)
|
||||
source_branch: str = SchemaField(
|
||||
description="Name of the source branch",
|
||||
placeholder="source_branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch creation operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
|
||||
description="This block creates a new branch from a specified source branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMakeBranchBlock.Input,
|
||||
output_schema=GithubMakeBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"new_branch": "new_branch_name",
|
||||
"source_branch": "source_branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch created successfully")],
|
||||
test_mock={
|
||||
"create_branch": lambda *args, **kwargs: "Branch created successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_branch(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
new_branch: str,
|
||||
source_branch: str,
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{quote(source_branch, safe='')}"
|
||||
response = await api.get(ref_url)
|
||||
data = response.json()
|
||||
sha = data["object"]["sha"]
|
||||
|
||||
# Create the new branch
|
||||
new_ref_url = repo_url + "/git/refs"
|
||||
data = {
|
||||
"ref": f"refs/heads/{new_branch}",
|
||||
"sha": sha,
|
||||
}
|
||||
response = await api.post(new_ref_url, json=data)
|
||||
return "Branch created successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = await self.create_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.new_branch,
|
||||
input_data.source_branch,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubDeleteBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Name of the branch to delete",
|
||||
placeholder="branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch deletion operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch deletion failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
|
||||
description="This block deletes a specified branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubDeleteBranchBlock.Input,
|
||||
output_schema=GithubDeleteBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch deleted successfully")],
|
||||
test_mock={
|
||||
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
|
||||
},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_branch(
|
||||
credentials: GithubCredentials, repo_url: str, branch: str
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{quote(branch, safe='')}"
|
||||
await api.delete(ref_url)
|
||||
return "Branch deleted successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = await self.delete_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubCompareBranchesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
base: str = SchemaField(
|
||||
description="Base branch or commit SHA",
|
||||
placeholder="main",
|
||||
)
|
||||
head: str = SchemaField(
|
||||
description="Head branch or commit SHA to compare against base",
|
||||
placeholder="feature-branch",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class FileChange(TypedDict):
|
||||
filename: str
|
||||
status: str
|
||||
additions: int
|
||||
deletions: int
|
||||
patch: str
|
||||
|
||||
status: str = SchemaField(
|
||||
description="Comparison status: ahead, behind, diverged, or identical"
|
||||
)
|
||||
ahead_by: int = SchemaField(
|
||||
description="Number of commits head is ahead of base"
|
||||
)
|
||||
behind_by: int = SchemaField(
|
||||
description="Number of commits head is behind base"
|
||||
)
|
||||
total_commits: int = SchemaField(
|
||||
description="Total number of commits in the comparison"
|
||||
)
|
||||
diff: str = SchemaField(description="Unified diff of all file changes")
|
||||
file: FileChange = SchemaField(
|
||||
title="Changed File", description="A changed file with its diff"
|
||||
)
|
||||
files: list[FileChange] = SchemaField(
|
||||
description="List of changed files with their diffs"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if comparison failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2e4faa8c-6086-4546-ba77-172d1d560186",
|
||||
description="This block compares two branches or commits in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCompareBranchesBlock.Input,
|
||||
output_schema=GithubCompareBranchesBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"base": "main",
|
||||
"head": "feature",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("status", "ahead"),
|
||||
("ahead_by", 2),
|
||||
("behind_by", 0),
|
||||
("total_commits", 2),
|
||||
("diff", "+++ b/file.py\n+new line"),
|
||||
(
|
||||
"files",
|
||||
[
|
||||
{
|
||||
"filename": "file.py",
|
||||
"status": "modified",
|
||||
"additions": 1,
|
||||
"deletions": 0,
|
||||
"patch": "+new line",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"file",
|
||||
{
|
||||
"filename": "file.py",
|
||||
"status": "modified",
|
||||
"additions": 1,
|
||||
"deletions": 0,
|
||||
"patch": "+new line",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"compare_branches": lambda *args, **kwargs: {
|
||||
"status": "ahead",
|
||||
"ahead_by": 2,
|
||||
"behind_by": 0,
|
||||
"total_commits": 2,
|
||||
"files": [
|
||||
{
|
||||
"filename": "file.py",
|
||||
"status": "modified",
|
||||
"additions": 1,
|
||||
"deletions": 0,
|
||||
"patch": "+new line",
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def compare_branches(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
base: str,
|
||||
head: str,
|
||||
) -> dict:
|
||||
api = get_api(credentials)
|
||||
safe_base = quote(base, safe="")
|
||||
safe_head = quote(head, safe="")
|
||||
compare_url = repo_url + f"/compare/{safe_base}...{safe_head}"
|
||||
response = await api.get(compare_url)
|
||||
return response.json()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data = await self.compare_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.base,
|
||||
input_data.head,
|
||||
)
|
||||
yield "status", data["status"]
|
||||
yield "ahead_by", data["ahead_by"]
|
||||
yield "behind_by", data["behind_by"]
|
||||
yield "total_commits", data["total_commits"]
|
||||
|
||||
files: list[GithubCompareBranchesBlock.Output.FileChange] = [
|
||||
GithubCompareBranchesBlock.Output.FileChange(
|
||||
filename=f["filename"],
|
||||
status=f["status"],
|
||||
additions=f["additions"],
|
||||
deletions=f["deletions"],
|
||||
patch=f.get("patch", ""),
|
||||
)
|
||||
for f in data.get("files", [])
|
||||
]
|
||||
|
||||
# Build unified diff
|
||||
diff_parts = []
|
||||
for f in data.get("files", []):
|
||||
patch = f.get("patch", "")
|
||||
if patch:
|
||||
diff_parts.append(f"+++ b/{f['filename']}\n{patch}")
|
||||
yield "diff", "\n".join(diff_parts)
|
||||
|
||||
yield "files", files
|
||||
for file in files:
|
||||
yield "file", file
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,720 +0,0 @@
|
||||
import base64
|
||||
from urllib.parse import quote
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class GithubReadFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file in the repository",
|
||||
placeholder="path/to/file",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch to read from",
|
||||
placeholder="branch_name",
|
||||
default="main",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
text_content: str = SchemaField(
|
||||
description="Content of the file (decoded as UTF-8 text)"
|
||||
)
|
||||
raw_content: str = SchemaField(
|
||||
description="Raw base64-encoded content of the file"
|
||||
)
|
||||
size: int = SchemaField(description="The size of the file (in bytes)")
|
||||
error: str = SchemaField(description="Error message if reading the file failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
|
||||
description="This block reads the content of a specified file from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFileBlock.Input,
|
||||
output_schema=GithubReadFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "path/to/file",
|
||||
"branch": "main",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("raw_content", "RmlsZSBjb250ZW50"),
|
||||
("text_content", "File content"),
|
||||
("size", 13),
|
||||
],
|
||||
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_file(
|
||||
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
|
||||
) -> tuple[str, int]:
|
||||
api = get_api(credentials)
|
||||
content_url = (
|
||||
repo_url
|
||||
+ f"/contents/{quote(file_path, safe='')}?ref={quote(branch, safe='')}"
|
||||
)
|
||||
response = await api.get(content_url)
|
||||
data = response.json()
|
||||
|
||||
if isinstance(data, list):
|
||||
# Multiple entries of different types exist at this path
|
||||
if not (file := next((f for f in data if f["type"] == "file"), None)):
|
||||
raise TypeError("Not a file")
|
||||
data = file
|
||||
|
||||
if data["type"] != "file":
|
||||
raise TypeError("Not a file")
|
||||
|
||||
return data["content"], data["size"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
content, size = await self.read_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "raw_content", content
|
||||
yield "text_content", base64.b64decode(content).decode("utf-8")
|
||||
yield "size", size
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubReadFolderBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
folder_path: str = SchemaField(
|
||||
description="Path to the folder in the repository",
|
||||
placeholder="path/to/folder",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to read from (defaults to main)",
|
||||
placeholder="branch_name",
|
||||
default="main",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class DirEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
|
||||
class FileEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
size: int
|
||||
|
||||
file: FileEntry = SchemaField(description="Files in the folder")
|
||||
dir: DirEntry = SchemaField(description="Directories in the folder")
|
||||
error: str = SchemaField(
|
||||
description="Error message if reading the folder failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
|
||||
description="This block reads the content of a specified folder from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFolderBlock.Input,
|
||||
output_schema=GithubReadFolderBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"folder_path": "path/to/folder",
|
||||
"branch": "main",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"file",
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
},
|
||||
),
|
||||
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
|
||||
],
|
||||
test_mock={
|
||||
"read_folder": lambda *args, **kwargs: (
|
||||
[
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
}
|
||||
],
|
||||
[{"name": "dir2", "path": "path/to/folder/dir2"}],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_folder(
|
||||
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
|
||||
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
|
||||
api = get_api(credentials)
|
||||
contents_url = (
|
||||
repo_url
|
||||
+ f"/contents/{quote(folder_path, safe='/')}?ref={quote(branch, safe='')}"
|
||||
)
|
||||
response = await api.get(contents_url)
|
||||
data = response.json()
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise TypeError("Not a folder")
|
||||
|
||||
files: list[GithubReadFolderBlock.Output.FileEntry] = [
|
||||
GithubReadFolderBlock.Output.FileEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
size=entry["size"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "file"
|
||||
]
|
||||
|
||||
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
|
||||
GithubReadFolderBlock.Output.DirEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "dir"
|
||||
]
|
||||
|
||||
return files, dirs
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
files, dirs = await self.read_folder(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.folder_path.lstrip("/"),
|
||||
input_data.branch,
|
||||
)
|
||||
for file in files:
|
||||
yield "file", file
|
||||
for dir in dirs:
|
||||
yield "dir", dir
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubCreateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path where the file should be created",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="Content to write to the file",
|
||||
placeholder="File content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch where the file should be created",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Create new file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the created file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the file creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
|
||||
description="This block creates a new file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCreateFileBlock.Input,
|
||||
output_schema=GithubCreateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Test content",
|
||||
"branch": "main",
|
||||
"commit_message": "Create test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "abc123"),
|
||||
],
|
||||
test_mock={
|
||||
"create_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"abc123",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.create_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubUpdateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file to update",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="New content for the file",
|
||||
placeholder="Updated content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch containing the file",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Update file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the updated file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
|
||||
description="This block updates an existing file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubUpdateFileBlock.Input,
|
||||
output_schema=GithubUpdateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Updated content",
|
||||
"branch": "main",
|
||||
"commit_message": "Update test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "def456"),
|
||||
],
|
||||
test_mock={
|
||||
"update_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"def456",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
|
||||
params = {"ref": branch}
|
||||
response = await api.get(contents_url, params=params)
|
||||
data = response.json()
|
||||
|
||||
# Convert new content to base64
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"sha": data["sha"],
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.update_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubSearchCodeBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
query: str = SchemaField(
|
||||
description="Search query (GitHub code search syntax)",
|
||||
placeholder="className language:python",
|
||||
)
|
||||
repo: str = SchemaField(
|
||||
description="Restrict search to a repository (owner/repo format, optional)",
|
||||
default="",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
per_page: int = SchemaField(
|
||||
description="Number of results to return (max 100)",
|
||||
default=30,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class SearchResult(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
repository: str
|
||||
url: str
|
||||
score: float
|
||||
|
||||
result: SearchResult = SchemaField(
|
||||
title="Result", description="A code search result"
|
||||
)
|
||||
results: list[SearchResult] = SchemaField(
|
||||
description="List of code search results"
|
||||
)
|
||||
total_count: int = SchemaField(description="Total number of matching results")
|
||||
error: str = SchemaField(description="Error message if search failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="47f94891-a2b1-4f1c-b5f2-573c043f721e",
|
||||
description="This block searches for code in GitHub repositories.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubSearchCodeBlock.Input,
|
||||
output_schema=GithubSearchCodeBlock.Output,
|
||||
test_input={
|
||||
"query": "addClass",
|
||||
"repo": "owner/repo",
|
||||
"per_page": 30,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("total_count", 1),
|
||||
(
|
||||
"results",
|
||||
[
|
||||
{
|
||||
"name": "file.py",
|
||||
"path": "src/file.py",
|
||||
"repository": "owner/repo",
|
||||
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
||||
"score": 1.0,
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"result",
|
||||
{
|
||||
"name": "file.py",
|
||||
"path": "src/file.py",
|
||||
"repository": "owner/repo",
|
||||
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
||||
"score": 1.0,
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"search_code": lambda *args, **kwargs: (
|
||||
1,
|
||||
[
|
||||
{
|
||||
"name": "file.py",
|
||||
"path": "src/file.py",
|
||||
"repository": "owner/repo",
|
||||
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
||||
"score": 1.0,
|
||||
}
|
||||
],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def search_code(
|
||||
credentials: GithubCredentials,
|
||||
query: str,
|
||||
repo: str,
|
||||
per_page: int,
|
||||
) -> tuple[int, list[Output.SearchResult]]:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
full_query = f"{query} repo:{repo}" if repo else query
|
||||
params = {"q": full_query, "per_page": str(per_page)}
|
||||
response = await api.get("https://api.github.com/search/code", params=params)
|
||||
data = response.json()
|
||||
results: list[GithubSearchCodeBlock.Output.SearchResult] = [
|
||||
GithubSearchCodeBlock.Output.SearchResult(
|
||||
name=item["name"],
|
||||
path=item["path"],
|
||||
repository=item["repository"]["full_name"],
|
||||
url=item["html_url"],
|
||||
score=item["score"],
|
||||
)
|
||||
for item in data["items"]
|
||||
]
|
||||
return data["total_count"], results
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
total_count, results = await self.search_code(
|
||||
credentials,
|
||||
input_data.query,
|
||||
input_data.repo,
|
||||
input_data.per_page,
|
||||
)
|
||||
yield "total_count", total_count
|
||||
yield "results", results
|
||||
for result in results:
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubGetRepositoryTreeBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to get the tree from",
|
||||
default="main",
|
||||
)
|
||||
recursive: bool = SchemaField(
|
||||
description="Whether to recursively list the entire tree",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class TreeEntry(TypedDict):
|
||||
path: str
|
||||
type: str
|
||||
size: int
|
||||
sha: str
|
||||
|
||||
entry: TreeEntry = SchemaField(
|
||||
title="Tree Entry", description="A file or directory in the tree"
|
||||
)
|
||||
entries: list[TreeEntry] = SchemaField(
|
||||
description="List of all files and directories in the tree"
|
||||
)
|
||||
truncated: bool = SchemaField(
|
||||
description="Whether the tree was truncated due to size"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if getting tree failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="89c5c0ec-172e-4001-a32c-bdfe4d0c9e81",
|
||||
description="This block lists the entire file tree of a GitHub repository recursively.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubGetRepositoryTreeBlock.Input,
|
||||
output_schema=GithubGetRepositoryTreeBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "main",
|
||||
"recursive": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("truncated", False),
|
||||
(
|
||||
"entries",
|
||||
[
|
||||
{
|
||||
"path": "src/main.py",
|
||||
"type": "blob",
|
||||
"size": 1234,
|
||||
"sha": "abc123",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"entry",
|
||||
{
|
||||
"path": "src/main.py",
|
||||
"type": "blob",
|
||||
"size": 1234,
|
||||
"sha": "abc123",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_tree": lambda *args, **kwargs: (
|
||||
False,
|
||||
[
|
||||
{
|
||||
"path": "src/main.py",
|
||||
"type": "blob",
|
||||
"size": 1234,
|
||||
"sha": "abc123",
|
||||
}
|
||||
],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_tree(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
branch: str,
|
||||
recursive: bool,
|
||||
) -> tuple[bool, list[Output.TreeEntry]]:
|
||||
api = get_api(credentials)
|
||||
tree_url = repo_url + f"/git/trees/{quote(branch, safe='')}"
|
||||
params = {"recursive": "1"} if recursive else {}
|
||||
response = await api.get(tree_url, params=params)
|
||||
data = response.json()
|
||||
entries: list[GithubGetRepositoryTreeBlock.Output.TreeEntry] = [
|
||||
GithubGetRepositoryTreeBlock.Output.TreeEntry(
|
||||
path=item["path"],
|
||||
type=item["type"],
|
||||
size=item.get("size", 0),
|
||||
sha=item["sha"],
|
||||
)
|
||||
for item in data["tree"]
|
||||
]
|
||||
return data.get("truncated", False), entries
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
truncated, entries = await self.get_tree(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.recursive,
|
||||
)
|
||||
yield "truncated", truncated
|
||||
yield "entries", entries
|
||||
for entry in entries:
|
||||
yield "entry", entry
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,120 +0,0 @@
|
||||
import inspect
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.github._auth import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
from backend.blocks.github.commits import FileOperation, GithubMultiFileCommitBlock
|
||||
from backend.blocks.github.pull_requests import (
|
||||
GithubMergePullRequestBlock,
|
||||
prepare_pr_api_url,
|
||||
)
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
# ── prepare_pr_api_url tests ──
|
||||
|
||||
|
||||
class TestPreparePrApiUrl:
|
||||
def test_https_scheme_preserved(self):
|
||||
result = prepare_pr_api_url("https://github.com/owner/repo/pull/42", "merge")
|
||||
assert result == "https://github.com/owner/repo/pulls/42/merge"
|
||||
|
||||
def test_http_scheme_preserved(self):
|
||||
result = prepare_pr_api_url("http://github.com/owner/repo/pull/1", "files")
|
||||
assert result == "http://github.com/owner/repo/pulls/1/files"
|
||||
|
||||
def test_no_scheme_defaults_to_https(self):
|
||||
result = prepare_pr_api_url("github.com/owner/repo/pull/5", "merge")
|
||||
assert result == "https://github.com/owner/repo/pulls/5/merge"
|
||||
|
||||
def test_reviewers_path(self):
|
||||
result = prepare_pr_api_url(
|
||||
"https://github.com/owner/repo/pull/99", "requested_reviewers"
|
||||
)
|
||||
assert result == "https://github.com/owner/repo/pulls/99/requested_reviewers"
|
||||
|
||||
def test_invalid_url_returned_as_is(self):
|
||||
url = "https://example.com/not-a-pr"
|
||||
assert prepare_pr_api_url(url, "merge") == url
|
||||
|
||||
def test_empty_string(self):
|
||||
assert prepare_pr_api_url("", "merge") == ""
|
||||
|
||||
|
||||
# ── Error-path block tests ──
|
||||
# When a block's run() yields ("error", msg), _execute() converts it to a
|
||||
# BlockExecutionError. We call block.execute() directly (not execute_block_test,
|
||||
# which returns early on empty test_output).
|
||||
|
||||
|
||||
def _mock_block(block, mocks: dict):
|
||||
"""Apply mocks to a block's static methods, wrapping sync mocks as async."""
|
||||
for name, mock_fn in mocks.items():
|
||||
original = getattr(block, name)
|
||||
if inspect.iscoroutinefunction(original):
|
||||
|
||||
async def async_mock(*args, _fn=mock_fn, **kwargs):
|
||||
return _fn(*args, **kwargs)
|
||||
|
||||
setattr(block, name, async_mock)
|
||||
else:
|
||||
setattr(block, name, mock_fn)
|
||||
|
||||
|
||||
def _raise(exc: Exception):
|
||||
"""Helper that returns a callable which raises the given exception."""
|
||||
|
||||
def _raiser(*args, **kwargs):
|
||||
raise exc
|
||||
|
||||
return _raiser
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_pr_error_path():
|
||||
block = GithubMergePullRequestBlock()
|
||||
_mock_block(block, {"merge_pr": _raise(RuntimeError("PR not mergeable"))})
|
||||
input_data = {
|
||||
"pr_url": "https://github.com/owner/repo/pull/1",
|
||||
"merge_method": "squash",
|
||||
"commit_title": "",
|
||||
"commit_message": "",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
with pytest.raises(BlockExecutionError, match="PR not mergeable"):
|
||||
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_file_commit_error_path():
|
||||
block = GithubMultiFileCommitBlock()
|
||||
_mock_block(block, {"multi_file_commit": _raise(RuntimeError("ref update failed"))})
|
||||
input_data = {
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "feature",
|
||||
"commit_message": "test",
|
||||
"files": [{"path": "a.py", "content": "x", "operation": "upsert"}],
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
with pytest.raises(BlockExecutionError, match="ref update failed"):
|
||||
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
|
||||
# ── FileOperation enum tests ──
|
||||
|
||||
|
||||
class TestFileOperation:
|
||||
def test_upsert_value(self):
|
||||
assert FileOperation.UPSERT == "upsert"
|
||||
|
||||
def test_delete_value(self):
|
||||
assert FileOperation.DELETE == "delete"
|
||||
|
||||
def test_invalid_value_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
FileOperation("create")
|
||||
|
||||
def test_invalid_value_raises_typo(self):
|
||||
with pytest.raises(ValueError):
|
||||
FileOperation("upser")
|
||||
@@ -241,8 +241,8 @@ class GmailBase(Block, ABC):
|
||||
h.ignore_links = False
|
||||
h.ignore_images = True
|
||||
return h.handle(html_content)
|
||||
except Exception:
|
||||
# Keep extraction resilient if html2text is unavailable or fails.
|
||||
except ImportError:
|
||||
# Fallback: return raw HTML if html2text is not available
|
||||
return html_content
|
||||
|
||||
# Handle content stored as attachment
|
||||
|
||||
@@ -140,31 +140,19 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
# OpenRouter models
|
||||
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
|
||||
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
|
||||
GEMINI_2_5_PRO_PREVIEW = "google/gemini-2.5-pro-preview-03-25"
|
||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro"
|
||||
GEMINI_3_1_PRO_PREVIEW = "google/gemini-3.1-pro-preview"
|
||||
GEMINI_3_FLASH_PREVIEW = "google/gemini-3-flash-preview"
|
||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
|
||||
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
|
||||
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
|
||||
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
|
||||
GEMINI_3_1_FLASH_LITE_PREVIEW = "google/gemini-3.1-flash-lite-preview"
|
||||
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
|
||||
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
|
||||
MISTRAL_NEMO = "mistralai/mistral-nemo"
|
||||
MISTRAL_LARGE_3 = "mistralai/mistral-large-2512"
|
||||
MISTRAL_MEDIUM_3_1 = "mistralai/mistral-medium-3.1"
|
||||
MISTRAL_SMALL_3_2 = "mistralai/mistral-small-3.2-24b-instruct"
|
||||
CODESTRAL = "mistralai/codestral-2508"
|
||||
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
|
||||
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
|
||||
COHERE_COMMAND_A_03_2025 = "cohere/command-a-03-2025"
|
||||
COHERE_COMMAND_A_TRANSLATE_08_2025 = "cohere/command-a-translate-08-2025"
|
||||
COHERE_COMMAND_A_REASONING_08_2025 = "cohere/command-a-reasoning-08-2025"
|
||||
COHERE_COMMAND_A_VISION_07_2025 = "cohere/command-a-vision-07-2025"
|
||||
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
|
||||
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
|
||||
PERPLEXITY_SONAR = "perplexity/sonar"
|
||||
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
|
||||
PERPLEXITY_SONAR_REASONING_PRO = "perplexity/sonar-reasoning-pro"
|
||||
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
|
||||
@@ -172,11 +160,9 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
|
||||
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
|
||||
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
|
||||
MICROSOFT_PHI_4 = "microsoft/phi-4"
|
||||
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
|
||||
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
|
||||
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
|
||||
GROK_3 = "x-ai/grok-3"
|
||||
GROK_4 = "x-ai/grok-4"
|
||||
GROK_4_FAST = "x-ai/grok-4-fast"
|
||||
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
|
||||
@@ -354,41 +340,17 @@ MODEL_METADATA = {
|
||||
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
|
||||
),
|
||||
# https://openrouter.ai/models
|
||||
LlmModel.GEMINI_2_5_PRO_PREVIEW: ModelMetadata(
|
||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
1050000,
|
||||
8192,
|
||||
"Gemini 2.5 Pro Preview 03.25",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
2,
|
||||
),
|
||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 2.5 Pro",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
2,
|
||||
),
|
||||
LlmModel.GEMINI_3_1_PRO_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 3.1 Pro Preview",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
2,
|
||||
),
|
||||
LlmModel.GEMINI_3_FLASH_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 3 Flash Preview",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
1,
|
||||
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
|
||||
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
|
||||
),
|
||||
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
|
||||
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
|
||||
@@ -396,15 +358,6 @@ MODEL_METADATA = {
|
||||
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
|
||||
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
|
||||
),
|
||||
LlmModel.GEMINI_3_1_FLASH_LITE_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 3.1 Flash Lite Preview",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
1,
|
||||
),
|
||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
@@ -426,78 +379,12 @@ MODEL_METADATA = {
|
||||
LlmModel.MISTRAL_NEMO: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
|
||||
),
|
||||
LlmModel.MISTRAL_LARGE_3: ModelMetadata(
|
||||
"open_router",
|
||||
262144,
|
||||
None,
|
||||
"Mistral Large 3 2512",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
2,
|
||||
),
|
||||
LlmModel.MISTRAL_MEDIUM_3_1: ModelMetadata(
|
||||
"open_router",
|
||||
131072,
|
||||
None,
|
||||
"Mistral Medium 3.1",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
2,
|
||||
),
|
||||
LlmModel.MISTRAL_SMALL_3_2: ModelMetadata(
|
||||
"open_router",
|
||||
131072,
|
||||
131072,
|
||||
"Mistral Small 3.2 24B",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
1,
|
||||
),
|
||||
LlmModel.CODESTRAL: ModelMetadata(
|
||||
"open_router",
|
||||
256000,
|
||||
None,
|
||||
"Codestral 2508",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
1,
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_03_2025: ModelMetadata(
|
||||
"open_router", 256000, 8192, "Command A 03.2025", "OpenRouter", "Cohere", 2
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_TRANSLATE_08_2025: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
8192,
|
||||
"Command A Translate 08.2025",
|
||||
"OpenRouter",
|
||||
"Cohere",
|
||||
2,
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_REASONING_08_2025: ModelMetadata(
|
||||
"open_router",
|
||||
256000,
|
||||
32768,
|
||||
"Command A Reasoning 08.2025",
|
||||
"OpenRouter",
|
||||
"Cohere",
|
||||
3,
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_VISION_07_2025: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
8192,
|
||||
"Command A Vision 07.2025",
|
||||
"OpenRouter",
|
||||
"Cohere",
|
||||
2,
|
||||
),
|
||||
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
|
||||
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
|
||||
),
|
||||
@@ -510,15 +397,6 @@ MODEL_METADATA = {
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
|
||||
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR_REASONING_PRO: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
8000,
|
||||
"Sonar Reasoning Pro",
|
||||
"OpenRouter",
|
||||
"Perplexity",
|
||||
2,
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
@@ -564,9 +442,6 @@ MODEL_METADATA = {
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
|
||||
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
|
||||
),
|
||||
LlmModel.MICROSOFT_PHI_4: ModelMetadata(
|
||||
"open_router", 16384, 16384, "Phi-4", "OpenRouter", "Microsoft", 1
|
||||
),
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
|
||||
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
|
||||
),
|
||||
@@ -576,15 +451,6 @@ MODEL_METADATA = {
|
||||
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
|
||||
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
|
||||
),
|
||||
LlmModel.GROK_3: ModelMetadata(
|
||||
"open_router",
|
||||
131072,
|
||||
131072,
|
||||
"Grok 3",
|
||||
"OpenRouter",
|
||||
"xAI",
|
||||
2,
|
||||
),
|
||||
LlmModel.GROK_4: ModelMetadata(
|
||||
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
|
||||
),
|
||||
|
||||
@@ -4,7 +4,7 @@ from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
import openai
|
||||
from pydantic import SecretStr, field_validator
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
@@ -13,7 +13,6 @@ from backend.blocks._base import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -36,20 +35,6 @@ class PerplexityModel(str, Enum):
|
||||
SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||
|
||||
|
||||
def _sanitize_perplexity_model(value: Any) -> PerplexityModel:
|
||||
"""Return a valid PerplexityModel, falling back to SONAR for invalid values."""
|
||||
if isinstance(value, PerplexityModel):
|
||||
return value
|
||||
try:
|
||||
return PerplexityModel(value)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Invalid PerplexityModel '{value}', "
|
||||
f"falling back to {PerplexityModel.SONAR.value}"
|
||||
)
|
||||
return PerplexityModel.SONAR
|
||||
|
||||
|
||||
PerplexityCredentials = CredentialsMetaInput[
|
||||
Literal[ProviderName.OPEN_ROUTER], Literal["api_key"]
|
||||
]
|
||||
@@ -88,25 +73,6 @@ class PerplexityBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
credentials: PerplexityCredentials = PerplexityCredentialsField()
|
||||
|
||||
@field_validator("model", mode="before")
|
||||
@classmethod
|
||||
def fallback_invalid_model(cls, v: Any) -> PerplexityModel:
|
||||
"""Fall back to SONAR if the model value is not a valid
|
||||
PerplexityModel (e.g. an OpenAI model ID set by the agent
|
||||
generator)."""
|
||||
return _sanitize_perplexity_model(v)
|
||||
|
||||
@classmethod
|
||||
def validate_data(cls, data: BlockInput) -> str | None:
|
||||
"""Sanitize the model field before JSON schema validation so that
|
||||
invalid values are replaced with the default instead of raising a
|
||||
BlockInputError."""
|
||||
model_value = data.get("model")
|
||||
if model_value is not None:
|
||||
data["model"] = _sanitize_perplexity_model(model_value).value
|
||||
return super().validate_data(data)
|
||||
|
||||
system_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
default="",
|
||||
|
||||
@@ -2232,7 +2232,6 @@ class DeleteRedditPostBlock(Block):
|
||||
("post_id", "abc123"),
|
||||
],
|
||||
test_mock={"delete_post": lambda creds, post_id: True},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -2291,7 +2290,6 @@ class DeleteRedditCommentBlock(Block):
|
||||
("comment_id", "xyz789"),
|
||||
],
|
||||
test_mock={"delete_comment": lambda creds, comment_id: True},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -72,7 +72,6 @@ class Slant3DCreateOrderBlock(Slant3DBlockBase):
|
||||
"_make_request": lambda *args, **kwargs: {"orderId": "314144241"},
|
||||
"_convert_to_color": lambda *args, **kwargs: "black",
|
||||
},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
async def run(
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
"""Unit tests for PerplexityBlock model fallback behavior."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.perplexity import (
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
PerplexityBlock,
|
||||
PerplexityModel,
|
||||
)
|
||||
|
||||
|
||||
def _make_input(**overrides) -> dict:
|
||||
defaults = {
|
||||
"prompt": "test query",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return defaults
|
||||
|
||||
|
||||
class TestPerplexityModelFallback:
|
||||
"""Tests for fallback_invalid_model field_validator."""
|
||||
|
||||
def test_invalid_model_falls_back_to_sonar(self):
|
||||
inp = PerplexityBlock.Input(**_make_input(model="gpt-5.2-2025-12-11"))
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
|
||||
def test_another_invalid_model_falls_back_to_sonar(self):
|
||||
inp = PerplexityBlock.Input(**_make_input(model="gpt-4o"))
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
|
||||
def test_valid_model_string_is_kept(self):
|
||||
inp = PerplexityBlock.Input(**_make_input(model="perplexity/sonar-pro"))
|
||||
assert inp.model == PerplexityModel.SONAR_PRO
|
||||
|
||||
def test_valid_enum_value_is_kept(self):
|
||||
inp = PerplexityBlock.Input(
|
||||
**_make_input(model=PerplexityModel.SONAR_DEEP_RESEARCH)
|
||||
)
|
||||
assert inp.model == PerplexityModel.SONAR_DEEP_RESEARCH
|
||||
|
||||
def test_default_model_when_omitted(self):
|
||||
inp = PerplexityBlock.Input(**_make_input())
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_value",
|
||||
[
|
||||
"perplexity/sonar",
|
||||
"perplexity/sonar-pro",
|
||||
"perplexity/sonar-deep-research",
|
||||
],
|
||||
)
|
||||
def test_all_valid_models_accepted(self, model_value: str):
|
||||
inp = PerplexityBlock.Input(**_make_input(model=model_value))
|
||||
assert inp.model.value == model_value
|
||||
|
||||
|
||||
class TestPerplexityValidateData:
|
||||
"""Tests for validate_data which runs during block execution (before
|
||||
Pydantic instantiation). Invalid models must be sanitized here so
|
||||
JSON schema validation does not reject them."""
|
||||
|
||||
def test_invalid_model_sanitized_before_schema_validation(self):
|
||||
data = _make_input(model="gpt-5.2-2025-12-11")
|
||||
error = PerplexityBlock.Input.validate_data(data)
|
||||
assert error is None
|
||||
assert data["model"] == PerplexityModel.SONAR.value
|
||||
|
||||
def test_valid_model_unchanged_by_validate_data(self):
|
||||
data = _make_input(model="perplexity/sonar-pro")
|
||||
error = PerplexityBlock.Input.validate_data(data)
|
||||
assert error is None
|
||||
assert data["model"] == "perplexity/sonar-pro"
|
||||
|
||||
def test_missing_model_uses_default(self):
|
||||
data = _make_input() # no model key
|
||||
error = PerplexityBlock.Input.validate_data(data)
|
||||
assert error is None
|
||||
inp = PerplexityBlock.Input(**data)
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
@@ -18,13 +18,11 @@ from langfuse import propagate_attributes
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.rate_limit import record_token_usage
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
@@ -38,7 +36,6 @@ from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from backend.copilot.service import (
|
||||
_build_system_prompt,
|
||||
@@ -49,11 +46,7 @@ from backend.copilot.service import (
|
||||
from backend.copilot.tools import execute_tool, get_available_tools
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.prompt import (
|
||||
compress_context,
|
||||
estimate_token_count,
|
||||
estimate_token_count_str,
|
||||
)
|
||||
from backend.util.prompt import compress_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -228,9 +221,6 @@ async def stream_chat_completion_baseline(
|
||||
text_block_id = str(uuid.uuid4())
|
||||
text_started = False
|
||||
step_open = False
|
||||
# Token usage accumulators — populated from streaming chunks
|
||||
turn_prompt_tokens = 0
|
||||
turn_completion_tokens = 0
|
||||
try:
|
||||
for _round in range(_MAX_TOOL_ROUNDS):
|
||||
# Open a new step for each LLM round
|
||||
@@ -242,7 +232,6 @@ async def stream_chat_completion_baseline(
|
||||
model=config.model,
|
||||
messages=openai_messages,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
if tools:
|
||||
create_kwargs["tools"] = tools
|
||||
@@ -253,18 +242,7 @@ async def stream_chat_completion_baseline(
|
||||
tool_calls_by_index: dict[int, dict[str, str]] = {}
|
||||
|
||||
async for chunk in response:
|
||||
# Capture token usage from the streaming chunk.
|
||||
# OpenRouter normalises all providers into OpenAI format
|
||||
# where prompt_tokens already includes cached tokens
|
||||
# (unlike Anthropic's native API). Use += to sum all
|
||||
# tool-call rounds since each API call is independent.
|
||||
if chunk.usage:
|
||||
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
|
||||
turn_completion_tokens += chunk.usage.completion_tokens or 0
|
||||
|
||||
if not chunk.choices:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
if not delta:
|
||||
continue
|
||||
|
||||
@@ -433,53 +411,6 @@ async def stream_chat_completion_baseline(
|
||||
except Exception:
|
||||
logger.warning("[Baseline] Langfuse trace context teardown failed")
|
||||
|
||||
# Fallback: estimate tokens via tiktoken when the provider does
|
||||
# not honour stream_options={"include_usage": True}.
|
||||
# Count the full message list (system + history + turn) since
|
||||
# each API call sends the complete context window.
|
||||
if turn_prompt_tokens == 0 and turn_completion_tokens == 0:
|
||||
turn_prompt_tokens = max(
|
||||
estimate_token_count(openai_messages, model=config.model), 0
|
||||
)
|
||||
turn_completion_tokens = max(
|
||||
estimate_token_count_str(assistant_text, model=config.model), 0
|
||||
)
|
||||
logger.info(
|
||||
"[Baseline] No streaming usage reported; estimated tokens: "
|
||||
"prompt=%d, completion=%d",
|
||||
turn_prompt_tokens,
|
||||
turn_completion_tokens,
|
||||
)
|
||||
|
||||
# Emit token usage and update session for persistence
|
||||
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
|
||||
total_tokens = turn_prompt_tokens + turn_completion_tokens
|
||||
session.usage.append(
|
||||
Usage(
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"[Baseline] Turn usage: prompt=%d, completion=%d, total=%d",
|
||||
turn_prompt_tokens,
|
||||
turn_completion_tokens,
|
||||
total_tokens,
|
||||
)
|
||||
# Record for rate limiting counters
|
||||
if user_id:
|
||||
try:
|
||||
await record_token_usage(
|
||||
user_id=user_id,
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
)
|
||||
except Exception as usage_err:
|
||||
logger.warning(
|
||||
"[Baseline] Failed to record token usage: %s", usage_err
|
||||
)
|
||||
|
||||
# Persist assistant response
|
||||
if assistant_text:
|
||||
session.messages.append(
|
||||
@@ -490,16 +421,4 @@ async def stream_chat_completion_baseline(
|
||||
except Exception as persist_err:
|
||||
logger.error("[Baseline] Failed to persist session: %s", persist_err)
|
||||
|
||||
# Yield usage and finish AFTER try/finally (not inside finally).
|
||||
# PEP 525 prohibits yielding from finally in async generators during
|
||||
# aclose() — doing so raises RuntimeError on client disconnect.
|
||||
# On GeneratorExit the client is already gone, so unreachable yields
|
||||
# are harmless; on normal completion they reach the SSE stream.
|
||||
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
|
||||
yield StreamUsage(
|
||||
promptTokens=turn_prompt_tokens,
|
||||
completionTokens=turn_completion_tokens,
|
||||
totalTokens=turn_prompt_tokens + turn_completion_tokens,
|
||||
)
|
||||
|
||||
yield StreamFinish()
|
||||
|
||||
@@ -70,20 +70,6 @@ class ChatConfig(BaseSettings):
|
||||
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
|
||||
)
|
||||
|
||||
# Rate limiting — token-based limits per day and per week.
|
||||
# Each CoPilot turn consumes ~10-15K tokens (system prompt + tool schemas + response),
|
||||
# so 2.5M daily allows ~170-250 turns/day which is reasonable for normal use.
|
||||
# TODO: These are global deploy-time constants. For per-user or per-plan limits,
|
||||
# move to the database (e.g. UserPlan table) and look up in get_usage_status.
|
||||
daily_token_limit: int = Field(
|
||||
default=2_500_000,
|
||||
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
|
||||
)
|
||||
weekly_token_limit: int = Field(
|
||||
default=12_500_000,
|
||||
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
|
||||
)
|
||||
|
||||
# Claude Agent SDK Configuration
|
||||
use_claude_agent_sdk: bool = Field(
|
||||
default=True,
|
||||
@@ -129,7 +115,7 @@ class ChatConfig(BaseSettings):
|
||||
description="E2B sandbox template to use for copilot sessions.",
|
||||
)
|
||||
e2b_sandbox_timeout: int = Field(
|
||||
default=300, # 5 min safety net — explicit per-turn pause is the primary mechanism
|
||||
default=10800, # 3 hours — wall-clock timeout, not idle; explicit pause is primary
|
||||
description="E2B sandbox running-time timeout (seconds). "
|
||||
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
|
||||
"mechanism; this is the safety net.",
|
||||
|
||||
@@ -73,9 +73,6 @@ class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
# Cache breakdown (Anthropic-specific; zero for non-Anthropic models)
|
||||
cache_read_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
|
||||
|
||||
class ChatSessionInfo(BaseModel):
|
||||
|
||||
@@ -52,11 +52,6 @@ Examples:
|
||||
You can embed a reference inside any string argument, or use it as the entire
|
||||
value. Multiple references in one argument are all expanded.
|
||||
|
||||
**Type coercion**: The platform automatically coerces expanded string values
|
||||
to match the block's expected input types. For example, if a block expects
|
||||
`list[list[str]]` and you pass a string containing a JSON array (e.g. from
|
||||
an @@agptfile: expansion), the string will be parsed into the correct type.
|
||||
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
|
||||
@@ -1,253 +0,0 @@
|
||||
"""CoPilot rate limiting based on token usage.
|
||||
|
||||
Uses Redis fixed-window counters to track per-user token consumption
|
||||
with configurable daily and weekly limits. Daily windows reset at
|
||||
midnight UTC; weekly windows reset at ISO week boundary (Monday 00:00
|
||||
UTC). Fails open when Redis is unavailable to avoid blocking users.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis key prefixes
|
||||
_PREFIX = "copilot:usage"
|
||||
|
||||
|
||||
class UsageWindow(BaseModel):
|
||||
"""Usage within a single time window."""
|
||||
|
||||
used: int
|
||||
limit: int = Field(
|
||||
description="Maximum tokens allowed in this window. 0 means unlimited."
|
||||
)
|
||||
resets_at: datetime
|
||||
|
||||
|
||||
class CoPilotUsageStatus(BaseModel):
|
||||
"""Current usage status for a user across all windows."""
|
||||
|
||||
daily: UsageWindow
|
||||
weekly: UsageWindow
|
||||
|
||||
|
||||
class RateLimitExceeded(Exception):
|
||||
"""Raised when a user exceeds their CoPilot usage limit."""
|
||||
|
||||
def __init__(self, window: str, resets_at: datetime):
|
||||
self.window = window
|
||||
self.resets_at = resets_at
|
||||
delta = resets_at - datetime.now(UTC)
|
||||
total_secs = delta.total_seconds()
|
||||
if total_secs <= 0:
|
||||
time_str = "now"
|
||||
else:
|
||||
hours = int(total_secs // 3600)
|
||||
minutes = int((total_secs % 3600) // 60)
|
||||
time_str = f"{hours}h {minutes}m" if hours > 0 else f"{minutes}m"
|
||||
super().__init__(
|
||||
f"You've reached your {window} usage limit. Resets in {time_str}."
|
||||
)
|
||||
|
||||
|
||||
def _daily_key(user_id: str, now: datetime | None = None) -> str:
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
return f"{_PREFIX}:daily:{user_id}:{now.strftime('%Y-%m-%d')}"
|
||||
|
||||
|
||||
def _weekly_key(user_id: str, now: datetime | None = None) -> str:
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
year, week, _ = now.isocalendar()
|
||||
return f"{_PREFIX}:weekly:{user_id}:{year}-W{week:02d}"
|
||||
|
||||
|
||||
def _daily_reset_time(now: datetime | None = None) -> datetime:
|
||||
"""Calculate when the current daily window resets (next midnight UTC)."""
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
|
||||
|
||||
|
||||
def _weekly_reset_time(now: datetime | None = None) -> datetime:
|
||||
"""Calculate when the current weekly window resets (next Monday 00:00 UTC).
|
||||
|
||||
On Monday itself, ``(7 - weekday) % 7`` is 0; the ``or 7`` fallback
|
||||
pushes to *next* Monday so the current week's window stays open.
|
||||
"""
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
days_until_monday = (7 - now.weekday()) % 7 or 7
|
||||
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(
|
||||
days=days_until_monday
|
||||
)
|
||||
|
||||
|
||||
async def _fetch_counters(user_id: str, now: datetime) -> tuple[int, int]:
|
||||
"""Fetch daily and weekly token counters from Redis.
|
||||
|
||||
Returns (daily_used, weekly_used). Returns (0, 0) if Redis is unavailable.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
daily_raw, weekly_raw = await asyncio.gather(
|
||||
redis.get(_daily_key(user_id, now=now)),
|
||||
redis.get(_weekly_key(user_id, now=now)),
|
||||
)
|
||||
return int(daily_raw or 0), int(weekly_raw or 0)
|
||||
|
||||
|
||||
async def get_usage_status(
|
||||
user_id: str,
|
||||
daily_token_limit: int,
|
||||
weekly_token_limit: int,
|
||||
) -> CoPilotUsageStatus:
|
||||
"""Get current usage status for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
daily_token_limit: Max tokens per day (0 = unlimited).
|
||||
weekly_token_limit: Max tokens per week (0 = unlimited).
|
||||
|
||||
Returns:
|
||||
CoPilotUsageStatus with current usage and limits.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
daily_used, weekly_used = await _fetch_counters(user_id, now)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Redis unavailable for usage status, returning zeros", exc_info=True
|
||||
)
|
||||
daily_used, weekly_used = 0, 0
|
||||
|
||||
return CoPilotUsageStatus(
|
||||
daily=UsageWindow(
|
||||
used=daily_used,
|
||||
limit=daily_token_limit,
|
||||
resets_at=_daily_reset_time(now=now),
|
||||
),
|
||||
weekly=UsageWindow(
|
||||
used=weekly_used,
|
||||
limit=weekly_token_limit,
|
||||
resets_at=_weekly_reset_time(now=now),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def check_rate_limit(
|
||||
user_id: str,
|
||||
daily_token_limit: int,
|
||||
weekly_token_limit: int,
|
||||
) -> None:
|
||||
"""Check if user is within rate limits. Raises RateLimitExceeded if not.
|
||||
|
||||
This is a pre-turn soft check. The authoritative usage counter is updated
|
||||
by ``record_token_usage()`` after the turn completes. Under concurrency,
|
||||
two parallel turns may both pass this check against the same snapshot.
|
||||
This is acceptable because token-based limits are approximate by nature
|
||||
(the exact token count is unknown until after generation).
|
||||
|
||||
Fails open: if Redis is unavailable, allows the request.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
daily_used, weekly_used = await _fetch_counters(user_id, now)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Redis unavailable for rate limit check, allowing request", exc_info=True
|
||||
)
|
||||
return
|
||||
|
||||
if daily_token_limit > 0 and daily_used >= daily_token_limit:
|
||||
raise RateLimitExceeded("daily", _daily_reset_time(now=now))
|
||||
|
||||
if weekly_token_limit > 0 and weekly_used >= weekly_token_limit:
|
||||
raise RateLimitExceeded("weekly", _weekly_reset_time(now=now))
|
||||
|
||||
|
||||
async def record_token_usage(
|
||||
user_id: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
*,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
) -> None:
|
||||
"""Record token usage for a user across all windows.
|
||||
|
||||
Uses cost-weighted counting so cached tokens don't unfairly penalise
|
||||
multi-turn conversations. Anthropic's pricing:
|
||||
- uncached input: 100%
|
||||
- cache creation: 25%
|
||||
- cache read: 10%
|
||||
- output: 100%
|
||||
|
||||
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
|
||||
from the API response). Cache counts are passed separately.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
prompt_tokens: Uncached input tokens.
|
||||
completion_tokens: Output tokens.
|
||||
cache_read_tokens: Tokens served from prompt cache (10% cost).
|
||||
cache_creation_tokens: Tokens written to prompt cache (25% cost).
|
||||
"""
|
||||
weighted_input = (
|
||||
prompt_tokens
|
||||
+ round(cache_creation_tokens * 0.25)
|
||||
+ round(cache_read_tokens * 0.1)
|
||||
)
|
||||
total = weighted_input + completion_tokens
|
||||
if total <= 0:
|
||||
return
|
||||
|
||||
raw_total = (
|
||||
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
|
||||
)
|
||||
logger.info(
|
||||
"Recording token usage for %s: raw=%d, weighted=%d "
|
||||
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
|
||||
user_id[:8],
|
||||
raw_total,
|
||||
total,
|
||||
prompt_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_tokens,
|
||||
completion_tokens,
|
||||
)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
pipe = redis.pipeline(transaction=False)
|
||||
|
||||
# Daily counter (expires at next midnight UTC)
|
||||
d_key = _daily_key(user_id, now=now)
|
||||
pipe.incrby(d_key, total)
|
||||
seconds_until_daily_reset = int(
|
||||
(_daily_reset_time(now=now) - now).total_seconds()
|
||||
)
|
||||
pipe.expire(d_key, max(seconds_until_daily_reset, 1))
|
||||
|
||||
# Weekly counter (expires end of week)
|
||||
w_key = _weekly_key(user_id, now=now)
|
||||
pipe.incrby(w_key, total)
|
||||
seconds_until_weekly_reset = int(
|
||||
(_weekly_reset_time(now=now) - now).total_seconds()
|
||||
)
|
||||
pipe.expire(w_key, max(seconds_until_weekly_reset, 1))
|
||||
|
||||
await pipe.execute()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Redis unavailable for recording token usage (tokens=%d)",
|
||||
total,
|
||||
exc_info=True,
|
||||
)
|
||||
@@ -1,334 +0,0 @@
|
||||
"""Unit tests for CoPilot rate limiting."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from .rate_limit import (
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
check_rate_limit,
|
||||
get_usage_status,
|
||||
record_token_usage,
|
||||
)
|
||||
|
||||
_USER = "test-user-rl"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RateLimitExceeded
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRateLimitExceeded:
|
||||
def test_message_contains_window_name(self):
|
||||
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1))
|
||||
assert "daily" in str(exc)
|
||||
|
||||
def test_message_contains_reset_time(self):
|
||||
exc = RateLimitExceeded(
|
||||
"weekly", datetime.now(UTC) + timedelta(hours=2, minutes=30)
|
||||
)
|
||||
msg = str(exc)
|
||||
# Allow for slight timing drift (29m or 30m)
|
||||
assert "2h " in msg
|
||||
assert "Resets in" in msg
|
||||
|
||||
def test_message_minutes_only_when_under_one_hour(self):
|
||||
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(minutes=15))
|
||||
msg = str(exc)
|
||||
assert "Resets in" in msg
|
||||
# Should not have "0h"
|
||||
assert "0h" not in msg
|
||||
|
||||
def test_message_says_now_when_resets_at_is_in_the_past(self):
|
||||
"""Negative delta (clock skew / stale TTL) should say 'now', not '-1h -30m'."""
|
||||
exc = RateLimitExceeded("daily", datetime.now(UTC) - timedelta(minutes=5))
|
||||
assert "Resets in now" in str(exc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_usage_status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetUsageStatus:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_redis_values(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["500", "2000"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert isinstance(status, CoPilotUsageStatus)
|
||||
assert status.daily.used == 500
|
||||
assert status.daily.limit == 10000
|
||||
assert status.weekly.used == 2000
|
||||
assert status.weekly.limit == 50000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_zeros_when_redis_unavailable(self):
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert status.daily.used == 0
|
||||
assert status.weekly.used == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_none_daily_counter(self):
|
||||
"""Daily counter is None (new day), weekly has usage."""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=[None, "3000"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert status.daily.used == 0
|
||||
assert status.weekly.used == 3000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_none_weekly_counter(self):
|
||||
"""Weekly counter is None (start of week), daily has usage."""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["500", None])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert status.daily.used == 500
|
||||
assert status.weekly.used == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resets_at_daily_is_next_midnight_utc(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["0", "0"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
# Daily reset should be within 24h
|
||||
assert status.daily.resets_at > now
|
||||
assert status.daily.resets_at <= now + timedelta(hours=24, seconds=5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_rate_limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckRateLimit:
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_when_under_limit(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["100", "200"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
# Should not raise
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_when_daily_limit_exceeded(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["10000", "200"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
assert exc_info.value.window == "daily"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_when_weekly_limit_exceeded(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["100", "50000"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
assert exc_info.value.window == "weekly"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_when_redis_unavailable(self):
|
||||
"""Fail-open: allow requests when Redis is down."""
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
# Should not raise
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_check_when_limit_is_zero(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["999999", "999999"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
# Should not raise — limits of 0 mean unlimited
|
||||
await check_rate_limit(_USER, daily_token_limit=0, weekly_token_limit=0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# record_token_usage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRecordTokenUsage:
|
||||
@staticmethod
|
||||
def _make_pipeline_mock() -> MagicMock:
|
||||
"""Create a pipeline mock with sync methods and async execute."""
|
||||
pipe = MagicMock()
|
||||
pipe.execute = AsyncMock(return_value=[])
|
||||
return pipe
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_increments_redis_counters(self):
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
# Should call incrby twice (daily + weekly) with total=150
|
||||
incrby_calls = mock_pipe.incrby.call_args_list
|
||||
assert len(incrby_calls) == 2
|
||||
assert incrby_calls[0].args[1] == 150 # daily
|
||||
assert incrby_calls[1].args[1] == 150 # weekly
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_zero_tokens(self):
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(_USER, prompt_tokens=0, completion_tokens=0)
|
||||
|
||||
# Should not call pipeline at all
|
||||
mock_redis.pipeline.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sets_expire_on_both_keys(self):
|
||||
"""Pipeline should call expire for both daily and weekly keys."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
expire_calls = mock_pipe.expire.call_args_list
|
||||
assert len(expire_calls) == 2
|
||||
|
||||
# Daily key TTL should be positive (seconds until next midnight)
|
||||
daily_ttl = expire_calls[0].args[1]
|
||||
assert daily_ttl >= 1
|
||||
|
||||
# Weekly key TTL should be positive (seconds until next Monday)
|
||||
weekly_ttl = expire_calls[1].args[1]
|
||||
assert weekly_ttl >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_redis_failure_gracefully(self):
|
||||
"""Should not raise when Redis is unavailable."""
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
# Should not raise
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_weighted_counting(self):
|
||||
"""Cached tokens should be weighted: cache_read=10%, cache_create=25%."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(
|
||||
_USER,
|
||||
prompt_tokens=100, # uncached → 100
|
||||
completion_tokens=50, # output → 50
|
||||
cache_read_tokens=10000, # 10% → 1000
|
||||
cache_creation_tokens=400, # 25% → 100
|
||||
)
|
||||
|
||||
# Expected weighted total: 100 + 1000 + 100 + 50 = 1250
|
||||
incrby_calls = mock_pipe.incrby.call_args_list
|
||||
assert len(incrby_calls) == 2
|
||||
assert incrby_calls[0].args[1] == 1250 # daily
|
||||
assert incrby_calls[1].args[1] == 1250 # weekly
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_redis_error_during_pipeline_execute(self):
|
||||
"""Should not raise when pipeline.execute() fails with RedisError."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_pipe.execute = AsyncMock(side_effect=RedisError("Pipeline failed"))
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
# Should not raise — fail-open
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
@@ -186,29 +186,12 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
|
||||
|
||||
class StreamUsage(StreamBaseResponse):
|
||||
"""Token usage statistics.
|
||||
|
||||
Emitted as an SSE comment so the Vercel AI SDK parser ignores it
|
||||
(it uses z.strictObject() and rejects unknown event types).
|
||||
Usage data is recorded server-side (session DB + Redis counters).
|
||||
"""
|
||||
"""Token usage statistics."""
|
||||
|
||||
type: ResponseType = ResponseType.USAGE
|
||||
promptTokens: int = Field(..., description="Number of uncached prompt tokens")
|
||||
promptTokens: int = Field(..., description="Number of prompt tokens")
|
||||
completionTokens: int = Field(..., description="Number of completion tokens")
|
||||
totalTokens: int = Field(
|
||||
..., description="Total number of tokens (raw, not weighted)"
|
||||
)
|
||||
cacheReadTokens: int = Field(
|
||||
default=0, description="Prompt tokens served from cache (10% cost)"
|
||||
)
|
||||
cacheCreationTokens: int = Field(
|
||||
default=0, description="Prompt tokens written to cache (25% cost)"
|
||||
)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Emit as SSE comment so the AI SDK parser ignores it."""
|
||||
return f": usage {self.model_dump_json(exclude_none=True)}\n\n"
|
||||
totalTokens: int = Field(..., description="Total number of tokens")
|
||||
|
||||
|
||||
class StreamError(StreamBaseResponse):
|
||||
|
||||
@@ -198,7 +198,6 @@ class CompactionTracker:
|
||||
|
||||
def reset_for_query(self) -> None:
|
||||
"""Reset per-query state before a new SDK query."""
|
||||
self._compact_start.clear()
|
||||
self._done = False
|
||||
self._start_emitted = False
|
||||
self._tool_call_id = ""
|
||||
|
||||
@@ -1,546 +0,0 @@
|
||||
"""End-to-end compaction flow test.
|
||||
|
||||
Simulates the full service.py compaction lifecycle using real-format
|
||||
JSONL session files — no SDK subprocess needed. Exercises:
|
||||
|
||||
1. TranscriptBuilder loads a "downloaded" transcript
|
||||
2. User query appended, assistant response streamed
|
||||
3. PreCompact hook fires → CompactionTracker.on_compact()
|
||||
4. Next message → emit_start_if_ready() yields spinner events
|
||||
5. Message after that → emit_end_if_ready() returns end events
|
||||
6. _read_compacted_entries() reads the CLI session file
|
||||
7. TranscriptBuilder.replace_entries() syncs state
|
||||
8. More messages appended post-compaction
|
||||
9. to_jsonl() exports full state for upload
|
||||
10. Fresh builder loads the export — roundtrip verified
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.response_model import (
|
||||
StreamFinishStep,
|
||||
StreamStartStep,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.sdk.compaction import CompactionTracker
|
||||
from backend.copilot.sdk.transcript import strip_progress_entries
|
||||
from backend.copilot.sdk.transcript_builder import TranscriptBuilder
|
||||
from backend.util import json
|
||||
|
||||
|
||||
def _make_jsonl(*entries: dict) -> str:
|
||||
return "\n".join(json.dumps(e) for e in entries) + "\n"
|
||||
|
||||
|
||||
def _run(coro):
|
||||
"""Run an async coroutine synchronously."""
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _read_compacted_entries(path: str) -> tuple[list[dict], str] | None:
|
||||
"""Test-only: read compacted entries from a session JSONL file.
|
||||
|
||||
Returns (parsed_dicts, jsonl_string) from the first ``isCompactSummary``
|
||||
entry onward, or ``None`` if no summary is found.
|
||||
"""
|
||||
content = Path(path).read_text()
|
||||
lines = content.strip().split("\n")
|
||||
compact_idx: int | None = None
|
||||
parsed: list[dict] = []
|
||||
raw_lines: list[str] = []
|
||||
for line in lines:
|
||||
if not line.strip():
|
||||
continue
|
||||
entry = json.loads(line, fallback=None)
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
parsed.append(entry)
|
||||
raw_lines.append(line.strip())
|
||||
if compact_idx is None and entry.get("isCompactSummary"):
|
||||
compact_idx = len(parsed) - 1
|
||||
if compact_idx is None:
|
||||
return None
|
||||
return parsed[compact_idx:], "\n".join(raw_lines[compact_idx:]) + "\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures: realistic CLI session file content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Pre-compaction conversation
|
||||
USER_1 = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "What files are in this project?"},
|
||||
}
|
||||
ASST_1_THINKING = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1-think",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_aaa",
|
||||
"type": "message",
|
||||
"content": [{"type": "thinking", "thinking": "Let me look at the files..."}],
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
ASST_1_TOOL = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1-tool",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_aaa",
|
||||
"type": "message",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "tu1",
|
||||
"name": "Bash",
|
||||
"input": {"command": "ls"},
|
||||
}
|
||||
],
|
||||
"stop_reason": "tool_use",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
TOOL_RESULT_1 = {
|
||||
"type": "user",
|
||||
"uuid": "tr1",
|
||||
"parentUuid": "a1-tool",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "tu1",
|
||||
"content": "file1.py\nfile2.py",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
ASST_1_TEXT = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1-text",
|
||||
"parentUuid": "tr1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_bbb",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "I found file1.py and file2.py."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
# Progress entries (should be stripped during upload)
|
||||
PROGRESS_1 = {
|
||||
"type": "progress",
|
||||
"uuid": "prog1",
|
||||
"parentUuid": "a1-tool",
|
||||
"data": {"type": "bash_progress", "stdout": "running ls..."},
|
||||
}
|
||||
# Second user message
|
||||
USER_2 = {
|
||||
"type": "user",
|
||||
"uuid": "u2",
|
||||
"parentUuid": "a1-text",
|
||||
"message": {"role": "user", "content": "Show me file1.py"},
|
||||
}
|
||||
ASST_2 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a2",
|
||||
"parentUuid": "u2",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_ccc",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "Here is file1.py content..."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
|
||||
# --- Compaction summary (written by CLI after context compaction) ---
|
||||
COMPACT_SUMMARY = {
|
||||
"type": "summary",
|
||||
"uuid": "cs1",
|
||||
"isCompactSummary": True,
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Summary: User asked about project files. Found file1.py and file2.py. "
|
||||
"User then asked to see file1.py."
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
# Post-compaction assistant response
|
||||
POST_COMPACT_ASST = {
|
||||
"type": "assistant",
|
||||
"uuid": "a3",
|
||||
"parentUuid": "cs1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_ddd",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "Here is the content of file1.py..."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
|
||||
# Post-compaction user follow-up
|
||||
USER_3 = {
|
||||
"type": "user",
|
||||
"uuid": "u3",
|
||||
"parentUuid": "a3",
|
||||
"message": {"role": "user", "content": "Now show file2.py"},
|
||||
}
|
||||
ASST_3 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a4",
|
||||
"parentUuid": "u3",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_eee",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "Here is file2.py..."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# E2E test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactionE2E:
|
||||
def _write_session_file(self, session_dir, entries):
|
||||
"""Write a CLI session JSONL file."""
|
||||
path = session_dir / "session.jsonl"
|
||||
path.write_text(_make_jsonl(*entries))
|
||||
return path
|
||||
|
||||
def test_full_compaction_lifecycle(self, tmp_path):
|
||||
"""Simulate the complete service.py compaction flow.
|
||||
|
||||
Timeline:
|
||||
1. Previous turn uploaded transcript with [USER_1, ASST_1, USER_2, ASST_2]
|
||||
2. Current turn: download → load_previous
|
||||
3. User sends "Now show file2.py" → append_user
|
||||
4. SDK starts streaming response
|
||||
5. Mid-stream: PreCompact hook fires (context too large)
|
||||
6. CLI writes compaction summary to session file
|
||||
7. Next SDK message → emit_start (spinner)
|
||||
8. Following message → emit_end (end events)
|
||||
9. _read_compacted_entries reads the session file
|
||||
10. replace_entries syncs TranscriptBuilder
|
||||
11. More assistant messages appended
|
||||
12. Export → upload → next turn downloads it
|
||||
"""
|
||||
session_dir = tmp_path / "session"
|
||||
session_dir.mkdir(parents=True)
|
||||
|
||||
# --- Step 1-2: Load "downloaded" transcript from previous turn ---
|
||||
previous_transcript = _make_jsonl(
|
||||
USER_1,
|
||||
ASST_1_THINKING,
|
||||
ASST_1_TOOL,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
)
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(previous_transcript)
|
||||
assert builder.entry_count == 7
|
||||
|
||||
# --- Step 3: User sends new query ---
|
||||
builder.append_user("Now show file2.py")
|
||||
assert builder.entry_count == 8
|
||||
|
||||
# --- Step 4: SDK starts streaming ---
|
||||
builder.append_assistant(
|
||||
[{"type": "thinking", "thinking": "Let me read file2.py..."}],
|
||||
model="claude-sonnet-4-20250514",
|
||||
)
|
||||
assert builder.entry_count == 9
|
||||
|
||||
# --- Step 5-6: PreCompact fires, CLI writes session file ---
|
||||
session_file = self._write_session_file(
|
||||
session_dir,
|
||||
[
|
||||
USER_1,
|
||||
ASST_1_THINKING,
|
||||
ASST_1_TOOL,
|
||||
PROGRESS_1,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
COMPACT_SUMMARY,
|
||||
POST_COMPACT_ASST,
|
||||
USER_3,
|
||||
ASST_3,
|
||||
],
|
||||
)
|
||||
|
||||
# --- Step 7: CompactionTracker receives PreCompact hook ---
|
||||
tracker = CompactionTracker()
|
||||
session = ChatSession.new(user_id="test-user")
|
||||
# on_compact is a property returning Event.set callable
|
||||
tracker.on_compact()
|
||||
|
||||
# --- Step 8: Next SDK message arrives → emit_start ---
|
||||
start_events = tracker.emit_start_if_ready()
|
||||
assert len(start_events) == 3
|
||||
assert isinstance(start_events[0], StreamStartStep)
|
||||
assert isinstance(start_events[1], StreamToolInputStart)
|
||||
assert isinstance(start_events[2], StreamToolInputAvailable)
|
||||
|
||||
# Verify tool_call_id is set
|
||||
tool_call_id = start_events[1].toolCallId
|
||||
assert tool_call_id.startswith("compaction-")
|
||||
|
||||
# --- Step 9: Following message → emit_end ---
|
||||
end_events = _run(tracker.emit_end_if_ready(session))
|
||||
assert len(end_events) == 2
|
||||
assert isinstance(end_events[0], StreamToolOutputAvailable)
|
||||
assert isinstance(end_events[1], StreamFinishStep)
|
||||
# Verify same tool_call_id
|
||||
assert end_events[0].toolCallId == tool_call_id
|
||||
|
||||
# Session should have compaction messages persisted
|
||||
assert len(session.messages) == 2
|
||||
assert session.messages[0].role == "assistant"
|
||||
assert session.messages[1].role == "tool"
|
||||
|
||||
# --- Step 10: _read_compacted_entries + replace_entries ---
|
||||
result = _read_compacted_entries(str(session_file))
|
||||
assert result is not None
|
||||
compacted_dicts, compacted_jsonl = result
|
||||
# Should have: COMPACT_SUMMARY + POST_COMPACT_ASST + USER_3 + ASST_3
|
||||
assert len(compacted_dicts) == 4
|
||||
assert compacted_dicts[0]["uuid"] == "cs1"
|
||||
assert compacted_dicts[0]["isCompactSummary"] is True
|
||||
|
||||
# Replace builder state with compacted JSONL
|
||||
old_count = builder.entry_count
|
||||
builder.replace_entries(compacted_jsonl)
|
||||
assert builder.entry_count == 4 # Only compacted entries
|
||||
assert builder.entry_count < old_count # Compaction reduced entries
|
||||
|
||||
# --- Step 11: More assistant messages after compaction ---
|
||||
builder.append_assistant(
|
||||
[{"type": "text", "text": "Here is file2.py:\n\ndef hello():\n pass"}],
|
||||
model="claude-sonnet-4-20250514",
|
||||
stop_reason="end_turn",
|
||||
)
|
||||
assert builder.entry_count == 5
|
||||
|
||||
# --- Step 12: Export for upload ---
|
||||
output = builder.to_jsonl()
|
||||
assert output # Not empty
|
||||
output_entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||
assert len(output_entries) == 5
|
||||
|
||||
# Verify structure:
|
||||
# [COMPACT_SUMMARY, POST_COMPACT_ASST, USER_3, ASST_3, new_assistant]
|
||||
assert output_entries[0]["type"] == "summary"
|
||||
assert output_entries[0].get("isCompactSummary") is True
|
||||
assert output_entries[0]["uuid"] == "cs1"
|
||||
assert output_entries[1]["uuid"] == "a3"
|
||||
assert output_entries[2]["uuid"] == "u3"
|
||||
assert output_entries[3]["uuid"] == "a4"
|
||||
assert output_entries[4]["type"] == "assistant"
|
||||
|
||||
# Verify parent chain is intact
|
||||
assert output_entries[1]["parentUuid"] == "cs1" # a3 → cs1
|
||||
assert output_entries[2]["parentUuid"] == "a3" # u3 → a3
|
||||
assert output_entries[3]["parentUuid"] == "u3" # a4 → u3
|
||||
assert output_entries[4]["parentUuid"] == "a4" # new → a4
|
||||
|
||||
# --- Step 13: Roundtrip — next turn loads this export ---
|
||||
builder2 = TranscriptBuilder()
|
||||
builder2.load_previous(output)
|
||||
assert builder2.entry_count == 5
|
||||
|
||||
# isCompactSummary survives roundtrip
|
||||
output2 = builder2.to_jsonl()
|
||||
first_entry = json.loads(output2.strip().split("\n")[0])
|
||||
assert first_entry.get("isCompactSummary") is True
|
||||
|
||||
# Can append more messages
|
||||
builder2.append_user("What about file3.py?")
|
||||
assert builder2.entry_count == 6
|
||||
final_output = builder2.to_jsonl()
|
||||
last_entry = json.loads(final_output.strip().split("\n")[-1])
|
||||
assert last_entry["type"] == "user"
|
||||
# Parented to the last entry from previous turn
|
||||
assert last_entry["parentUuid"] == output_entries[-1]["uuid"]
|
||||
|
||||
def test_double_compaction_within_session(self, tmp_path):
|
||||
"""Two compactions in the same session (across reset_for_query)."""
|
||||
session_dir = tmp_path / "session"
|
||||
session_dir.mkdir(parents=True)
|
||||
|
||||
tracker = CompactionTracker()
|
||||
session = ChatSession.new(user_id="test")
|
||||
builder = TranscriptBuilder()
|
||||
|
||||
# --- First query with compaction ---
|
||||
builder.append_user("first question")
|
||||
builder.append_assistant([{"type": "text", "text": "first answer"}])
|
||||
|
||||
# Write session file for first compaction
|
||||
first_summary = {
|
||||
"type": "summary",
|
||||
"uuid": "cs-first",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "First compaction summary"},
|
||||
}
|
||||
first_post = {
|
||||
"type": "assistant",
|
||||
"uuid": "a-first",
|
||||
"parentUuid": "cs-first",
|
||||
"message": {"role": "assistant", "content": "first post-compact"},
|
||||
}
|
||||
file1 = session_dir / "session1.jsonl"
|
||||
file1.write_text(_make_jsonl(first_summary, first_post))
|
||||
|
||||
tracker.on_compact()
|
||||
tracker.emit_start_if_ready()
|
||||
end_events1 = _run(tracker.emit_end_if_ready(session))
|
||||
assert len(end_events1) == 2 # output + finish
|
||||
|
||||
result1_entries = _read_compacted_entries(str(file1))
|
||||
assert result1_entries is not None
|
||||
_, compacted1_jsonl = result1_entries
|
||||
builder.replace_entries(compacted1_jsonl)
|
||||
assert builder.entry_count == 2
|
||||
|
||||
# --- Reset for second query ---
|
||||
tracker.reset_for_query()
|
||||
|
||||
# --- Second query with compaction ---
|
||||
builder.append_user("second question")
|
||||
builder.append_assistant([{"type": "text", "text": "second answer"}])
|
||||
|
||||
second_summary = {
|
||||
"type": "summary",
|
||||
"uuid": "cs-second",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "Second compaction summary"},
|
||||
}
|
||||
second_post = {
|
||||
"type": "assistant",
|
||||
"uuid": "a-second",
|
||||
"parentUuid": "cs-second",
|
||||
"message": {"role": "assistant", "content": "second post-compact"},
|
||||
}
|
||||
file2 = session_dir / "session2.jsonl"
|
||||
file2.write_text(_make_jsonl(second_summary, second_post))
|
||||
|
||||
tracker.on_compact()
|
||||
tracker.emit_start_if_ready()
|
||||
end_events2 = _run(tracker.emit_end_if_ready(session))
|
||||
assert len(end_events2) == 2 # output + finish
|
||||
|
||||
result2_entries = _read_compacted_entries(str(file2))
|
||||
assert result2_entries is not None
|
||||
_, compacted2_jsonl = result2_entries
|
||||
builder.replace_entries(compacted2_jsonl)
|
||||
assert builder.entry_count == 2 # Only second compaction entries
|
||||
|
||||
# Export and verify
|
||||
output = builder.to_jsonl()
|
||||
entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||
assert entries[0]["uuid"] == "cs-second"
|
||||
assert entries[0].get("isCompactSummary") is True
|
||||
|
||||
def test_strip_progress_then_load_then_compact_roundtrip(self, tmp_path):
|
||||
"""Full pipeline: strip → load → compact → replace → export → reload.
|
||||
|
||||
This tests the exact sequence that happens across two turns:
|
||||
Turn 1: SDK produces transcript with progress entries
|
||||
Upload: strip_progress_entries removes progress, upload to cloud
|
||||
Turn 2: Download → load_previous → compaction fires → replace → export
|
||||
Turn 3: Download the Turn 2 export → load_previous (roundtrip)
|
||||
"""
|
||||
session_dir = tmp_path / "session"
|
||||
session_dir.mkdir(parents=True)
|
||||
|
||||
# --- Turn 1: SDK produces raw transcript ---
|
||||
raw_content = _make_jsonl(
|
||||
USER_1,
|
||||
ASST_1_THINKING,
|
||||
ASST_1_TOOL,
|
||||
PROGRESS_1,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
)
|
||||
|
||||
# Strip progress for upload
|
||||
stripped = strip_progress_entries(raw_content)
|
||||
stripped_entries = [
|
||||
json.loads(line) for line in stripped.strip().split("\n") if line.strip()
|
||||
]
|
||||
# Progress should be gone
|
||||
assert not any(e.get("type") == "progress" for e in stripped_entries)
|
||||
assert len(stripped_entries) == 7 # 8 - 1 progress
|
||||
|
||||
# --- Turn 2: Download stripped, load, compaction happens ---
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(stripped)
|
||||
assert builder.entry_count == 7
|
||||
|
||||
builder.append_user("Now show file2.py")
|
||||
builder.append_assistant(
|
||||
[{"type": "text", "text": "Reading file2.py..."}],
|
||||
model="claude-sonnet-4-20250514",
|
||||
)
|
||||
|
||||
# CLI writes session file with compaction
|
||||
session_file = self._write_session_file(
|
||||
session_dir,
|
||||
[
|
||||
USER_1,
|
||||
ASST_1_TOOL,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
COMPACT_SUMMARY,
|
||||
POST_COMPACT_ASST,
|
||||
],
|
||||
)
|
||||
|
||||
result = _read_compacted_entries(str(session_file))
|
||||
assert result is not None
|
||||
_, compacted_jsonl = result
|
||||
builder.replace_entries(compacted_jsonl)
|
||||
|
||||
# Append post-compaction message
|
||||
builder.append_user("Thanks!")
|
||||
output = builder.to_jsonl()
|
||||
|
||||
# --- Turn 3: Fresh load of Turn 2 export ---
|
||||
builder3 = TranscriptBuilder()
|
||||
builder3.load_previous(output)
|
||||
# Should have: compact_summary + post_compact_asst + "Thanks!"
|
||||
assert builder3.entry_count == 3
|
||||
|
||||
# Compact summary survived the full pipeline
|
||||
first = json.loads(builder3.to_jsonl().strip().split("\n")[0])
|
||||
assert first.get("isCompactSummary") is True
|
||||
assert first["type"] == "summary"
|
||||
@@ -26,17 +26,3 @@ For other services, search the MCP registry at https://registry.modelcontextprot
|
||||
|
||||
If the server requires credentials, a `SetupRequirementsResponse` is returned with an OAuth
|
||||
login prompt. Once the user completes the flow and confirms, retry the same call immediately.
|
||||
|
||||
### Communication style
|
||||
|
||||
Avoid technical jargon like "MCP server", "OAuth", or "credentials" when talking to the user.
|
||||
Use plain, friendly language instead:
|
||||
|
||||
| Instead of… | Say… |
|
||||
|---|---|
|
||||
| "Let me connect to Sentry's MCP server and discover what tools are available." | "I can connect to Sentry and help identify important issues." |
|
||||
| "Let me connect to Sentry's MCP server now." | "Next, I'll connect to Sentry." |
|
||||
| "The MCP server at mcp.sentry.dev requires authentication. Please connect your credentials to continue." | "To continue, sign in to Sentry and approve access." |
|
||||
| "Sentry's MCP server needs OAuth authentication. You should see a prompt to connect your Sentry account…" | "You should see a prompt to sign in to Sentry. Once connected, I can help surface critical issues right away." |
|
||||
|
||||
Use **"connect to [Service]"** or **"sign in to [Service]"** — never "MCP server", "OAuth", or "credentials".
|
||||
|
||||
@@ -221,12 +221,12 @@ class SDKResponseAdapter:
|
||||
responses.append(StreamFinish())
|
||||
else:
|
||||
logger.warning(
|
||||
"Unexpected ResultMessage subtype: %s", sdk_message.subtype
|
||||
f"Unexpected ResultMessage subtype: {sdk_message.subtype}"
|
||||
)
|
||||
responses.append(StreamFinish())
|
||||
|
||||
else:
|
||||
logger.debug("Unhandled SDK message type: %s", type(sdk_message).__name__)
|
||||
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
|
||||
|
||||
return responses
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ def _validate_workspace_path(
|
||||
if is_allowed_local_path(path, sdk_cwd):
|
||||
return {}
|
||||
|
||||
logger.warning("Blocked %s outside workspace: %s", tool_name, path)
|
||||
logger.warning(f"Blocked {tool_name} outside workspace: {path}")
|
||||
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
|
||||
return _deny(
|
||||
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
|
||||
@@ -71,7 +71,7 @@ def _validate_tool_access(
|
||||
"""
|
||||
# Block forbidden tools
|
||||
if tool_name in BLOCKED_TOOLS:
|
||||
logger.warning("Blocked tool access attempt: %s", tool_name)
|
||||
logger.warning(f"Blocked tool access attempt: {tool_name}")
|
||||
return _deny(
|
||||
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
|
||||
"This is enforced by the platform and cannot be bypassed. "
|
||||
@@ -111,9 +111,7 @@ def _validate_user_isolation(
|
||||
# the tool itself via _validate_ephemeral_path.
|
||||
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
||||
if path and ".." in path:
|
||||
logger.warning(
|
||||
"Blocked path traversal attempt: %s by user %s", path, user_id
|
||||
)
|
||||
logger.warning(f"Blocked path traversal attempt: {path} by user {user_id}")
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
@@ -171,7 +169,7 @@ def create_security_hooks(
|
||||
# Block background task execution first — denied calls
|
||||
# should not consume a subtask slot.
|
||||
if tool_input.get("run_in_background"):
|
||||
logger.info("[SDK] Blocked background Task, user=%s", user_id)
|
||||
logger.info(f"[SDK] Blocked background Task, user={user_id}")
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
_deny(
|
||||
@@ -213,7 +211,7 @@ def create_security_hooks(
|
||||
if tool_name == "Task" and tool_use_id is not None:
|
||||
task_tool_use_ids.add(tool_use_id)
|
||||
|
||||
logger.debug("[SDK] Tool start: %s, user=%s", tool_name, user_id)
|
||||
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:
|
||||
|
||||
@@ -40,13 +40,11 @@ from ..constants import COPILOT_ERROR_PREFIX, COPILOT_SYSTEM_PREFIX
|
||||
from ..model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from ..prompting import get_sdk_supplement
|
||||
from ..rate_limit import record_token_usage
|
||||
from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
@@ -56,7 +54,6 @@ from ..response_model import (
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from ..service import (
|
||||
_build_system_prompt,
|
||||
@@ -78,12 +75,8 @@ from .tool_adapter import (
|
||||
wait_for_stash,
|
||||
)
|
||||
from .transcript import (
|
||||
COMPACT_THRESHOLD_BYTES,
|
||||
TranscriptDownload,
|
||||
cleanup_cli_project_dir,
|
||||
compact_transcript,
|
||||
download_transcript,
|
||||
read_cli_session_file,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
@@ -301,7 +294,7 @@ def _cleanup_sdk_tool_results(cwd: str) -> None:
|
||||
"""
|
||||
normalized = os.path.normpath(cwd)
|
||||
if not normalized.startswith(_SDK_CWD_PREFIX):
|
||||
logger.warning("[SDK] Rejecting cleanup for path outside workspace: %s", cwd)
|
||||
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
|
||||
return
|
||||
|
||||
# Clean the CLI's project directory (transcripts + tool-results).
|
||||
@@ -395,7 +388,7 @@ async def _compress_messages(
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[SDK] Context compression with LLM failed: %s", e)
|
||||
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
|
||||
# Fall back to truncation-only (no LLM summarization)
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
@@ -631,56 +624,6 @@ async def _prepare_file_attachments(
|
||||
return PreparedAttachments(hint=hint, image_blocks=image_blocks)
|
||||
|
||||
|
||||
async def _maybe_compact_and_upload(
|
||||
dl: TranscriptDownload,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> str:
|
||||
"""Compact an oversized transcript and upload the compacted version.
|
||||
|
||||
Returns the (possibly compacted) transcript content, or an empty string
|
||||
if compaction was needed but failed.
|
||||
"""
|
||||
content = dl.content
|
||||
if len(content) <= COMPACT_THRESHOLD_BYTES:
|
||||
return content
|
||||
|
||||
logger.warning(
|
||||
"%s Transcript oversized (%dB > %dB), compacting",
|
||||
log_prefix,
|
||||
len(content),
|
||||
COMPACT_THRESHOLD_BYTES,
|
||||
)
|
||||
compacted = await compact_transcript(content, log_prefix=log_prefix)
|
||||
if not compacted:
|
||||
logger.warning(
|
||||
"%s Compaction failed, skipping resume for this turn", log_prefix
|
||||
)
|
||||
return ""
|
||||
|
||||
# Keep the original message_count: it reflects the number of
|
||||
# session.messages covered by this transcript, which the gap-fill
|
||||
# logic uses as a slice index. Counting JSONL lines would give a
|
||||
# smaller number (compacted messages != session message count) and
|
||||
# cause already-covered messages to be re-injected.
|
||||
try:
|
||||
await upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
content=compacted,
|
||||
message_count=dl.message_count,
|
||||
log_prefix=log_prefix,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"%s Failed to upload compacted transcript",
|
||||
log_prefix,
|
||||
exc_info=True,
|
||||
)
|
||||
return compacted
|
||||
|
||||
|
||||
async def stream_chat_completion_sdk(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
@@ -792,14 +735,6 @@ async def stream_chat_completion_sdk(
|
||||
_otel_ctx: Any = None
|
||||
|
||||
# Make sure there is no more code between the lock acquisition and try-block.
|
||||
# Token usage accumulators — populated from ResultMessage at end of turn
|
||||
turn_prompt_tokens = 0 # uncached input tokens only
|
||||
turn_completion_tokens = 0
|
||||
turn_cache_read_tokens = 0
|
||||
turn_cache_creation_tokens = 0
|
||||
total_tokens = 0 # computed once before StreamUsage, reused in finally
|
||||
turn_cost_usd: float | None = None
|
||||
|
||||
try:
|
||||
# Build system prompt (reuses non-SDK path with Langfuse support).
|
||||
# Pre-compute the cwd here so the exact working directory path can be
|
||||
@@ -892,33 +827,20 @@ async def stream_chat_completion_sdk(
|
||||
is_valid,
|
||||
)
|
||||
if is_valid:
|
||||
transcript_content = await _maybe_compact_and_upload(
|
||||
dl,
|
||||
user_id=user_id or "",
|
||||
session_id=session_id,
|
||||
log_prefix=log_prefix,
|
||||
)
|
||||
# Load previous context into builder (empty string is a no-op)
|
||||
if transcript_content:
|
||||
transcript_builder.load_previous(
|
||||
transcript_content, log_prefix=log_prefix
|
||||
)
|
||||
resume_file = (
|
||||
write_transcript_to_tempfile(
|
||||
transcript_content, session_id, sdk_cwd
|
||||
)
|
||||
if transcript_content
|
||||
else None
|
||||
# Load previous FULL context into builder
|
||||
transcript_builder.load_previous(dl.content, log_prefix=log_prefix)
|
||||
resume_file = write_transcript_to_tempfile(
|
||||
dl.content, session_id, sdk_cwd
|
||||
)
|
||||
if resume_file:
|
||||
use_resume = True
|
||||
transcript_msg_count = dl.message_count
|
||||
logger.debug(
|
||||
f"{log_prefix} Using --resume ({len(transcript_content)}B, "
|
||||
f"{log_prefix} Using --resume ({len(dl.content)}B, "
|
||||
f"msg_count={transcript_msg_count})"
|
||||
)
|
||||
else:
|
||||
logger.warning("%s Transcript downloaded but invalid", log_prefix)
|
||||
logger.warning(f"{log_prefix} Transcript downloaded but invalid")
|
||||
elif config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
||||
logger.warning(
|
||||
f"{log_prefix} No transcript available "
|
||||
@@ -1188,7 +1110,7 @@ async def stream_chat_completion_sdk(
|
||||
- len(adapter.resolved_tool_calls),
|
||||
)
|
||||
|
||||
# Log ResultMessage details and capture token usage
|
||||
# Log ResultMessage details for debugging
|
||||
if isinstance(sdk_msg, ResultMessage):
|
||||
logger.info(
|
||||
"%s Received: ResultMessage %s "
|
||||
@@ -1207,46 +1129,9 @@ async def stream_chat_completion_sdk(
|
||||
sdk_msg.result or "(no error message provided)",
|
||||
)
|
||||
|
||||
# Capture token usage from ResultMessage.
|
||||
# Anthropic reports cached tokens separately:
|
||||
# input_tokens = uncached only
|
||||
# cache_read_input_tokens = served from cache
|
||||
# cache_creation_input_tokens = written to cache
|
||||
if sdk_msg.usage:
|
||||
turn_prompt_tokens += sdk_msg.usage.get("input_tokens", 0)
|
||||
turn_cache_read_tokens += sdk_msg.usage.get(
|
||||
"cache_read_input_tokens", 0
|
||||
)
|
||||
turn_cache_creation_tokens += sdk_msg.usage.get(
|
||||
"cache_creation_input_tokens", 0
|
||||
)
|
||||
turn_completion_tokens += sdk_msg.usage.get(
|
||||
"output_tokens", 0
|
||||
)
|
||||
logger.info(
|
||||
"%s Token usage: uncached=%d, cache_read=%d, cache_create=%d, output=%d",
|
||||
log_prefix,
|
||||
turn_prompt_tokens,
|
||||
turn_cache_read_tokens,
|
||||
turn_cache_creation_tokens,
|
||||
turn_completion_tokens,
|
||||
)
|
||||
if sdk_msg.total_cost_usd is not None:
|
||||
turn_cost_usd = sdk_msg.total_cost_usd
|
||||
|
||||
# Emit compaction end if SDK finished compacting.
|
||||
# When compaction ends, sync TranscriptBuilder with
|
||||
# the CLI's compacted session file so the uploaded
|
||||
# transcript reflects compaction.
|
||||
compaction_events = await compaction.emit_end_if_ready(session)
|
||||
for ev in compaction_events:
|
||||
# Emit compaction end if SDK finished compacting
|
||||
for ev in await compaction.emit_end_if_ready(session):
|
||||
yield ev
|
||||
if compaction_events and sdk_cwd:
|
||||
cli_content = await read_cli_session_file(sdk_cwd)
|
||||
if cli_content:
|
||||
transcript_builder.replace_entries(
|
||||
cli_content, log_prefix=log_prefix
|
||||
)
|
||||
|
||||
for response in adapter.convert_message(sdk_msg):
|
||||
if isinstance(response, StreamStart):
|
||||
@@ -1440,27 +1325,6 @@ async def stream_chat_completion_sdk(
|
||||
) and not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
|
||||
# Emit token usage to the client (must be in try to reach SSE stream).
|
||||
# Session persistence of usage is in finally to stay consistent with
|
||||
# rate-limit recording even if an exception interrupts between here
|
||||
# and the finally block.
|
||||
# Compute total_tokens once; reused in the finally block for
|
||||
# session persistence and rate-limit recording.
|
||||
total_tokens = (
|
||||
turn_prompt_tokens
|
||||
+ turn_cache_read_tokens
|
||||
+ turn_cache_creation_tokens
|
||||
+ turn_completion_tokens
|
||||
)
|
||||
if total_tokens > 0:
|
||||
yield StreamUsage(
|
||||
promptTokens=turn_prompt_tokens,
|
||||
completionTokens=turn_completion_tokens,
|
||||
totalTokens=total_tokens,
|
||||
cacheReadTokens=turn_cache_read_tokens,
|
||||
cacheCreationTokens=turn_cache_creation_tokens,
|
||||
)
|
||||
|
||||
# Transcript upload is handled exclusively in the finally block
|
||||
# to avoid double-uploads (the success path used to upload the
|
||||
# old resume file, then the finally block overwrote it with the
|
||||
@@ -1525,48 +1389,6 @@ async def stream_chat_completion_sdk(
|
||||
except Exception:
|
||||
logger.warning("OTEL context teardown failed", exc_info=True)
|
||||
|
||||
# --- Persist token usage to session + rate-limit counters ---
|
||||
# Both must live in finally so they stay consistent even when an
|
||||
# exception interrupts the try block after StreamUsage was yielded.
|
||||
# total_tokens is computed once before StreamUsage yield above.
|
||||
if total_tokens > 0:
|
||||
if session is not None:
|
||||
session.usage.append(
|
||||
Usage(
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cache_read_tokens=turn_cache_read_tokens,
|
||||
cache_creation_tokens=turn_cache_creation_tokens,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"%s Turn usage: uncached=%d, cache_read=%d, cache_create=%d, "
|
||||
"output=%d, total=%d, cost_usd=%s",
|
||||
log_prefix,
|
||||
turn_prompt_tokens,
|
||||
turn_cache_read_tokens,
|
||||
turn_cache_creation_tokens,
|
||||
turn_completion_tokens,
|
||||
total_tokens,
|
||||
turn_cost_usd,
|
||||
)
|
||||
if user_id and total_tokens > 0:
|
||||
try:
|
||||
await record_token_usage(
|
||||
user_id=user_id,
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
cache_read_tokens=turn_cache_read_tokens,
|
||||
cache_creation_tokens=turn_cache_creation_tokens,
|
||||
)
|
||||
except Exception as usage_err:
|
||||
logger.warning(
|
||||
"%s Failed to record token usage: %s",
|
||||
log_prefix,
|
||||
usage_err,
|
||||
)
|
||||
|
||||
# --- Persist session messages ---
|
||||
# This MUST run in finally to persist messages even when the generator
|
||||
# is stopped early (e.g., user clicks stop, processor breaks stream loop).
|
||||
@@ -1662,6 +1484,6 @@ async def _update_title_async(
|
||||
)
|
||||
if title and user_id:
|
||||
await update_session_title(session_id, user_id, title, only_if_empty=True)
|
||||
logger.debug("[SDK] Generated title for %s: %s", session_id, title)
|
||||
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
|
||||
except Exception as e:
|
||||
logger.warning("[SDK] Failed to update session title: %s", e)
|
||||
logger.warning(f"[SDK] Failed to update session title: {e}")
|
||||
|
||||
@@ -234,9 +234,7 @@ def create_tool_handler(base_tool: BaseTool):
|
||||
try:
|
||||
return await _execute_tool_sync(base_tool, user_id, session, args)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error executing tool %s: %s", base_tool.name, e, exc_info=True
|
||||
)
|
||||
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
||||
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
|
||||
|
||||
return tool_handler
|
||||
|
||||
@@ -13,17 +13,10 @@ filesystem for self-hosted) — no DB column needed.
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import openai
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.util import json
|
||||
from backend.util.prompt import CompressResult, compress_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -41,11 +34,6 @@ STRIPPABLE_TYPES = frozenset(
|
||||
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
|
||||
)
|
||||
|
||||
# JSONL protocol values used in transcript serialization.
|
||||
STOP_REASON_END_TURN = "end_turn"
|
||||
COMPACT_MSG_ID_PREFIX = "msg_compact_"
|
||||
ENTRY_TYPE_MESSAGE = "message"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptDownload:
|
||||
@@ -94,11 +82,7 @@ def strip_progress_entries(content: str) -> str:
|
||||
parent = entry.get("parentUuid", "")
|
||||
if uid:
|
||||
uuid_to_parent[uid] = parent
|
||||
if (
|
||||
entry.get("type", "") in STRIPPABLE_TYPES
|
||||
and uid
|
||||
and not entry.get("isCompactSummary")
|
||||
):
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES and uid:
|
||||
stripped_uuids.add(uid)
|
||||
|
||||
# Second pass: keep non-stripped entries, reparenting where needed.
|
||||
@@ -109,9 +93,7 @@ def strip_progress_entries(content: str) -> str:
|
||||
continue
|
||||
parent = entry.get("parentUuid", "")
|
||||
original_parent = parent
|
||||
seen_parents: set[str] = set()
|
||||
while parent in stripped_uuids and parent not in seen_parents:
|
||||
seen_parents.add(parent)
|
||||
while parent in stripped_uuids:
|
||||
parent = uuid_to_parent.get(parent, "")
|
||||
if parent != original_parent:
|
||||
entry["parentUuid"] = parent
|
||||
@@ -124,9 +106,7 @@ def strip_progress_entries(content: str) -> str:
|
||||
if not isinstance(entry, dict):
|
||||
result_lines.append(line)
|
||||
continue
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
|
||||
"isCompactSummary"
|
||||
):
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES:
|
||||
continue
|
||||
uid = entry.get("uuid", "")
|
||||
if uid in reparented:
|
||||
@@ -157,78 +137,32 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
||||
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
||||
|
||||
|
||||
def _cli_project_dir(sdk_cwd: str) -> str | None:
|
||||
"""Return the CLI's project directory for a given working directory.
|
||||
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""Remove the CLI's project directory for a specific working directory.
|
||||
|
||||
Returns ``None`` if the path would escape the projects base.
|
||||
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
|
||||
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
|
||||
safe to remove entirely after the transcript has been uploaded.
|
||||
"""
|
||||
import shutil
|
||||
|
||||
# Encode cwd the same way CLI does (replaces non-alphanumeric with -)
|
||||
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
|
||||
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
|
||||
|
||||
if not project_dir.startswith(projects_base + os.sep):
|
||||
logger.warning("[Transcript] Project dir escaped base: %s", project_dir)
|
||||
return None
|
||||
return project_dir
|
||||
|
||||
|
||||
async def read_cli_session_file(sdk_cwd: str) -> str | None:
|
||||
"""Read the CLI's own session file, which reflects any mid-stream compaction.
|
||||
|
||||
After the CLI compacts context, its session file contains the compacted
|
||||
conversation. Reading this file lets ``TranscriptBuilder`` replace its
|
||||
uncompacted entries with the CLI's compacted version.
|
||||
"""
|
||||
import aiofiles
|
||||
|
||||
project_dir = _cli_project_dir(sdk_cwd)
|
||||
if not project_dir or not os.path.isdir(project_dir):
|
||||
return None
|
||||
jsonl_files = list(Path(project_dir).glob("*.jsonl"))
|
||||
if not jsonl_files:
|
||||
logger.debug("[Transcript] No CLI session file in %s", project_dir)
|
||||
return None
|
||||
# Pick the most recently modified file (there should only be one per turn).
|
||||
# Guard against races where a file is deleted between glob and stat.
|
||||
candidates: list[tuple[float, Path]] = []
|
||||
for p in jsonl_files:
|
||||
try:
|
||||
candidates.append((p.stat().st_mtime, p))
|
||||
except OSError:
|
||||
continue
|
||||
if not candidates:
|
||||
logger.debug("[Transcript] No readable CLI session file in %s", project_dir)
|
||||
return None
|
||||
# Resolve + prefix check to prevent symlink escapes.
|
||||
session_file = max(candidates, key=lambda item: item[0])[1]
|
||||
real_path = str(session_file.resolve())
|
||||
if not real_path.startswith(project_dir + os.sep):
|
||||
logger.warning("[Transcript] Session file escaped project dir: %s", real_path)
|
||||
return None
|
||||
try:
|
||||
async with aiofiles.open(real_path) as f:
|
||||
content = await f.read()
|
||||
logger.info(
|
||||
"[Transcript] Read CLI session file: %s (%d bytes)",
|
||||
real_path,
|
||||
len(content),
|
||||
logger.warning(
|
||||
f"[Transcript] Cleanup path escaped projects base: {project_dir}"
|
||||
)
|
||||
return content
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to read CLI session file: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""Remove the CLI's project directory for a specific working directory."""
|
||||
project_dir = _cli_project_dir(sdk_cwd)
|
||||
if not project_dir:
|
||||
return
|
||||
|
||||
if os.path.isdir(project_dir):
|
||||
shutil.rmtree(project_dir, ignore_errors=True)
|
||||
logger.debug("[Transcript] Cleaned up CLI project dir: %s", project_dir)
|
||||
logger.debug(f"[Transcript] Cleaned up CLI project dir: {project_dir}")
|
||||
else:
|
||||
logger.debug("[Transcript] Project dir not found: %s", project_dir)
|
||||
logger.debug(f"[Transcript] Project dir not found: {project_dir}")
|
||||
|
||||
|
||||
def write_transcript_to_tempfile(
|
||||
@@ -246,7 +180,7 @@ def write_transcript_to_tempfile(
|
||||
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
|
||||
real_cwd = os.path.realpath(cwd)
|
||||
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
|
||||
logger.warning("[Transcript] cwd outside sandbox: %s", cwd)
|
||||
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -256,17 +190,17 @@ def write_transcript_to_tempfile(
|
||||
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
|
||||
)
|
||||
if not jsonl_path.startswith(real_cwd):
|
||||
logger.warning("[Transcript] Path escaped cwd: %s", jsonl_path)
|
||||
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
|
||||
return None
|
||||
|
||||
with open(jsonl_path, "w") as f:
|
||||
f.write(transcript_content)
|
||||
|
||||
logger.info("[Transcript] Wrote resume file: %s", jsonl_path)
|
||||
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
|
||||
return jsonl_path
|
||||
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to write resume file: %s", e)
|
||||
logger.warning(f"[Transcript] Failed to write resume file: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -410,14 +344,11 @@ async def upload_transcript(
|
||||
content=json.dumps(meta).encode("utf-8"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("%s Failed to write metadata: %s", log_prefix, e)
|
||||
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
|
||||
|
||||
logger.info(
|
||||
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
|
||||
log_prefix,
|
||||
len(encoded),
|
||||
len(content),
|
||||
message_count,
|
||||
f"{log_prefix} Uploaded {len(encoded)}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count})"
|
||||
)
|
||||
|
||||
|
||||
@@ -440,10 +371,10 @@ async def download_transcript(
|
||||
data = await storage.retrieve(path)
|
||||
content = data.decode("utf-8")
|
||||
except FileNotFoundError:
|
||||
logger.debug("%s No transcript in storage", log_prefix)
|
||||
logger.debug(f"{log_prefix} No transcript in storage")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("%s Failed to download transcript: %s", log_prefix, e)
|
||||
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
|
||||
return None
|
||||
|
||||
# Try to load metadata (best-effort — old transcripts won't have it)
|
||||
@@ -463,14 +394,10 @@ async def download_transcript(
|
||||
meta = json.loads(meta_data.decode("utf-8"), fallback={})
|
||||
message_count = meta.get("message_count", 0)
|
||||
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||
except FileNotFoundError:
|
||||
except (FileNotFoundError, Exception):
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
except Exception as e:
|
||||
logger.debug("%s Failed to load transcript metadata: %s", log_prefix, e)
|
||||
|
||||
logger.info(
|
||||
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
|
||||
)
|
||||
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
|
||||
return TranscriptDownload(
|
||||
content=content,
|
||||
message_count=message_count,
|
||||
@@ -478,171 +405,15 @@ async def download_transcript(
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transcript compaction
|
||||
# ---------------------------------------------------------------------------
|
||||
async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
"""Delete transcript from bucket storage (e.g. after resume failure)."""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
# Transcripts above this byte threshold are compacted at download time.
|
||||
COMPACT_THRESHOLD_BYTES = 400_000
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
|
||||
|
||||
def _flatten_assistant_content(blocks: list) -> str:
|
||||
"""Flatten assistant content blocks into a single plain-text string."""
|
||||
parts: list[str] = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict):
|
||||
if block.get("type") == "text":
|
||||
parts.append(block.get("text", ""))
|
||||
elif block.get("type") == "tool_use":
|
||||
parts.append(f"[tool_use: {block.get('name', '?')}]")
|
||||
elif isinstance(block, str):
|
||||
parts.append(block)
|
||||
return "\n".join(parts) if parts else ""
|
||||
|
||||
|
||||
def _flatten_tool_result_content(blocks: list) -> str:
|
||||
"""Flatten tool_result and other content blocks into plain text.
|
||||
|
||||
Handles nested tool_result structures, text blocks, and raw strings.
|
||||
Uses ``json.dumps`` as fallback for dict blocks without a ``text`` key
|
||||
or where ``text`` is ``None``.
|
||||
"""
|
||||
str_parts: list[str] = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_result":
|
||||
inner = block.get("content", "")
|
||||
if isinstance(inner, list):
|
||||
for sub in inner:
|
||||
if isinstance(sub, dict):
|
||||
text = sub.get("text")
|
||||
str_parts.append(
|
||||
str(text) if text is not None else json.dumps(sub)
|
||||
)
|
||||
else:
|
||||
str_parts.append(str(sub))
|
||||
else:
|
||||
str_parts.append(str(inner))
|
||||
elif isinstance(block, dict) and block.get("type") == "text":
|
||||
str_parts.append(str(block.get("text", "")))
|
||||
elif isinstance(block, str):
|
||||
str_parts.append(block)
|
||||
return "\n".join(str_parts) if str_parts else ""
|
||||
|
||||
|
||||
def _transcript_to_messages(content: str) -> list[dict]:
|
||||
"""Convert JSONL transcript entries to message dicts for compress_context."""
|
||||
messages: list[dict] = []
|
||||
for line in content.strip().split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
entry = json.loads(line, fallback=None)
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
|
||||
"isCompactSummary"
|
||||
):
|
||||
continue
|
||||
msg = entry.get("message", {})
|
||||
role = msg.get("role", "")
|
||||
if not role:
|
||||
continue
|
||||
msg_dict: dict = {"role": role}
|
||||
raw_content = msg.get("content")
|
||||
if role == "assistant" and isinstance(raw_content, list):
|
||||
msg_dict["content"] = _flatten_assistant_content(raw_content)
|
||||
elif isinstance(raw_content, list):
|
||||
msg_dict["content"] = _flatten_tool_result_content(raw_content)
|
||||
else:
|
||||
msg_dict["content"] = raw_content or ""
|
||||
messages.append(msg_dict)
|
||||
return messages
|
||||
|
||||
|
||||
def _messages_to_transcript(messages: list[dict]) -> str:
|
||||
"""Convert compressed message dicts back to JSONL transcript format."""
|
||||
lines: list[str] = []
|
||||
last_uuid: str | None = None
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
entry_type = "assistant" if role == "assistant" else "user"
|
||||
uid = str(uuid4())
|
||||
content = msg.get("content", "")
|
||||
if role == "assistant":
|
||||
message: dict = {
|
||||
"role": "assistant",
|
||||
"model": "",
|
||||
"id": f"{COMPACT_MSG_ID_PREFIX}{uuid4().hex[:24]}",
|
||||
"type": ENTRY_TYPE_MESSAGE,
|
||||
"content": [{"type": "text", "text": content}] if content else [],
|
||||
"stop_reason": STOP_REASON_END_TURN,
|
||||
"stop_sequence": None,
|
||||
}
|
||||
else:
|
||||
message = {"role": role, "content": content}
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": last_uuid,
|
||||
"message": message,
|
||||
}
|
||||
lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
last_uuid = uid
|
||||
return "\n".join(lines) + "\n" if lines else ""
|
||||
|
||||
|
||||
async def _run_compression(
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
cfg: ChatConfig,
|
||||
log_prefix: str,
|
||||
) -> CompressResult:
|
||||
"""Run LLM-based compression with truncation fallback."""
|
||||
try:
|
||||
async with openai.AsyncOpenAI(
|
||||
api_key=cfg.api_key, base_url=cfg.base_url, timeout=30.0
|
||||
) as client:
|
||||
return await compress_context(messages=messages, model=model, client=client)
|
||||
await storage.delete(path)
|
||||
logger.info(f"[Transcript] Deleted transcript for session {session_id}")
|
||||
except Exception as e:
|
||||
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
|
||||
return await compress_context(messages=messages, model=model, client=None)
|
||||
|
||||
|
||||
async def compact_transcript(
|
||||
content: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> str | None:
|
||||
"""Compact an oversized JSONL transcript using LLM summarization.
|
||||
|
||||
Converts transcript entries to plain messages, runs ``compress_context``
|
||||
(the same compressor used for pre-query history), and rebuilds JSONL.
|
||||
|
||||
Returns the compacted JSONL string, or ``None`` on failure.
|
||||
"""
|
||||
cfg = ChatConfig()
|
||||
messages = _transcript_to_messages(content)
|
||||
if len(messages) < 2:
|
||||
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
|
||||
return None
|
||||
try:
|
||||
result = await _run_compression(messages, cfg.model, cfg, log_prefix)
|
||||
if not result.was_compacted:
|
||||
logger.info("%s Transcript already within token budget", log_prefix)
|
||||
return content
|
||||
logger.info(
|
||||
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
|
||||
log_prefix,
|
||||
result.original_token_count,
|
||||
result.token_count,
|
||||
result.messages_summarized,
|
||||
result.messages_dropped,
|
||||
)
|
||||
compacted = _messages_to_transcript(result.messages)
|
||||
if not validate_transcript(compacted):
|
||||
logger.warning("%s Compacted transcript failed validation", log_prefix)
|
||||
return None
|
||||
return compacted
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
|
||||
)
|
||||
return None
|
||||
logger.warning(f"[Transcript] Failed to delete transcript: {e}")
|
||||
|
||||
@@ -31,7 +31,6 @@ class TranscriptEntry(BaseModel):
|
||||
uuid: str
|
||||
parentUuid: str | None
|
||||
message: dict[str, Any]
|
||||
isCompactSummary: bool | None = None
|
||||
|
||||
|
||||
class TranscriptBuilder:
|
||||
@@ -79,12 +78,10 @@ class TranscriptBuilder:
|
||||
)
|
||||
continue
|
||||
|
||||
# Skip STRIPPABLE_TYPES unless the entry is a compaction summary.
|
||||
# Compaction summaries may have type "summary" but must be preserved
|
||||
# so --resume can reconstruct the compacted conversation.
|
||||
# Load all non-strippable entries (user/assistant/system/etc.)
|
||||
# Skip only STRIPPABLE_TYPES to match strip_progress_entries() behavior
|
||||
entry_type = data.get("type", "")
|
||||
is_compact = data.get("isCompactSummary", False)
|
||||
if entry_type in STRIPPABLE_TYPES and not is_compact:
|
||||
if entry_type in STRIPPABLE_TYPES:
|
||||
continue
|
||||
|
||||
entry = TranscriptEntry(
|
||||
@@ -92,7 +89,6 @@ class TranscriptBuilder:
|
||||
uuid=data.get("uuid") or str(uuid4()),
|
||||
parentUuid=data.get("parentUuid"),
|
||||
message=data.get("message", {}),
|
||||
isCompactSummary=True if is_compact else None,
|
||||
)
|
||||
self._entries.append(entry)
|
||||
self._last_uuid = entry.uuid
|
||||
@@ -181,33 +177,6 @@ class TranscriptBuilder:
|
||||
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
def replace_entries(self, content: str, log_prefix: str = "[Transcript]") -> None:
|
||||
"""Replace all entries with compacted JSONL content.
|
||||
|
||||
Called after the CLI performs mid-stream compaction so the builder's
|
||||
state reflects the compacted conversation instead of the full
|
||||
pre-compaction history.
|
||||
"""
|
||||
prev_count = len(self._entries)
|
||||
temp = TranscriptBuilder()
|
||||
try:
|
||||
temp.load_previous(content, log_prefix=log_prefix)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"%s Failed to parse compacted transcript; keeping %d existing entries",
|
||||
log_prefix,
|
||||
prev_count,
|
||||
)
|
||||
return
|
||||
self._entries = temp._entries
|
||||
self._last_uuid = temp._last_uuid
|
||||
logger.info(
|
||||
"%s Replaced %d entries with %d compacted entries",
|
||||
log_prefix,
|
||||
prev_count,
|
||||
len(self._entries),
|
||||
)
|
||||
|
||||
@property
|
||||
def entry_count(self) -> int:
|
||||
"""Total number of entries in the complete context."""
|
||||
|
||||
@@ -2,25 +2,14 @@
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .transcript import (
|
||||
COMPACT_MSG_ID_PREFIX,
|
||||
STRIPPABLE_TYPES,
|
||||
_cli_project_dir,
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
read_cli_session_file,
|
||||
strip_progress_entries,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
)
|
||||
from .transcript_builder import TranscriptBuilder
|
||||
|
||||
|
||||
def _make_jsonl(*entries: dict) -> str:
|
||||
@@ -46,14 +35,6 @@ PROGRESS_ENTRY = {
|
||||
"data": {"type": "bash_progress", "stdout": "running..."},
|
||||
}
|
||||
|
||||
COMPACT_SUMMARY = {
|
||||
"type": "summary",
|
||||
"uuid": "cs1",
|
||||
"parentUuid": None,
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "Summary of previous conversation..."},
|
||||
}
|
||||
|
||||
VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG)
|
||||
|
||||
|
||||
@@ -256,121 +237,6 @@ class TestStripProgressEntries:
|
||||
# Should return just a newline (empty content stripped)
|
||||
assert result.strip() == ""
|
||||
|
||||
|
||||
# --- _cli_project_dir ---
|
||||
|
||||
|
||||
class TestCliProjectDir:
|
||||
def test_returns_path_for_valid_cwd(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
|
||||
projects = tmp_path / "projects"
|
||||
projects.mkdir()
|
||||
result = _cli_project_dir("/tmp/copilot-abc")
|
||||
assert result is not None
|
||||
assert "projects" in result
|
||||
|
||||
def test_returns_none_for_path_traversal(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
|
||||
projects = tmp_path / "projects"
|
||||
projects.mkdir()
|
||||
# A cwd that encodes to something with .. shouldn't escape
|
||||
result = _cli_project_dir("/tmp/copilot-test")
|
||||
# Should return a valid path (no traversal possible with alphanum encoding)
|
||||
assert result is None or result.startswith(str(projects))
|
||||
|
||||
|
||||
# --- read_cli_session_file ---
|
||||
|
||||
|
||||
class TestReadCliSessionFile:
|
||||
@pytest.mark.asyncio
|
||||
async def test_reads_session_file(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
|
||||
# Create the CLI project directory structure
|
||||
cwd = "/tmp/copilot-testread"
|
||||
import re
|
||||
|
||||
encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
project_dir = tmp_path / "projects" / encoded
|
||||
project_dir.mkdir(parents=True)
|
||||
# Write a session file
|
||||
session_file = project_dir / "test-session.jsonl"
|
||||
session_file.write_text(json.dumps(ASST_MSG) + "\n")
|
||||
|
||||
result = await read_cli_session_file(cwd)
|
||||
assert result is not None
|
||||
assert "assistant" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_no_files(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
|
||||
cwd = "/tmp/copilot-nofiles"
|
||||
import re
|
||||
|
||||
encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
project_dir = tmp_path / "projects" / encoded
|
||||
project_dir.mkdir(parents=True)
|
||||
# No jsonl files
|
||||
result = await read_cli_session_file(cwd)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_dir_missing(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
|
||||
(tmp_path / "projects").mkdir()
|
||||
result = await read_cli_session_file("/tmp/copilot-nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- _transcript_to_messages / _messages_to_transcript ---
|
||||
|
||||
|
||||
class TestTranscriptMessageConversion:
|
||||
def test_roundtrip_preserves_roles(self):
|
||||
transcript = _make_jsonl(USER_MSG, ASST_MSG)
|
||||
messages = _transcript_to_messages(transcript)
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["role"] == "user"
|
||||
assert messages[1]["role"] == "assistant"
|
||||
|
||||
def test_messages_to_transcript_produces_valid_jsonl(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "hello"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
assert validate_transcript(result) is True
|
||||
|
||||
def test_strips_strippable_types(self):
|
||||
transcript = _make_jsonl(
|
||||
{"type": "progress", "uuid": "p1", "message": {"role": "user"}},
|
||||
USER_MSG,
|
||||
ASST_MSG,
|
||||
)
|
||||
messages = _transcript_to_messages(transcript)
|
||||
assert len(messages) == 2 # progress entry skipped
|
||||
|
||||
def test_flattens_assistant_content_blocks(self):
|
||||
asst_with_blocks = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "hello"},
|
||||
{"type": "tool_use", "name": "bash"},
|
||||
],
|
||||
},
|
||||
}
|
||||
messages = _transcript_to_messages(_make_jsonl(asst_with_blocks))
|
||||
assert len(messages) == 1
|
||||
assert "hello" in messages[0]["content"]
|
||||
assert "[tool_use: bash]" in messages[0]["content"]
|
||||
|
||||
def test_empty_messages_returns_empty(self):
|
||||
result = _messages_to_transcript([])
|
||||
assert result == ""
|
||||
|
||||
def test_no_strippable_entries(self):
|
||||
"""When there's nothing to strip, output matches input structure."""
|
||||
content = _make_jsonl(USER_MSG, ASST_MSG)
|
||||
@@ -416,654 +282,3 @@ class TestTranscriptMessageConversion:
|
||||
lines = result.strip().split("\n")
|
||||
asst_entry = json.loads(lines[-1])
|
||||
assert asst_entry["parentUuid"] == "u1" # reparented
|
||||
|
||||
|
||||
# --- TranscriptBuilder ---
|
||||
|
||||
|
||||
class TestTranscriptBuilderReplaceEntries:
|
||||
"""Tests for TranscriptBuilder.replace_entries — the compaction sync path."""
|
||||
|
||||
def test_replace_entries_with_valid_content(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hello")
|
||||
builder.append_assistant([{"type": "text", "text": "world"}])
|
||||
assert builder.entry_count == 2
|
||||
|
||||
# Replace with compacted content (one user + one assistant)
|
||||
compacted = _make_jsonl(USER_MSG, ASST_MSG)
|
||||
builder.replace_entries(compacted)
|
||||
assert builder.entry_count == 2
|
||||
|
||||
def test_replace_entries_keeps_old_on_corrupt_content(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hello")
|
||||
assert builder.entry_count == 1
|
||||
|
||||
# Corrupt content that fails to parse
|
||||
builder.replace_entries("not valid json at all\n")
|
||||
# Should still have old entries (load_previous skips invalid lines,
|
||||
# but if ALL lines are invalid, temp builder is empty → exception path)
|
||||
assert builder.entry_count >= 0 # doesn't crash
|
||||
|
||||
def test_replace_entries_with_empty_content(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hello")
|
||||
assert builder.entry_count == 1
|
||||
|
||||
builder.replace_entries("")
|
||||
# Empty content → load_previous returns early → temp is empty
|
||||
# replace_entries swaps to empty (0 entries)
|
||||
assert builder.entry_count == 0
|
||||
|
||||
def test_replace_entries_filters_strippable_types(self):
|
||||
"""Strippable types (progress, file-history-snapshot) are filtered out."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hello")
|
||||
|
||||
content = _make_jsonl(
|
||||
{"type": "progress", "uuid": "p1", "message": {}},
|
||||
USER_MSG,
|
||||
ASST_MSG,
|
||||
)
|
||||
builder.replace_entries(content)
|
||||
# Only user + assistant should remain (progress filtered)
|
||||
assert builder.entry_count == 2
|
||||
|
||||
def test_replace_entries_preserves_uuids(self):
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_jsonl(USER_MSG, ASST_MSG)
|
||||
builder.replace_entries(content)
|
||||
|
||||
jsonl = builder.to_jsonl()
|
||||
lines = jsonl.strip().split("\n")
|
||||
first = json.loads(lines[0])
|
||||
assert first["uuid"] == "u1"
|
||||
|
||||
|
||||
class TestTranscriptBuilderBasic:
|
||||
def test_append_user_and_assistant(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hi")
|
||||
builder.append_assistant([{"type": "text", "text": "hello"}])
|
||||
assert builder.entry_count == 2
|
||||
assert not builder.is_empty
|
||||
|
||||
def test_to_jsonl_empty(self):
|
||||
builder = TranscriptBuilder()
|
||||
assert builder.to_jsonl() == ""
|
||||
assert builder.is_empty
|
||||
|
||||
def test_load_previous_and_append(self):
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_jsonl(USER_MSG, ASST_MSG)
|
||||
builder.load_previous(content)
|
||||
assert builder.entry_count == 2
|
||||
builder.append_user("new message")
|
||||
assert builder.entry_count == 3
|
||||
|
||||
def test_consecutive_assistant_entries_share_message_id(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hi")
|
||||
builder.append_assistant([{"type": "text", "text": "part1"}])
|
||||
builder.append_assistant([{"type": "text", "text": "part2"}])
|
||||
|
||||
jsonl = builder.to_jsonl()
|
||||
lines = jsonl.strip().split("\n")
|
||||
asst1 = json.loads(lines[1])
|
||||
asst2 = json.loads(lines[2])
|
||||
assert asst1["message"]["id"] == asst2["message"]["id"]
|
||||
|
||||
def test_non_consecutive_assistant_entries_get_new_id(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("hi")
|
||||
builder.append_assistant([{"type": "text", "text": "response1"}])
|
||||
builder.append_user("followup")
|
||||
builder.append_assistant([{"type": "text", "text": "response2"}])
|
||||
|
||||
jsonl = builder.to_jsonl()
|
||||
lines = jsonl.strip().split("\n")
|
||||
asst1 = json.loads(lines[1])
|
||||
asst2 = json.loads(lines[3])
|
||||
assert asst1["message"]["id"] != asst2["message"]["id"]
|
||||
|
||||
|
||||
class TestCompactSummaryRoundtrip:
|
||||
"""Verify isCompactSummary survives export→reload roundtrip."""
|
||||
|
||||
def test_load_previous_preserves_compact_summary(self):
|
||||
"""Compaction summary with type 'summary' should not be stripped."""
|
||||
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(content)
|
||||
# summary type is in STRIPPABLE_TYPES, but isCompactSummary keeps it
|
||||
assert builder.entry_count == 3
|
||||
|
||||
def test_export_reload_preserves_compact_summary(self):
|
||||
"""Critical: isCompactSummary must survive to_jsonl → load_previous."""
|
||||
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
|
||||
builder1 = TranscriptBuilder()
|
||||
builder1.load_previous(content)
|
||||
assert builder1.entry_count == 3
|
||||
|
||||
exported = builder1.to_jsonl()
|
||||
# Verify isCompactSummary is in the exported JSONL
|
||||
first_line = json.loads(exported.strip().split("\n")[0])
|
||||
assert first_line.get("isCompactSummary") is True
|
||||
|
||||
# Reload and verify it's still preserved
|
||||
builder2 = TranscriptBuilder()
|
||||
builder2.load_previous(exported)
|
||||
assert builder2.entry_count == 3
|
||||
|
||||
def test_strip_progress_preserves_compact_summary(self):
|
||||
"""strip_progress_entries should keep isCompactSummary entries."""
|
||||
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
|
||||
stripped = strip_progress_entries(content)
|
||||
entries = [json.loads(line) for line in stripped.strip().split("\n")]
|
||||
types = [e.get("type") for e in entries]
|
||||
assert "summary" in types # Not stripped despite being in STRIPPABLE_TYPES
|
||||
compact = [e for e in entries if e.get("isCompactSummary")]
|
||||
assert len(compact) == 1
|
||||
|
||||
def test_regular_summary_still_stripped(self):
|
||||
"""Non-compact summaries should still be stripped."""
|
||||
regular_summary = {
|
||||
"type": "summary",
|
||||
"uuid": "rs1",
|
||||
"summary": "Session summary",
|
||||
}
|
||||
content = _make_jsonl(regular_summary, USER_MSG, ASST_MSG)
|
||||
stripped = strip_progress_entries(content)
|
||||
entries = [json.loads(line) for line in stripped.strip().split("\n")]
|
||||
types = [e.get("type") for e in entries]
|
||||
assert "summary" not in types
|
||||
|
||||
def test_replace_entries_preserves_compact_summary(self):
|
||||
"""replace_entries should preserve isCompactSummary entries."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("old")
|
||||
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
|
||||
builder.replace_entries(content)
|
||||
assert builder.entry_count == 3
|
||||
|
||||
# Verify by re-exporting
|
||||
exported = builder.to_jsonl()
|
||||
first = json.loads(exported.strip().split("\n")[0])
|
||||
assert first.get("isCompactSummary") is True
|
||||
|
||||
|
||||
# --- _flatten_assistant_content ---
|
||||
|
||||
|
||||
class TestFlattenAssistantContent:
|
||||
def test_text_blocks(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "text", "text": "World"},
|
||||
]
|
||||
assert _flatten_assistant_content(blocks) == "Hello\nWorld"
|
||||
|
||||
def test_tool_use_blocks(self):
|
||||
blocks = [{"type": "tool_use", "name": "read_file", "id": "t1", "input": {}}]
|
||||
assert _flatten_assistant_content(blocks) == "[tool_use: read_file]"
|
||||
|
||||
def test_mixed_blocks(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "Let me read that."},
|
||||
{"type": "tool_use", "name": "read", "id": "t1", "input": {}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "Let me read that." in result
|
||||
assert "[tool_use: read]" in result
|
||||
|
||||
def test_string_blocks(self):
|
||||
"""Plain strings in the list should be included."""
|
||||
assert _flatten_assistant_content(["hello", "world"]) == "hello\nworld"
|
||||
|
||||
def test_empty_list(self):
|
||||
assert _flatten_assistant_content([]) == ""
|
||||
|
||||
def test_tool_use_missing_name(self):
|
||||
blocks = [{"type": "tool_use", "id": "t1", "input": {}}]
|
||||
assert _flatten_assistant_content(blocks) == "[tool_use: ?]"
|
||||
|
||||
|
||||
# --- _flatten_tool_result_content ---
|
||||
|
||||
|
||||
class TestFlattenToolResultContent:
|
||||
def test_tool_result_with_text(self):
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "t1",
|
||||
"content": [{"type": "text", "text": "file contents here"}],
|
||||
}
|
||||
]
|
||||
assert _flatten_tool_result_content(blocks) == "file contents here"
|
||||
|
||||
def test_tool_result_with_string_content(self):
|
||||
blocks = [
|
||||
{"type": "tool_result", "tool_use_id": "t1", "content": "simple result"}
|
||||
]
|
||||
assert _flatten_tool_result_content(blocks) == "simple result"
|
||||
|
||||
def test_tool_result_with_nested_list(self):
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "t1",
|
||||
"content": [
|
||||
{"type": "text", "text": "line 1"},
|
||||
{"type": "text", "text": "line 2"},
|
||||
],
|
||||
}
|
||||
]
|
||||
assert _flatten_tool_result_content(blocks) == "line 1\nline 2"
|
||||
|
||||
def test_text_blocks(self):
|
||||
blocks = [{"type": "text", "text": "some text"}]
|
||||
assert _flatten_tool_result_content(blocks) == "some text"
|
||||
|
||||
def test_string_items(self):
|
||||
assert _flatten_tool_result_content(["raw string"]) == "raw string"
|
||||
|
||||
def test_empty_list(self):
|
||||
assert _flatten_tool_result_content([]) == ""
|
||||
|
||||
def test_tool_result_none_text_uses_json(self):
|
||||
"""Dicts without text key fall back to json.dumps."""
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "t1",
|
||||
"content": [{"type": "image", "source": "data:..."}],
|
||||
}
|
||||
]
|
||||
result = _flatten_tool_result_content(blocks)
|
||||
assert "image" in result # json.dumps fallback includes the key
|
||||
|
||||
|
||||
# --- _transcript_to_messages ---
|
||||
|
||||
|
||||
class TestTranscriptToMessages:
|
||||
def test_basic_conversation(self):
|
||||
content = _make_jsonl(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hello"},
|
||||
},
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hi there"}],
|
||||
},
|
||||
},
|
||||
)
|
||||
msgs = _transcript_to_messages(content)
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0] == {"role": "user", "content": "hello"}
|
||||
assert msgs[1] == {"role": "assistant", "content": "hi there"}
|
||||
|
||||
def test_strips_progress_entries(self):
|
||||
content = _make_jsonl(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
},
|
||||
{
|
||||
"type": "progress",
|
||||
"uuid": "p1",
|
||||
"message": {"role": "user", "content": "..."},
|
||||
},
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "ok"}],
|
||||
},
|
||||
},
|
||||
)
|
||||
msgs = _transcript_to_messages(content)
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0]["role"] == "user"
|
||||
assert msgs[1]["role"] == "assistant"
|
||||
|
||||
def test_preserves_compact_summaries(self):
|
||||
content = _make_jsonl(
|
||||
{
|
||||
"type": "summary",
|
||||
"uuid": "cs1",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "Summary of previous..."},
|
||||
},
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
},
|
||||
)
|
||||
msgs = _transcript_to_messages(content)
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0]["content"] == "Summary of previous..."
|
||||
|
||||
def test_strips_regular_summary(self):
|
||||
content = _make_jsonl(
|
||||
{
|
||||
"type": "summary",
|
||||
"uuid": "s1",
|
||||
"message": {"role": "user", "content": "Session summary"},
|
||||
},
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
},
|
||||
)
|
||||
msgs = _transcript_to_messages(content)
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0]["content"] == "hi"
|
||||
|
||||
def test_skips_entries_without_role(self):
|
||||
content = _make_jsonl(
|
||||
{"type": "user", "uuid": "u1", "message": {}},
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u2",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
},
|
||||
)
|
||||
msgs = _transcript_to_messages(content)
|
||||
assert len(msgs) == 1
|
||||
|
||||
def test_tool_result_content(self):
|
||||
content = _make_jsonl(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "t1",
|
||||
"content": "file contents",
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
msgs = _transcript_to_messages(content)
|
||||
assert len(msgs) == 1
|
||||
assert "file contents" in msgs[0]["content"]
|
||||
|
||||
def test_empty_content(self):
|
||||
assert _transcript_to_messages("") == []
|
||||
assert _transcript_to_messages(" \n ") == []
|
||||
|
||||
def test_invalid_json_lines_skipped(self):
|
||||
content = '{"type":"user","uuid":"u1","message":{"role":"user","content":"hi"}}\nnot json\n'
|
||||
msgs = _transcript_to_messages(content)
|
||||
assert len(msgs) == 1
|
||||
|
||||
|
||||
# --- _messages_to_transcript ---
|
||||
|
||||
|
||||
class TestMessagesToTranscript:
|
||||
def test_basic_roundtrip_structure(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi there"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
assert result.endswith("\n")
|
||||
lines = [json.loads(line) for line in result.strip().split("\n")]
|
||||
assert len(lines) == 2
|
||||
|
||||
# User entry
|
||||
assert lines[0]["type"] == "user"
|
||||
assert lines[0]["message"]["role"] == "user"
|
||||
assert lines[0]["message"]["content"] == "hello"
|
||||
assert lines[0]["parentUuid"] is None
|
||||
|
||||
# Assistant entry
|
||||
assert lines[1]["type"] == "assistant"
|
||||
assert lines[1]["message"]["role"] == "assistant"
|
||||
assert lines[1]["message"]["content"] == [{"type": "text", "text": "hi there"}]
|
||||
assert lines[1]["message"]["id"].startswith(COMPACT_MSG_ID_PREFIX)
|
||||
assert lines[1]["parentUuid"] == lines[0]["uuid"]
|
||||
|
||||
def test_parent_uuid_chain(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "q1"},
|
||||
{"role": "assistant", "content": "a1"},
|
||||
{"role": "user", "content": "q2"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
lines = [json.loads(line) for line in result.strip().split("\n")]
|
||||
assert lines[0]["parentUuid"] is None
|
||||
assert lines[1]["parentUuid"] == lines[0]["uuid"]
|
||||
assert lines[2]["parentUuid"] == lines[1]["uuid"]
|
||||
|
||||
def test_empty_messages(self):
|
||||
assert _messages_to_transcript([]) == ""
|
||||
|
||||
def test_assistant_empty_content(self):
|
||||
messages = [{"role": "assistant", "content": ""}]
|
||||
result = _messages_to_transcript(messages)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["message"]["content"] == []
|
||||
|
||||
def test_output_is_valid_transcript(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "world"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
assert validate_transcript(result)
|
||||
|
||||
|
||||
# --- _transcript_to_messages + _messages_to_transcript roundtrip ---
|
||||
|
||||
|
||||
class TestTranscriptCompactionRoundtrip:
|
||||
def test_content_preserved_through_roundtrip(self):
|
||||
"""Messages→transcript→messages preserves content."""
|
||||
original = [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": "4"},
|
||||
{"role": "user", "content": "Thanks"},
|
||||
]
|
||||
transcript = _messages_to_transcript(original)
|
||||
recovered = _transcript_to_messages(transcript)
|
||||
assert len(recovered) == len(original)
|
||||
for orig, rec in zip(original, recovered):
|
||||
assert orig["role"] == rec["role"]
|
||||
assert orig["content"] == rec["content"]
|
||||
|
||||
def test_full_transcript_to_messages_and_back(self):
|
||||
"""Real-ish JSONL → messages → transcript → messages roundtrip."""
|
||||
source = _make_jsonl(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "explain python"},
|
||||
},
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Python is a programming language."}
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u2",
|
||||
"parentUuid": "a1",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "t1",
|
||||
"content": "output of ls",
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
msgs1 = _transcript_to_messages(source)
|
||||
assert len(msgs1) == 3
|
||||
|
||||
rebuilt = _messages_to_transcript(msgs1)
|
||||
msgs2 = _transcript_to_messages(rebuilt)
|
||||
assert len(msgs2) == len(msgs1)
|
||||
for m1, m2 in zip(msgs1, msgs2):
|
||||
assert m1["role"] == m2["role"]
|
||||
# Content may differ in format (list vs string) but text is preserved
|
||||
assert m1["content"] in m2["content"] or m2["content"] in m1["content"]
|
||||
|
||||
|
||||
# --- compact_transcript ---
|
||||
|
||||
|
||||
class TestCompactTranscript:
|
||||
@pytest.mark.asyncio
|
||||
async def test_too_few_messages_returns_none(self):
|
||||
"""Transcripts with < 2 messages can't be compacted."""
|
||||
single = _make_jsonl(
|
||||
{"type": "user", "uuid": "u1", "message": {"role": "user", "content": "hi"}}
|
||||
)
|
||||
result = await compact_transcript(single)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_transcript_returns_none(self):
|
||||
result = await compact_transcript("")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_produces_valid_transcript(self, monkeypatch):
|
||||
"""When compress_context compacts, result should be valid JSONL."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
mock_result = CompressResult(
|
||||
messages=[
|
||||
{"role": "user", "content": "Summary of conversation"},
|
||||
{"role": "assistant", "content": "Acknowledged"},
|
||||
],
|
||||
token_count=50,
|
||||
was_compacted=True,
|
||||
original_token_count=5000,
|
||||
messages_summarized=10,
|
||||
messages_dropped=5,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
AsyncMock(return_value=mock_result),
|
||||
)
|
||||
|
||||
source = _make_jsonl(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "msg1"},
|
||||
},
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "reply1"}],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u2",
|
||||
"message": {"role": "user", "content": "msg2"},
|
||||
},
|
||||
)
|
||||
result = await compact_transcript(source)
|
||||
assert result is not None
|
||||
assert validate_transcript(result)
|
||||
|
||||
# Verify compacted content
|
||||
msgs = _transcript_to_messages(result)
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0]["content"] == "Summary of conversation"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_compaction_needed_returns_original(self, monkeypatch):
|
||||
"""When compress_context says no compaction needed, return original."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
mock_result = CompressResult(
|
||||
messages=[], token_count=100, was_compacted=False, original_token_count=100
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
AsyncMock(return_value=mock_result),
|
||||
)
|
||||
|
||||
source = _make_jsonl(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
},
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hello"}],
|
||||
},
|
||||
},
|
||||
)
|
||||
result = await compact_transcript(source)
|
||||
assert result == source # Unchanged
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compression_failure_returns_none(self, monkeypatch):
|
||||
"""When _run_compression raises, compact_transcript returns None."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
AsyncMock(side_effect=RuntimeError("LLM unavailable")),
|
||||
)
|
||||
|
||||
source = _make_jsonl(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
},
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hello"}],
|
||||
},
|
||||
},
|
||||
)
|
||||
result = await compact_transcript(source)
|
||||
assert result is None
|
||||
|
||||
@@ -23,11 +23,6 @@ from typing import Any, Literal
|
||||
|
||||
import orjson
|
||||
|
||||
from backend.api.model import CopilotCompletionPayload
|
||||
from backend.data.notification_bus import (
|
||||
AsyncRedisNotificationEventBus,
|
||||
NotificationEvent,
|
||||
)
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
from .config import ChatConfig
|
||||
@@ -43,7 +38,6 @@ from .response_model import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
_notification_bus = AsyncRedisNotificationEventBus()
|
||||
|
||||
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
|
||||
_local_sessions: dict[str, asyncio.Task] = {}
|
||||
@@ -751,29 +745,6 @@ async def mark_session_completed(
|
||||
|
||||
# Clean up local session reference if exists
|
||||
_local_sessions.pop(session_id, None)
|
||||
|
||||
# Publish copilot completion notification via WebSocket
|
||||
if meta:
|
||||
parsed = _parse_session_meta(meta, session_id)
|
||||
if parsed.user_id:
|
||||
try:
|
||||
await _notification_bus.publish(
|
||||
NotificationEvent(
|
||||
user_id=parsed.user_id,
|
||||
payload=CopilotCompletionPayload(
|
||||
type="copilot_completion",
|
||||
event="session_completed",
|
||||
session_id=session_id,
|
||||
status=status,
|
||||
),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to publish copilot completion notification "
|
||||
f"for session {session_id}: {e}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -829,12 +829,8 @@ class AgentFixer:
|
||||
|
||||
For nodes whose block has category "AI", this function ensures that the
|
||||
input_default has a "model" parameter set to one of the allowed models.
|
||||
If missing or set to an unsupported value, it is replaced with the
|
||||
appropriate default.
|
||||
|
||||
Blocks that define their own ``enum`` constraint on the ``model`` field
|
||||
in their inputSchema (e.g. PerplexityBlock) are validated against that
|
||||
enum instead of the generic allowed set.
|
||||
If missing or set to an unsupported value, it is replaced with
|
||||
default_model.
|
||||
|
||||
Args:
|
||||
agent: The agent dictionary to fix
|
||||
@@ -844,7 +840,7 @@ class AgentFixer:
|
||||
Returns:
|
||||
The fixed agent dictionary
|
||||
"""
|
||||
generic_allowed_models = {"gpt-4o", "claude-opus-4-6"}
|
||||
allowed_models = {"gpt-4o", "claude-opus-4-6"}
|
||||
|
||||
# Create a mapping of block_id to block for quick lookup
|
||||
block_map = {block.get("id"): block for block in blocks}
|
||||
@@ -872,36 +868,20 @@ class AgentFixer:
|
||||
input_default = node.get("input_default", {})
|
||||
current_model = input_default.get("model")
|
||||
|
||||
# Determine allowed models and default from the block's schema.
|
||||
# Blocks with a block-specific enum on the model field (e.g.
|
||||
# PerplexityBlock) use their own enum values; others use the
|
||||
# generic set.
|
||||
model_schema = (
|
||||
block.get("inputSchema", {}).get("properties", {}).get("model", {})
|
||||
)
|
||||
block_model_enum = model_schema.get("enum")
|
||||
|
||||
if block_model_enum:
|
||||
allowed_models = set(block_model_enum)
|
||||
fallback_model = model_schema.get("default", block_model_enum[0])
|
||||
else:
|
||||
allowed_models = generic_allowed_models
|
||||
fallback_model = default_model
|
||||
|
||||
if current_model not in allowed_models:
|
||||
block_name = block.get("name", "Unknown AI Block")
|
||||
if current_model is None:
|
||||
self.add_fix_log(
|
||||
f"Added model parameter '{fallback_model}' to AI "
|
||||
f"Added model parameter '{default_model}' to AI "
|
||||
f"block node {node_id} ({block_name})"
|
||||
)
|
||||
else:
|
||||
self.add_fix_log(
|
||||
f"Replaced unsupported model '{current_model}' "
|
||||
f"with '{fallback_model}' on AI block node "
|
||||
f"with '{default_model}' on AI block node "
|
||||
f"{node_id} ({block_name})"
|
||||
)
|
||||
input_default["model"] = fallback_model
|
||||
input_default["model"] = default_model
|
||||
node["input_default"] = input_default
|
||||
fixed_count += 1
|
||||
|
||||
|
||||
@@ -475,111 +475,6 @@ class TestFixAiModelParameter:
|
||||
|
||||
assert result["nodes"][0]["input_default"]["model"] == "claude-opus-4-6"
|
||||
|
||||
def test_block_specific_enum_uses_block_default(self):
|
||||
"""Blocks with their own model enum (e.g. PerplexityBlock) should use
|
||||
the block's allowed models and default, not the generic ones."""
|
||||
fixer = AgentFixer()
|
||||
block_id = generate_uuid()
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id=block_id,
|
||||
input_default={"model": "gpt-5.2-2025-12-11"},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
blocks = [
|
||||
{
|
||||
"id": block_id,
|
||||
"name": "PerplexityBlock",
|
||||
"categories": [{"category": "AI"}],
|
||||
"inputSchema": {
|
||||
"properties": {
|
||||
"model": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"perplexity/sonar",
|
||||
"perplexity/sonar-pro",
|
||||
"perplexity/sonar-deep-research",
|
||||
],
|
||||
"default": "perplexity/sonar",
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
result = fixer.fix_ai_model_parameter(agent, blocks)
|
||||
|
||||
assert result["nodes"][0]["input_default"]["model"] == "perplexity/sonar"
|
||||
|
||||
def test_block_specific_enum_valid_model_unchanged(self):
|
||||
"""A valid block-specific model should not be replaced."""
|
||||
fixer = AgentFixer()
|
||||
block_id = generate_uuid()
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id=block_id,
|
||||
input_default={"model": "perplexity/sonar-pro"},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
blocks = [
|
||||
{
|
||||
"id": block_id,
|
||||
"name": "PerplexityBlock",
|
||||
"categories": [{"category": "AI"}],
|
||||
"inputSchema": {
|
||||
"properties": {
|
||||
"model": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"perplexity/sonar",
|
||||
"perplexity/sonar-pro",
|
||||
"perplexity/sonar-deep-research",
|
||||
],
|
||||
"default": "perplexity/sonar",
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
result = fixer.fix_ai_model_parameter(agent, blocks)
|
||||
|
||||
assert result["nodes"][0]["input_default"]["model"] == "perplexity/sonar-pro"
|
||||
|
||||
def test_block_specific_enum_missing_model_gets_block_default(self):
|
||||
"""Missing model on a block with enum should use the block's default."""
|
||||
fixer = AgentFixer()
|
||||
block_id = generate_uuid()
|
||||
node = _make_node(node_id="n1", block_id=block_id, input_default={})
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
blocks = [
|
||||
{
|
||||
"id": block_id,
|
||||
"name": "PerplexityBlock",
|
||||
"categories": [{"category": "AI"}],
|
||||
"inputSchema": {
|
||||
"properties": {
|
||||
"model": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"perplexity/sonar",
|
||||
"perplexity/sonar-pro",
|
||||
"perplexity/sonar-deep-research",
|
||||
],
|
||||
"default": "perplexity/sonar",
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
result = fixer.fix_ai_model_parameter(agent, blocks)
|
||||
|
||||
assert result["nodes"][0]["input_default"]["model"] == "perplexity/sonar"
|
||||
|
||||
|
||||
class TestFixAgentExecutorBlocks:
|
||||
"""Tests for fix_agent_executor_blocks."""
|
||||
|
||||
@@ -21,11 +21,9 @@ Lifecycle
|
||||
Cost control
|
||||
------------
|
||||
Sandboxes are created with a configurable ``on_timeout`` lifecycle action
|
||||
(default: ``"pause"``) and ``auto_resume`` (default: ``True``). The explicit
|
||||
per-turn ``pause_sandbox()`` call is the primary mechanism; the lifecycle
|
||||
timeout is a safety net (default: 5 min). ``auto_resume`` ensures that paused
|
||||
sandboxes wake transparently on SDK activity, making the aggressive safety-net
|
||||
timeout safe. Paused sandboxes are free.
|
||||
(default: ``"pause"``). The explicit per-turn ``pause_sandbox()`` call is the
|
||||
primary mechanism; the lifecycle setting is a safety net. Paused sandboxes are
|
||||
free.
|
||||
|
||||
The sandbox_id is stored in Redis. The same key doubles as a creation lock:
|
||||
a ``"creating"`` sentinel value is written with a short TTL while a new sandbox
|
||||
@@ -42,7 +40,6 @@ import logging
|
||||
from typing import Any, Awaitable, Callable, Literal
|
||||
|
||||
from e2b import AsyncSandbox
|
||||
from e2b.sandbox.sandbox_api import SandboxLifecycle
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
@@ -119,10 +116,9 @@ async def get_or_create_sandbox(
|
||||
removes the need for a separate lock key.
|
||||
|
||||
*timeout* controls how long the e2b sandbox may run continuously before
|
||||
the ``on_timeout`` lifecycle rule fires (default: 5 min).
|
||||
the ``on_timeout`` lifecycle rule fires (default: 3 h).
|
||||
*on_timeout* controls what happens on timeout: ``"pause"`` (default, free)
|
||||
or ``"kill"``. When ``"pause"``, ``auto_resume`` is enabled so paused
|
||||
sandboxes wake transparently on SDK activity.
|
||||
or ``"kill"``.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = _sandbox_key(session_id)
|
||||
@@ -160,15 +156,11 @@ async def get_or_create_sandbox(
|
||||
|
||||
# We hold the slot — create the sandbox.
|
||||
try:
|
||||
lifecycle = SandboxLifecycle(
|
||||
on_timeout=on_timeout,
|
||||
auto_resume=on_timeout == "pause",
|
||||
)
|
||||
sandbox = await AsyncSandbox.create(
|
||||
template=template,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
lifecycle=lifecycle,
|
||||
lifecycle={"on_timeout": on_timeout},
|
||||
)
|
||||
try:
|
||||
await _set_stored_sandbox_id(session_id, sandbox.sandbox_id)
|
||||
|
||||
@@ -157,17 +157,14 @@ class TestGetOrCreateSandbox:
|
||||
|
||||
assert result is new_sb
|
||||
mock_cls.create.assert_awaited_once()
|
||||
# Verify lifecycle: pause + auto_resume enabled
|
||||
# Verify lifecycle param is set
|
||||
_, kwargs = mock_cls.create.call_args
|
||||
assert kwargs.get("lifecycle") == {
|
||||
"on_timeout": "pause",
|
||||
"auto_resume": True,
|
||||
}
|
||||
assert kwargs.get("lifecycle") == {"on_timeout": "pause"}
|
||||
# sandbox_id should be saved to Redis
|
||||
redis.set.assert_awaited()
|
||||
|
||||
def test_create_with_on_timeout_kill(self):
|
||||
"""on_timeout='kill' disables auto_resume automatically."""
|
||||
"""on_timeout='kill' is passed through to AsyncSandbox.create."""
|
||||
new_sb = _mock_sandbox("sb-new")
|
||||
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
|
||||
with (
|
||||
@@ -182,10 +179,7 @@ class TestGetOrCreateSandbox:
|
||||
)
|
||||
|
||||
_, kwargs = mock_cls.create.call_args
|
||||
assert kwargs.get("lifecycle") == {
|
||||
"on_timeout": "kill",
|
||||
"auto_resume": False,
|
||||
}
|
||||
assert kwargs.get("lifecycle") == {"on_timeout": "kill"}
|
||||
|
||||
def test_create_failure_releases_slot(self):
|
||||
"""If sandbox creation fails, the Redis creation slot is deleted."""
|
||||
|
||||
@@ -8,16 +8,11 @@ from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
|
||||
from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.executor.utils import block_usage_cost
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.exceptions import BlockError, InsufficientBalanceError
|
||||
from backend.util.type import coerce_inputs_to_schema
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
from .models import BlockOutputResponse, ErrorResponse, ToolResponseBase
|
||||
from .utils import match_credentials_to_requirements
|
||||
@@ -25,26 +20,6 @@ from .utils import match_credentials_to_requirements
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _get_credits(user_id: str) -> int:
|
||||
"""Get user credits using the adapter pattern (RPC when Prisma unavailable)."""
|
||||
if not db.is_connected():
|
||||
return await get_database_manager_async_client().get_credits(user_id)
|
||||
credit_model = await get_user_credit_model(user_id)
|
||||
return await credit_model.get_credits(user_id)
|
||||
|
||||
|
||||
async def _spend_credits(
|
||||
user_id: str, cost: int, metadata: UsageTransactionMetadata
|
||||
) -> int:
|
||||
"""Spend user credits using the adapter pattern (RPC when Prisma unavailable)."""
|
||||
if not db.is_connected():
|
||||
return await get_database_manager_async_client().spend_credits(
|
||||
user_id, cost, metadata
|
||||
)
|
||||
credit_model = await get_user_credit_model(user_id)
|
||||
return await credit_model.spend_credits(user_id, cost, metadata)
|
||||
|
||||
|
||||
def get_inputs_from_schema(
|
||||
input_schema: dict[str, Any],
|
||||
exclude_fields: set[str] | None = None,
|
||||
@@ -136,23 +111,6 @@ async def execute_block(
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Coerce non-matching data types to the expected input schema.
|
||||
coerce_inputs_to_schema(input_data, block.input_schema)
|
||||
|
||||
# Pre-execution credit check
|
||||
cost, cost_filter = block_usage_cost(block, input_data)
|
||||
has_cost = cost > 0
|
||||
if has_cost:
|
||||
balance = await _get_credits(user_id)
|
||||
if balance < cost:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Insufficient credits to run '{block.name}'. "
|
||||
"Please top up your credits to continue."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Execute the block and collect outputs
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in block.execute(
|
||||
@@ -161,37 +119,6 @@ async def execute_block(
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
# Charge credits for block execution
|
||||
if has_cost:
|
||||
try:
|
||||
await _spend_credits(
|
||||
user_id=user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
block_id=block_id,
|
||||
block=block.name,
|
||||
input=cost_filter,
|
||||
reason="copilot_block_execution",
|
||||
),
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
logger.warning(
|
||||
"Post-exec credit charge failed for block %s (cost=%d)",
|
||||
block.name,
|
||||
cost,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Insufficient credits to complete '{block.name}'. "
|
||||
"Please top up your credits to continue."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
@@ -202,16 +129,16 @@ async def execute_block(
|
||||
)
|
||||
|
||||
except BlockError as e:
|
||||
logger.warning("Block execution failed: %s", e)
|
||||
logger.warning(f"Block execution failed: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Block execution failed: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error executing block: %s", e, exc_info=True)
|
||||
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="An unexpected error occurred while executing the block",
|
||||
message=f"Failed to execute block: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -1,506 +0,0 @@
|
||||
"""Tests for execute_block — credit charging and type coercion."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.tools.helpers import execute_block
|
||||
from backend.copilot.tools.models import BlockOutputResponse, ErrorResponse
|
||||
|
||||
_USER = "test-user-helpers"
|
||||
_SESSION = "test-session-helpers"
|
||||
|
||||
|
||||
def _make_block(block_id: str = "block-1", name: str = "TestBlock"):
|
||||
"""Create a minimal mock block for execute_block()."""
|
||||
mock = MagicMock()
|
||||
mock.id = block_id
|
||||
mock.name = name
|
||||
mock.block_type = BlockType.STANDARD
|
||||
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.get_credentials_fields_info.return_value = {}
|
||||
|
||||
async def _execute(
|
||||
input_data: dict, **kwargs: Any
|
||||
) -> AsyncIterator[tuple[str, Any]]:
|
||||
yield "result", "ok"
|
||||
|
||||
mock.execute = _execute
|
||||
return mock
|
||||
|
||||
|
||||
def _patch_workspace():
|
||||
"""Patch workspace_db to return a mock workspace."""
|
||||
mock_workspace = MagicMock()
|
||||
mock_workspace.id = "ws-1"
|
||||
mock_ws_db = MagicMock()
|
||||
mock_ws_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
|
||||
return patch("backend.copilot.tools.helpers.workspace_db", return_value=mock_ws_db)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Credit charging tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestExecuteBlockCreditCharging:
|
||||
async def test_charges_credits_when_cost_is_positive(self):
|
||||
"""Block with cost > 0 should call spend_credits after execution."""
|
||||
block = _make_block()
|
||||
mock_spend = AsyncMock()
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(10, {"key": "val"}),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._get_credits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=100,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._spend_credits",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=mock_spend,
|
||||
),
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id="block-1",
|
||||
input_data={"text": "hello"},
|
||||
user_id=_USER,
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(result, BlockOutputResponse)
|
||||
assert result.success is True
|
||||
mock_spend.assert_awaited_once()
|
||||
call_kwargs = mock_spend.call_args.kwargs
|
||||
assert call_kwargs["cost"] == 10
|
||||
assert call_kwargs["metadata"].reason == "copilot_block_execution"
|
||||
|
||||
async def test_returns_error_when_insufficient_credits_before_exec(self):
|
||||
"""Pre-execution check should return ErrorResponse when balance < cost."""
|
||||
block = _make_block()
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(10, {}),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._get_credits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=5, # balance < cost (10)
|
||||
),
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id="block-1",
|
||||
input_data={},
|
||||
user_id=_USER,
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "Insufficient credits" in result.message
|
||||
|
||||
async def test_no_charge_when_cost_is_zero(self):
|
||||
"""Block with cost 0 should not call spend_credits."""
|
||||
block = _make_block()
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(0, {}),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._get_credits",
|
||||
) as mock_get_credits,
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._spend_credits",
|
||||
) as mock_spend_credits,
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id="block-1",
|
||||
input_data={},
|
||||
user_id=_USER,
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(result, BlockOutputResponse)
|
||||
assert result.success is True
|
||||
# Credit functions should not be called at all for zero-cost blocks
|
||||
mock_get_credits.assert_not_awaited()
|
||||
mock_spend_credits.assert_not_awaited()
|
||||
|
||||
async def test_returns_error_on_post_exec_insufficient_balance(self):
|
||||
"""If charging fails after execution, return ErrorResponse."""
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
block = _make_block()
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(10, {}),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._get_credits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=15, # passes pre-check
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._spend_credits",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=InsufficientBalanceError(
|
||||
"Low balance", _USER, 5, 10
|
||||
), # fails during actual charge (race with concurrent spend)
|
||||
),
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id="block-1",
|
||||
input_data={},
|
||||
user_id=_USER,
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "Insufficient credits" in result.message
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type coercion tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
|
||||
"""Create a mock input_schema with model_fields matching the given annotations."""
|
||||
schema = MagicMock()
|
||||
model_fields = {}
|
||||
for name, ann in annotations.items():
|
||||
field = MagicMock()
|
||||
field.annotation = ann
|
||||
model_fields[name] = field
|
||||
schema.model_fields = model_fields
|
||||
return schema
|
||||
|
||||
|
||||
def _make_coerce_block(
|
||||
block_id: str,
|
||||
name: str,
|
||||
annotations: dict[str, Any],
|
||||
outputs: dict[str, list[Any]] | None = None,
|
||||
) -> MagicMock:
|
||||
"""Create a mock block with typed annotations and a simple execute method."""
|
||||
block = MagicMock()
|
||||
block.id = block_id
|
||||
block.name = name
|
||||
block.input_schema = _make_block_schema(annotations)
|
||||
|
||||
captured_inputs: dict[str, Any] = {}
|
||||
|
||||
async def mock_execute(input_data: dict, **_kwargs: Any):
|
||||
captured_inputs.update(input_data)
|
||||
for output_name, values in (outputs or {"result": ["ok"]}).items():
|
||||
for v in values:
|
||||
yield output_name, v
|
||||
|
||||
block.execute = mock_execute
|
||||
block._captured_inputs = captured_inputs
|
||||
return block
|
||||
|
||||
|
||||
_TEST_SESSION_ID = "test-session-coerce"
|
||||
_TEST_USER_ID = "test-user-coerce"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_json_string_to_nested_list():
|
||||
"""JSON string → list[list[str]] (Google Sheets CSV import case)."""
|
||||
block = _make_coerce_block(
|
||||
"sheets-write",
|
||||
"Google Sheets Write",
|
||||
{"values": list[list[str]], "spreadsheet_id": str},
|
||||
)
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="sheets-write",
|
||||
input_data={
|
||||
"values": '[["Name","Score"],["Alice","90"],["Bob","85"]]',
|
||||
"spreadsheet_id": "abc123",
|
||||
},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert response.success is True
|
||||
assert block._captured_inputs["values"] == [
|
||||
["Name", "Score"],
|
||||
["Alice", "90"],
|
||||
["Bob", "85"],
|
||||
]
|
||||
assert isinstance(block._captured_inputs["values"], list)
|
||||
assert isinstance(block._captured_inputs["values"][0], list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_json_string_to_list():
|
||||
"""JSON string → list[str]."""
|
||||
block = _make_coerce_block(
|
||||
"list-block",
|
||||
"List Block",
|
||||
{"items": list[str]},
|
||||
)
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="list-block",
|
||||
input_data={"items": '["a","b","c"]'},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-2",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert block._captured_inputs["items"] == ["a", "b", "c"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_json_string_to_dict():
|
||||
"""JSON string → dict[str, str]."""
|
||||
block = _make_coerce_block(
|
||||
"dict-block",
|
||||
"Dict Block",
|
||||
{"config": dict[str, str]},
|
||||
)
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="dict-block",
|
||||
input_data={"config": '{"key": "value", "foo": "bar"}'},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-3",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert block._captured_inputs["config"] == {"key": "value", "foo": "bar"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_coercion_when_type_matches():
|
||||
"""Already-correct types pass through without coercion."""
|
||||
block = _make_coerce_block(
|
||||
"pass-through",
|
||||
"Pass Through",
|
||||
{"values": list[list[str]], "name": str},
|
||||
)
|
||||
|
||||
original_values = [["a", "b"], ["c", "d"]]
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="pass-through",
|
||||
input_data={"values": original_values, "name": "test"},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-4",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert block._captured_inputs["values"] == original_values
|
||||
assert block._captured_inputs["name"] == "test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_string_to_int():
|
||||
"""String number → int."""
|
||||
block = _make_coerce_block(
|
||||
"int-block",
|
||||
"Int Block",
|
||||
{"count": int},
|
||||
)
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="int-block",
|
||||
input_data={"count": "42"},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-5",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert block._captured_inputs["count"] == 42
|
||||
assert isinstance(block._captured_inputs["count"], int)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_skips_none_values():
|
||||
"""None values are not coerced (they may be optional fields)."""
|
||||
block = _make_coerce_block(
|
||||
"optional-block",
|
||||
"Optional Block",
|
||||
{"data": list[str], "label": str},
|
||||
)
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="optional-block",
|
||||
input_data={"label": "test"},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-6",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert "data" not in block._captured_inputs
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_union_type_preserves_valid_member():
|
||||
"""Union-typed fields should not be coerced when the value matches a member."""
|
||||
block = _make_coerce_block(
|
||||
"union-block",
|
||||
"Union Block",
|
||||
{"content": str | list[str]},
|
||||
)
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="union-block",
|
||||
input_data={"content": ["a", "b"]},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-7",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert block._captured_inputs["content"] == ["a", "b"]
|
||||
assert isinstance(block._captured_inputs["content"], list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_inner_elements_of_generic():
|
||||
"""Inner elements of generic containers are recursively coerced."""
|
||||
block = _make_coerce_block(
|
||||
"inner-coerce",
|
||||
"Inner Coerce",
|
||||
{"values": list[str]},
|
||||
)
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="ws-1")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
):
|
||||
response = await execute_block(
|
||||
block=block,
|
||||
block_id="inner-coerce",
|
||||
input_data={"values": [1, 2, 3]},
|
||||
user_id=_TEST_USER_ID,
|
||||
session_id=_TEST_SESSION_ID,
|
||||
node_exec_id="exec-8",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert block._captured_inputs["values"] == ["1", "2", "3"]
|
||||
assert all(isinstance(v, str) for v in block._captured_inputs["values"])
|
||||
@@ -34,11 +34,6 @@ logger = logging.getLogger(__name__)
|
||||
_AUTH_STATUS_CODES = {401, 403}
|
||||
|
||||
|
||||
def _service_name(host: str) -> str:
|
||||
"""Strip the 'mcp.' prefix from an MCP hostname: 'mcp.sentry.dev' → 'sentry.dev'"""
|
||||
return host[4:] if host.startswith("mcp.") else host
|
||||
|
||||
|
||||
class RunMCPToolTool(BaseTool):
|
||||
"""
|
||||
Tool for discovering and executing tools on any MCP server.
|
||||
@@ -308,8 +303,8 @@ class RunMCPToolTool(BaseTool):
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Unable to connect to {_service_name(server_host(server_url))} "
|
||||
"— no credentials configured."
|
||||
f"The MCP server at {server_host(server_url)} requires authentication, "
|
||||
"but no credential configuration was found."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -317,13 +312,15 @@ class RunMCPToolTool(BaseTool):
|
||||
missing_creds_list = list(missing_creds_dict.values())
|
||||
|
||||
host = server_host(server_url)
|
||||
service = _service_name(host)
|
||||
return SetupRequirementsResponse(
|
||||
message=(f"To continue, sign in to {service} and approve access."),
|
||||
message=(
|
||||
f"The MCP server at {host} requires authentication. "
|
||||
"Please connect your credentials to continue."
|
||||
),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=server_url,
|
||||
agent_name=service,
|
||||
agent_name=f"MCP: {host}",
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_creds_dict,
|
||||
|
||||
@@ -756,4 +756,4 @@ async def test_build_setup_requirements_returns_setup_response():
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert result.setup_info.agent_id == _SERVER_URL
|
||||
assert "sign in" in result.message.lower()
|
||||
assert "authentication" in result.message.lower()
|
||||
|
||||
@@ -100,31 +100,19 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.OLLAMA_DOLPHIN: 1,
|
||||
LlmModel.OPENAI_GPT_OSS_120B: 1,
|
||||
LlmModel.OPENAI_GPT_OSS_20B: 1,
|
||||
LlmModel.GEMINI_2_5_PRO_PREVIEW: 4,
|
||||
LlmModel.GEMINI_2_5_PRO: 4,
|
||||
LlmModel.GEMINI_3_1_PRO_PREVIEW: 5,
|
||||
LlmModel.GEMINI_3_FLASH_PREVIEW: 2,
|
||||
LlmModel.GEMINI_3_PRO_PREVIEW: 5,
|
||||
LlmModel.GEMINI_2_5_FLASH: 1,
|
||||
LlmModel.GEMINI_2_0_FLASH: 1,
|
||||
LlmModel.GEMINI_3_1_FLASH_LITE_PREVIEW: 1,
|
||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: 1,
|
||||
LlmModel.GEMINI_2_0_FLASH_LITE: 1,
|
||||
LlmModel.MISTRAL_NEMO: 1,
|
||||
LlmModel.MISTRAL_LARGE_3: 2,
|
||||
LlmModel.MISTRAL_MEDIUM_3_1: 2,
|
||||
LlmModel.MISTRAL_SMALL_3_2: 1,
|
||||
LlmModel.CODESTRAL: 1,
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: 1,
|
||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: 3,
|
||||
LlmModel.COHERE_COMMAND_A_03_2025: 3,
|
||||
LlmModel.COHERE_COMMAND_A_TRANSLATE_08_2025: 3,
|
||||
LlmModel.COHERE_COMMAND_A_REASONING_08_2025: 6,
|
||||
LlmModel.COHERE_COMMAND_A_VISION_07_2025: 3,
|
||||
LlmModel.DEEPSEEK_CHAT: 2,
|
||||
LlmModel.DEEPSEEK_R1_0528: 1,
|
||||
LlmModel.PERPLEXITY_SONAR: 1,
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: 5,
|
||||
LlmModel.PERPLEXITY_SONAR_REASONING_PRO: 5,
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: 10,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: 1,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: 1,
|
||||
@@ -132,7 +120,6 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.AMAZON_NOVA_MICRO_V1: 1,
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: 1,
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
|
||||
LlmModel.MICROSOFT_PHI_4: 1,
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
|
||||
LlmModel.META_LLAMA_4_SCOUT: 1,
|
||||
LlmModel.META_LLAMA_4_MAVERICK: 1,
|
||||
@@ -140,7 +127,6 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.LLAMA_API_LLAMA4_MAVERICK: 1,
|
||||
LlmModel.LLAMA_API_LLAMA3_3_8B: 1,
|
||||
LlmModel.LLAMA_API_LLAMA3_3_70B: 1,
|
||||
LlmModel.GROK_3: 3,
|
||||
LlmModel.GROK_4: 9,
|
||||
LlmModel.GROK_4_FAST: 1,
|
||||
LlmModel.GROK_4_1_FAST: 1,
|
||||
|
||||
@@ -512,10 +512,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
list_workspace_files = d.list_workspace_files
|
||||
soft_delete_workspace_file = d.soft_delete_workspace_file
|
||||
|
||||
# ============ Credits ============ #
|
||||
spend_credits = d.spend_credits
|
||||
get_credits = d.get_credits
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = d.get_business_understanding
|
||||
upsert_business_understanding = d.upsert_business_understanding
|
||||
|
||||
@@ -1,750 +0,0 @@
|
||||
import asyncio
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
from prisma.errors import UniqueViolationError
|
||||
from pydantic import BaseModel, EmailStr, TypeAdapter, ValidationError
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.data.model import User
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.tally import get_business_understanding_input_from_tally, mask_email
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
merge_business_understanding_data,
|
||||
)
|
||||
from backend.data.user import get_user_by_email, get_user_by_id
|
||||
from backend.executor.cluster_lock import AsyncClusterLock
|
||||
from backend.util.exceptions import (
|
||||
NotAuthorizedError,
|
||||
NotFoundError,
|
||||
PreconditionFailed,
|
||||
)
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_settings = Settings()
|
||||
|
||||
_WORKER_ID = f"{socket.gethostname()}:{os.getpid()}"
|
||||
|
||||
_tally_seed_tasks: dict[str, asyncio.Task] = {}
|
||||
_TALLY_STALE_SECONDS = 300
|
||||
_MAX_TALLY_ERROR_LENGTH = 200
|
||||
_email_adapter = TypeAdapter(EmailStr)
|
||||
|
||||
MAX_BULK_INVITE_FILE_BYTES = 1024 * 1024
|
||||
MAX_BULK_INVITE_ROWS = 500
|
||||
|
||||
|
||||
class InvitedUserRecord(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
status: prisma.enums.InvitedUserStatus
|
||||
auth_user_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
tally_understanding: Optional[dict[str, Any]] = None
|
||||
tally_status: prisma.enums.TallyComputationStatus
|
||||
tally_computed_at: Optional[datetime] = None
|
||||
tally_error: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, invited_user: "prisma.models.InvitedUser") -> "InvitedUserRecord":
|
||||
payload = (
|
||||
invited_user.tallyUnderstanding
|
||||
if isinstance(invited_user.tallyUnderstanding, dict)
|
||||
else None
|
||||
)
|
||||
return cls(
|
||||
id=invited_user.id,
|
||||
email=invited_user.email,
|
||||
status=invited_user.status,
|
||||
auth_user_id=invited_user.authUserId,
|
||||
name=invited_user.name,
|
||||
tally_understanding=payload,
|
||||
tally_status=invited_user.tallyStatus,
|
||||
tally_computed_at=invited_user.tallyComputedAt,
|
||||
tally_error=invited_user.tallyError,
|
||||
created_at=invited_user.createdAt,
|
||||
updated_at=invited_user.updatedAt,
|
||||
)
|
||||
|
||||
|
||||
class BulkInvitedUserRowResult(BaseModel):
|
||||
row_number: int
|
||||
email: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
status: Literal["CREATED", "SKIPPED", "ERROR"]
|
||||
message: str
|
||||
invited_user: Optional[InvitedUserRecord] = None
|
||||
|
||||
|
||||
class BulkInvitedUsersResult(BaseModel):
|
||||
created_count: int
|
||||
skipped_count: int
|
||||
error_count: int
|
||||
results: list[BulkInvitedUserRowResult]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ParsedInviteRow:
|
||||
row_number: int
|
||||
email: str
|
||||
name: Optional[str]
|
||||
|
||||
|
||||
def normalize_email(email: str) -> str:
|
||||
return email.strip().lower()
|
||||
|
||||
|
||||
def _normalize_name(name: Optional[str]) -> Optional[str]:
|
||||
if name is None:
|
||||
return None
|
||||
normalized = name.strip()
|
||||
return normalized or None
|
||||
|
||||
|
||||
def _default_profile_name(email: str, preferred_name: Optional[str]) -> str:
|
||||
if preferred_name:
|
||||
return preferred_name
|
||||
local_part = email.split("@", 1)[0].strip()
|
||||
return local_part or "user"
|
||||
|
||||
|
||||
def _sanitize_username_base(email: str) -> str:
|
||||
local_part = email.split("@", 1)[0].lower()
|
||||
sanitized = re.sub(r"[^a-z0-9-]", "", local_part)
|
||||
sanitized = sanitized.strip("-")
|
||||
return sanitized[:40] or "user"
|
||||
|
||||
|
||||
async def _generate_unique_profile_username(email: str, tx) -> str:
|
||||
base = _sanitize_username_base(email)
|
||||
|
||||
for _ in range(2):
|
||||
candidate = f"{base}-{uuid4().hex[:6]}"
|
||||
existing = await prisma.models.Profile.prisma(tx).find_unique(
|
||||
where={"username": candidate}
|
||||
)
|
||||
if existing is None:
|
||||
return candidate
|
||||
|
||||
raise RuntimeError(f"Unable to generate unique username for {email}")
|
||||
|
||||
|
||||
async def _ensure_default_profile(
|
||||
user_id: str,
|
||||
email: str,
|
||||
preferred_name: Optional[str],
|
||||
tx,
|
||||
) -> None:
|
||||
existing_profile = await prisma.models.Profile.prisma(tx).find_unique(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
if existing_profile is not None:
|
||||
return
|
||||
|
||||
username = await _generate_unique_profile_username(email, tx)
|
||||
await prisma.models.Profile.prisma(tx).create(
|
||||
data=prisma.types.ProfileCreateInput(
|
||||
userId=user_id,
|
||||
name=_default_profile_name(email, preferred_name),
|
||||
username=username,
|
||||
description="I'm new here",
|
||||
links=[],
|
||||
avatarUrl="",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _ensure_default_onboarding(user_id: str, tx) -> None:
|
||||
await prisma.models.UserOnboarding.prisma(tx).upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": prisma.types.UserOnboardingCreateInput(userId=user_id),
|
||||
"update": {},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _apply_tally_understanding(
|
||||
user_id: str,
|
||||
invited_user: "prisma.models.InvitedUser",
|
||||
tx,
|
||||
) -> None:
|
||||
if not isinstance(invited_user.tallyUnderstanding, dict):
|
||||
return
|
||||
|
||||
try:
|
||||
input_data = BusinessUnderstandingInput.model_validate(
|
||||
invited_user.tallyUnderstanding
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Malformed tallyUnderstanding for invited user %s; skipping",
|
||||
invited_user.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
payload = merge_business_understanding_data({}, input_data)
|
||||
await prisma.models.CoPilotUnderstanding.prisma(tx).upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "data": SafeJson(payload)},
|
||||
"update": {"data": SafeJson(payload)},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def list_invited_users(
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> tuple[list[InvitedUserRecord], int]:
|
||||
total = await prisma.models.InvitedUser.prisma().count()
|
||||
invited_users = await prisma.models.InvitedUser.prisma().find_many(
|
||||
order={"createdAt": "desc"},
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
return [InvitedUserRecord.from_db(iu) for iu in invited_users], total
|
||||
|
||||
|
||||
async def create_invited_user(
|
||||
email: str, name: Optional[str] = None
|
||||
) -> InvitedUserRecord:
|
||||
normalized_email = normalize_email(email)
|
||||
normalized_name = _normalize_name(name)
|
||||
|
||||
existing_user = await prisma.models.User.prisma().find_unique(
|
||||
where={"email": normalized_email}
|
||||
)
|
||||
if existing_user is not None:
|
||||
raise PreconditionFailed("An active user with this email already exists")
|
||||
|
||||
existing_invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"email": normalized_email}
|
||||
)
|
||||
if existing_invited_user is not None:
|
||||
raise PreconditionFailed("An invited user with this email already exists")
|
||||
|
||||
try:
|
||||
invited_user = await prisma.models.InvitedUser.prisma().create(
|
||||
data={
|
||||
"email": normalized_email,
|
||||
"name": normalized_name,
|
||||
"status": prisma.enums.InvitedUserStatus.INVITED,
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
|
||||
}
|
||||
)
|
||||
except UniqueViolationError:
|
||||
raise PreconditionFailed("An invited user with this email already exists")
|
||||
schedule_invited_user_tally_precompute(invited_user.id)
|
||||
return InvitedUserRecord.from_db(invited_user)
|
||||
|
||||
|
||||
async def revoke_invited_user(invited_user_id: str) -> InvitedUserRecord:
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"id": invited_user_id}
|
||||
)
|
||||
if invited_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.CLAIMED:
|
||||
raise PreconditionFailed("Claimed invited users cannot be revoked")
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
|
||||
return InvitedUserRecord.from_db(invited_user)
|
||||
|
||||
revoked_user = await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={"status": prisma.enums.InvitedUserStatus.REVOKED},
|
||||
)
|
||||
if revoked_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
return InvitedUserRecord.from_db(revoked_user)
|
||||
|
||||
|
||||
async def retry_invited_user_tally(invited_user_id: str) -> InvitedUserRecord:
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"id": invited_user_id}
|
||||
)
|
||||
if invited_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
|
||||
raise PreconditionFailed("Revoked invited users cannot retry Tally seeding")
|
||||
|
||||
refreshed_user = await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyUnderstanding": None,
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
|
||||
"tallyComputedAt": None,
|
||||
"tallyError": None,
|
||||
},
|
||||
)
|
||||
if refreshed_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
schedule_invited_user_tally_precompute(invited_user_id)
|
||||
return InvitedUserRecord.from_db(refreshed_user)
|
||||
|
||||
|
||||
def _decode_bulk_invite_file(content: bytes) -> str:
|
||||
if len(content) > MAX_BULK_INVITE_FILE_BYTES:
|
||||
raise ValueError("Invite file exceeds the maximum size of 1 MB")
|
||||
|
||||
try:
|
||||
return content.decode("utf-8-sig")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise ValueError("Invite file must be UTF-8 encoded") from exc
|
||||
|
||||
|
||||
def _parse_bulk_invite_csv(text: str) -> list[_ParsedInviteRow]:
|
||||
indexed_rows: list[tuple[int, list[str]]] = []
|
||||
|
||||
for row_number, row in enumerate(csv.reader(io.StringIO(text)), start=1):
|
||||
normalized_row = [cell.strip() for cell in row]
|
||||
if any(normalized_row):
|
||||
indexed_rows.append((row_number, normalized_row))
|
||||
|
||||
if not indexed_rows:
|
||||
return []
|
||||
|
||||
header = [cell.lower() for cell in indexed_rows[0][1]]
|
||||
has_header = "email" in header
|
||||
email_index = header.index("email") if has_header else 0
|
||||
name_index: Optional[int] = (
|
||||
header.index("name")
|
||||
if has_header and "name" in header
|
||||
else (1 if not has_header else None)
|
||||
)
|
||||
data_rows = indexed_rows[1:] if has_header else indexed_rows
|
||||
|
||||
parsed_rows: list[_ParsedInviteRow] = []
|
||||
for row_number, row in data_rows:
|
||||
if len(parsed_rows) >= MAX_BULK_INVITE_ROWS:
|
||||
break
|
||||
email = row[email_index].strip() if len(row) > email_index else ""
|
||||
name = (
|
||||
row[name_index].strip()
|
||||
if name_index is not None and len(row) > name_index
|
||||
else ""
|
||||
)
|
||||
parsed_rows.append(
|
||||
_ParsedInviteRow(
|
||||
row_number=row_number,
|
||||
email=email,
|
||||
name=name or None,
|
||||
)
|
||||
)
|
||||
|
||||
return parsed_rows
|
||||
|
||||
|
||||
def _parse_bulk_invite_text(text: str) -> list[_ParsedInviteRow]:
|
||||
parsed_rows: list[_ParsedInviteRow] = []
|
||||
|
||||
for row_number, raw_line in enumerate(text.splitlines(), start=1):
|
||||
if len(parsed_rows) >= MAX_BULK_INVITE_ROWS:
|
||||
break
|
||||
line = raw_line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
parsed_rows.append(
|
||||
_ParsedInviteRow(
|
||||
row_number=row_number,
|
||||
email=line,
|
||||
name=None,
|
||||
)
|
||||
)
|
||||
|
||||
return parsed_rows
|
||||
|
||||
|
||||
def _parse_bulk_invite_file(
|
||||
filename: Optional[str],
|
||||
content: bytes,
|
||||
) -> list[_ParsedInviteRow]:
|
||||
text = _decode_bulk_invite_file(content)
|
||||
file_name = filename.lower() if filename else ""
|
||||
parsed_rows = (
|
||||
_parse_bulk_invite_csv(text)
|
||||
if file_name.endswith(".csv")
|
||||
else _parse_bulk_invite_text(text)
|
||||
)
|
||||
|
||||
if not parsed_rows:
|
||||
raise ValueError("Invite file did not contain any emails")
|
||||
|
||||
return parsed_rows
|
||||
|
||||
|
||||
async def bulk_create_invited_users_from_file(
|
||||
filename: Optional[str],
|
||||
content: bytes,
|
||||
) -> BulkInvitedUsersResult:
|
||||
parsed_rows = _parse_bulk_invite_file(filename, content)
|
||||
|
||||
created_count = 0
|
||||
skipped_count = 0
|
||||
error_count = 0
|
||||
results: list[BulkInvitedUserRowResult] = []
|
||||
seen_emails: set[str] = set()
|
||||
|
||||
for row in parsed_rows:
|
||||
row_name = _normalize_name(row.name)
|
||||
|
||||
try:
|
||||
validated_email = _email_adapter.validate_python(row.email)
|
||||
except ValidationError:
|
||||
error_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=row.email or None,
|
||||
name=row_name,
|
||||
status="ERROR",
|
||||
message="Invalid email address",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
normalized_email = normalize_email(str(validated_email))
|
||||
if normalized_email in seen_emails:
|
||||
skipped_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="SKIPPED",
|
||||
message="Duplicate email in upload file",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
seen_emails.add(normalized_email)
|
||||
|
||||
try:
|
||||
invited_user = await create_invited_user(normalized_email, row_name)
|
||||
except PreconditionFailed as exc:
|
||||
skipped_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="SKIPPED",
|
||||
message=str(exc),
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
masked = mask_email(normalized_email)
|
||||
logger.exception(
|
||||
"Failed to create bulk invite for row %s (%s)",
|
||||
row.row_number,
|
||||
masked,
|
||||
)
|
||||
error_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="ERROR",
|
||||
message="Unexpected error creating invite",
|
||||
)
|
||||
)
|
||||
else:
|
||||
created_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="CREATED",
|
||||
message="Invite created",
|
||||
invited_user=invited_user,
|
||||
)
|
||||
)
|
||||
|
||||
return BulkInvitedUsersResult(
|
||||
created_count=created_count,
|
||||
skipped_count=skipped_count,
|
||||
error_count=error_count,
|
||||
results=results,
|
||||
)
|
||||
|
||||
|
||||
async def _compute_invited_user_tally_seed(invited_user_id: str) -> None:
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"id": invited_user_id}
|
||||
)
|
||||
if invited_user is None:
|
||||
return
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
|
||||
return
|
||||
|
||||
try:
|
||||
r = await get_redis_async()
|
||||
except Exception:
|
||||
r = None
|
||||
|
||||
lock: AsyncClusterLock | None = None
|
||||
|
||||
if r is not None:
|
||||
lock = AsyncClusterLock(
|
||||
redis=r,
|
||||
key=f"tally_seed:{invited_user_id}",
|
||||
owner_id=_WORKER_ID,
|
||||
timeout=_TALLY_STALE_SECONDS,
|
||||
)
|
||||
current_owner = await lock.try_acquire()
|
||||
|
||||
if current_owner is None:
|
||||
logger.warn("Redis unvailable for tally lock - skipping tally enrichement")
|
||||
return
|
||||
elif current_owner != _WORKER_ID:
|
||||
logger.debug(
|
||||
"Tally seed for %s already locked by %s, skipping",
|
||||
invited_user_id,
|
||||
current_owner,
|
||||
)
|
||||
return
|
||||
if (
|
||||
invited_user.tallyStatus == prisma.enums.TallyComputationStatus.RUNNING
|
||||
and invited_user.updatedAt is not None
|
||||
):
|
||||
age = (datetime.now(timezone.utc) - invited_user.updatedAt).total_seconds()
|
||||
if age < _TALLY_STALE_SECONDS:
|
||||
logger.debug(
|
||||
"Tally task for %s still RUNNING (age=%ds), skipping",
|
||||
invited_user_id,
|
||||
int(age),
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"Tally task for %s is stale (age=%ds), re-running",
|
||||
invited_user_id,
|
||||
int(age),
|
||||
)
|
||||
|
||||
await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.RUNNING,
|
||||
"tallyError": None,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
input_data = await get_business_understanding_input_from_tally(
|
||||
invited_user.email,
|
||||
require_api_key=True,
|
||||
)
|
||||
payload = (
|
||||
SafeJson(input_data.model_dump(exclude_none=True))
|
||||
if input_data is not None
|
||||
else None
|
||||
)
|
||||
await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyUnderstanding": payload,
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.READY,
|
||||
"tallyComputedAt": datetime.now(timezone.utc),
|
||||
"tallyError": None,
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"Failed to compute Tally understanding for invited user %s",
|
||||
invited_user_id,
|
||||
)
|
||||
sanitized_error = re.sub(
|
||||
r"https?://\S+", "<url>", f"{type(exc).__name__}: {exc}"
|
||||
)[:_MAX_TALLY_ERROR_LENGTH]
|
||||
await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.FAILED,
|
||||
"tallyError": sanitized_error,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def schedule_invited_user_tally_precompute(invited_user_id: str) -> None:
|
||||
existing = _tally_seed_tasks.get(invited_user_id)
|
||||
if existing is not None and not existing.done():
|
||||
logger.debug("Tally task already running for %s, skipping", invited_user_id)
|
||||
return
|
||||
|
||||
task = asyncio.create_task(_compute_invited_user_tally_seed(invited_user_id))
|
||||
_tally_seed_tasks[invited_user_id] = task
|
||||
|
||||
def _on_done(t: asyncio.Task, _id: str = invited_user_id) -> None:
|
||||
if _tally_seed_tasks.get(_id) is t:
|
||||
del _tally_seed_tasks[_id]
|
||||
|
||||
task.add_done_callback(_on_done)
|
||||
|
||||
|
||||
async def _open_signup_create_user(
|
||||
auth_user_id: str,
|
||||
normalized_email: str,
|
||||
metadata_name: Optional[str],
|
||||
) -> User:
|
||||
"""Create a user without requiring an invite (open signup mode)."""
|
||||
preferred_name = _normalize_name(metadata_name)
|
||||
try:
|
||||
async with transaction() as tx:
|
||||
user = await prisma.models.User.prisma(tx).create(
|
||||
data=prisma.types.UserCreateInput(
|
||||
id=auth_user_id,
|
||||
email=normalized_email,
|
||||
name=preferred_name,
|
||||
)
|
||||
)
|
||||
await _ensure_default_profile(
|
||||
auth_user_id, normalized_email, preferred_name, tx
|
||||
)
|
||||
await _ensure_default_onboarding(auth_user_id, tx)
|
||||
except UniqueViolationError:
|
||||
existing = await prisma.models.User.prisma().find_unique(
|
||||
where={"id": auth_user_id}
|
||||
)
|
||||
if existing is not None:
|
||||
return User.from_db(existing)
|
||||
raise
|
||||
|
||||
return User.from_db(user)
|
||||
|
||||
|
||||
# TODO: We need to change this functions logic before going live
|
||||
async def get_or_activate_user(user_data: dict) -> User:
|
||||
auth_user_id = user_data.get("sub")
|
||||
if not auth_user_id:
|
||||
raise NotAuthorizedError("User ID not found in token")
|
||||
|
||||
auth_email = user_data.get("email")
|
||||
if not auth_email:
|
||||
raise NotAuthorizedError("Email not found in token")
|
||||
|
||||
normalized_email = normalize_email(auth_email)
|
||||
user_metadata = user_data.get("user_metadata")
|
||||
metadata_name = (
|
||||
user_metadata.get("name") if isinstance(user_metadata, dict) else None
|
||||
)
|
||||
|
||||
existing_user = None
|
||||
try:
|
||||
existing_user = await get_user_by_id(auth_user_id)
|
||||
except ValueError:
|
||||
existing_user = None
|
||||
except Exception:
|
||||
logger.exception("Error on get user by id during tally enrichment process")
|
||||
raise
|
||||
|
||||
if existing_user is not None:
|
||||
return existing_user
|
||||
|
||||
if not _settings.config.enable_invite_gate or normalized_email.endswith("@agpt.co"):
|
||||
return await _open_signup_create_user(
|
||||
auth_user_id, normalized_email, metadata_name
|
||||
)
|
||||
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"email": normalized_email}
|
||||
)
|
||||
if invited_user is None:
|
||||
raise NotAuthorizedError("Your email is not allowed to access the platform")
|
||||
|
||||
if invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
|
||||
raise NotAuthorizedError("Your invitation is no longer active")
|
||||
|
||||
try:
|
||||
async with transaction() as tx:
|
||||
current_user = await prisma.models.User.prisma(tx).find_unique(
|
||||
where={"id": auth_user_id}
|
||||
)
|
||||
if current_user is not None:
|
||||
return User.from_db(current_user)
|
||||
|
||||
current_invited_user = await prisma.models.InvitedUser.prisma(
|
||||
tx
|
||||
).find_unique(where={"email": normalized_email})
|
||||
if current_invited_user is None:
|
||||
raise NotAuthorizedError(
|
||||
"Your email is not allowed to access the platform"
|
||||
)
|
||||
|
||||
if current_invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
|
||||
raise NotAuthorizedError("Your invitation is no longer active")
|
||||
|
||||
if current_invited_user.authUserId not in (None, auth_user_id):
|
||||
raise NotAuthorizedError("Your invitation has already been claimed")
|
||||
|
||||
preferred_name = current_invited_user.name or _normalize_name(metadata_name)
|
||||
await prisma.models.User.prisma(tx).create(
|
||||
data=prisma.types.UserCreateInput(
|
||||
id=auth_user_id,
|
||||
email=normalized_email,
|
||||
name=preferred_name,
|
||||
)
|
||||
)
|
||||
|
||||
await prisma.models.InvitedUser.prisma(tx).update(
|
||||
where={"id": current_invited_user.id},
|
||||
data={
|
||||
"status": prisma.enums.InvitedUserStatus.CLAIMED,
|
||||
"authUserId": auth_user_id,
|
||||
},
|
||||
)
|
||||
|
||||
await _ensure_default_profile(
|
||||
auth_user_id,
|
||||
normalized_email,
|
||||
preferred_name,
|
||||
tx,
|
||||
)
|
||||
await _ensure_default_onboarding(auth_user_id, tx)
|
||||
await _apply_tally_understanding(auth_user_id, current_invited_user, tx)
|
||||
except UniqueViolationError:
|
||||
logger.info("Concurrent activation for user %s; re-fetching", auth_user_id)
|
||||
already_created = await prisma.models.User.prisma().find_unique(
|
||||
where={"id": auth_user_id}
|
||||
)
|
||||
if already_created is not None:
|
||||
return User.from_db(already_created)
|
||||
raise RuntimeError(
|
||||
f"UniqueViolationError during activation but user {auth_user_id} not found"
|
||||
)
|
||||
|
||||
get_user_by_id.cache_delete(auth_user_id)
|
||||
get_user_by_email.cache_delete(normalized_email)
|
||||
|
||||
activated_user = await prisma.models.User.prisma().find_unique(
|
||||
where={"id": auth_user_id}
|
||||
)
|
||||
if activated_user is None:
|
||||
raise RuntimeError(
|
||||
f"Activated user {auth_user_id} was not found after creation"
|
||||
)
|
||||
|
||||
return User.from_db(activated_user)
|
||||
@@ -1,335 +0,0 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.util.exceptions import NotAuthorizedError, PreconditionFailed
|
||||
|
||||
from .invited_user import (
|
||||
InvitedUserRecord,
|
||||
bulk_create_invited_users_from_file,
|
||||
create_invited_user,
|
||||
get_or_activate_user,
|
||||
retry_invited_user_tally,
|
||||
)
|
||||
|
||||
|
||||
def _invited_user_db_record(
|
||||
*,
|
||||
status: prisma.enums.InvitedUserStatus = prisma.enums.InvitedUserStatus.INVITED,
|
||||
tally_understanding: dict | None = None,
|
||||
):
|
||||
now = datetime.now(timezone.utc)
|
||||
return SimpleNamespace(
|
||||
id="invite-1",
|
||||
email="invited@example.com",
|
||||
status=status,
|
||||
authUserId=None,
|
||||
name="Invited User",
|
||||
tallyUnderstanding=tally_understanding,
|
||||
tallyStatus=prisma.enums.TallyComputationStatus.PENDING,
|
||||
tallyComputedAt=None,
|
||||
tallyError=None,
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
)
|
||||
|
||||
|
||||
def _invited_user_record(
|
||||
*,
|
||||
status: prisma.enums.InvitedUserStatus = prisma.enums.InvitedUserStatus.INVITED,
|
||||
tally_understanding: dict | None = None,
|
||||
):
|
||||
return InvitedUserRecord.from_db(
|
||||
cast(
|
||||
prisma.models.InvitedUser,
|
||||
_invited_user_db_record(
|
||||
status=status,
|
||||
tally_understanding=tally_understanding,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _user_db_record():
|
||||
now = datetime.now(timezone.utc)
|
||||
return SimpleNamespace(
|
||||
id="auth-user-1",
|
||||
email="invited@example.com",
|
||||
emailVerified=True,
|
||||
name="Invited User",
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
metadata={},
|
||||
integrations="",
|
||||
stripeCustomerId=None,
|
||||
topUpConfig=None,
|
||||
maxEmailsPerDay=3,
|
||||
notifyOnAgentRun=True,
|
||||
notifyOnZeroBalance=True,
|
||||
notifyOnLowBalance=True,
|
||||
notifyOnBlockExecutionFailed=True,
|
||||
notifyOnContinuousAgentError=True,
|
||||
notifyOnDailySummary=True,
|
||||
notifyOnWeeklySummary=True,
|
||||
notifyOnMonthlySummary=True,
|
||||
notifyOnAgentApproved=True,
|
||||
notifyOnAgentRejected=True,
|
||||
timezone="not-set",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_invited_user_rejects_existing_active_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
user_repo = Mock()
|
||||
user_repo.find_unique = AsyncMock(return_value=_user_db_record())
|
||||
invited_user_repo = Mock()
|
||||
invited_user_repo.find_unique = AsyncMock()
|
||||
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=invited_user_repo,
|
||||
)
|
||||
|
||||
with pytest.raises(PreconditionFailed):
|
||||
await create_invited_user("Invited@example.com")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_invited_user_schedules_tally_seed(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
user_repo = Mock()
|
||||
user_repo.find_unique = AsyncMock(return_value=None)
|
||||
invited_user_repo = Mock()
|
||||
invited_user_repo.find_unique = AsyncMock(return_value=None)
|
||||
invited_user_repo.create = AsyncMock(return_value=_invited_user_db_record())
|
||||
schedule = mocker.patch(
|
||||
"backend.data.invited_user.schedule_invited_user_tally_precompute"
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=invited_user_repo,
|
||||
)
|
||||
|
||||
invited_user = await create_invited_user("Invited@example.com", "Invited User")
|
||||
|
||||
assert invited_user.email == "invited@example.com"
|
||||
invited_user_repo.create.assert_awaited_once()
|
||||
schedule.assert_called_once_with("invite-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_invited_user_tally_resets_state_and_schedules(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
invited_user_repo = Mock()
|
||||
invited_user_repo.find_unique = AsyncMock(return_value=_invited_user_db_record())
|
||||
invited_user_repo.update = AsyncMock(return_value=_invited_user_db_record())
|
||||
schedule = mocker.patch(
|
||||
"backend.data.invited_user.schedule_invited_user_tally_precompute"
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=invited_user_repo,
|
||||
)
|
||||
|
||||
invited_user = await retry_invited_user_tally("invite-1")
|
||||
|
||||
assert invited_user.id == "invite-1"
|
||||
invited_user_repo.update.assert_awaited_once()
|
||||
schedule.assert_called_once_with("invite-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_activate_user_requires_invite(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
invited_user_repo = Mock()
|
||||
invited_user_repo.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
mock_get_user_by_id = AsyncMock(side_effect=ValueError("User not found"))
|
||||
mock_get_user_by_id.cache_delete = Mock()
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.get_user_by_id",
|
||||
mock_get_user_by_id,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user._settings.config.enable_invite_gate",
|
||||
True,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=invited_user_repo,
|
||||
)
|
||||
|
||||
with pytest.raises(NotAuthorizedError):
|
||||
await get_or_activate_user(
|
||||
{"sub": "auth-user-1", "email": "invited@example.com"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_activate_user_creates_user_from_invite(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
tx = object()
|
||||
invited_user = _invited_user_db_record(
|
||||
tally_understanding={"user_name": "Invited User", "industry": "Automation"}
|
||||
)
|
||||
created_user = _user_db_record()
|
||||
|
||||
outside_user_repo = Mock()
|
||||
# Only called once at post-transaction verification (line 741);
|
||||
# get_user_by_id (line 657) uses prisma.user.find_unique, not this mock.
|
||||
outside_user_repo.find_unique = AsyncMock(return_value=created_user)
|
||||
|
||||
inside_user_repo = Mock()
|
||||
inside_user_repo.find_unique = AsyncMock(return_value=None)
|
||||
inside_user_repo.create = AsyncMock(return_value=created_user)
|
||||
|
||||
outside_invited_repo = Mock()
|
||||
outside_invited_repo.find_unique = AsyncMock(return_value=invited_user)
|
||||
|
||||
inside_invited_repo = Mock()
|
||||
inside_invited_repo.find_unique = AsyncMock(return_value=invited_user)
|
||||
inside_invited_repo.update = AsyncMock(return_value=invited_user)
|
||||
|
||||
def user_prisma(client=None):
|
||||
return inside_user_repo if client is tx else outside_user_repo
|
||||
|
||||
def invited_user_prisma(client=None):
|
||||
return inside_invited_repo if client is tx else outside_invited_repo
|
||||
|
||||
@asynccontextmanager
|
||||
async def fake_transaction():
|
||||
yield tx
|
||||
|
||||
# Mock get_user_by_id since it uses prisma.user.find_unique (global client),
|
||||
# not prisma.models.User.prisma().find_unique which we mock above.
|
||||
mock_get_user_by_id = AsyncMock(side_effect=ValueError("User not found"))
|
||||
mock_get_user_by_id.cache_delete = Mock()
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.get_user_by_id",
|
||||
mock_get_user_by_id,
|
||||
)
|
||||
mock_get_user_by_email = AsyncMock()
|
||||
mock_get_user_by_email.cache_delete = Mock()
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.get_user_by_email",
|
||||
mock_get_user_by_email,
|
||||
)
|
||||
ensure_profile = mocker.patch(
|
||||
"backend.data.invited_user._ensure_default_profile",
|
||||
AsyncMock(),
|
||||
)
|
||||
ensure_onboarding = mocker.patch(
|
||||
"backend.data.invited_user._ensure_default_onboarding",
|
||||
AsyncMock(),
|
||||
)
|
||||
apply_tally = mocker.patch(
|
||||
"backend.data.invited_user._apply_tally_understanding",
|
||||
AsyncMock(),
|
||||
)
|
||||
mocker.patch("backend.data.invited_user.transaction", fake_transaction)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.User.prisma", side_effect=user_prisma
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
side_effect=invited_user_prisma,
|
||||
)
|
||||
|
||||
user = await get_or_activate_user(
|
||||
{
|
||||
"sub": "auth-user-1",
|
||||
"email": "Invited@example.com",
|
||||
"user_metadata": {"name": "Invited User"},
|
||||
}
|
||||
)
|
||||
|
||||
assert user.id == "auth-user-1"
|
||||
inside_user_repo.create.assert_awaited_once()
|
||||
inside_invited_repo.update.assert_awaited_once()
|
||||
ensure_profile.assert_awaited_once()
|
||||
ensure_onboarding.assert_awaited_once_with("auth-user-1", tx)
|
||||
apply_tally.assert_awaited_once_with("auth-user-1", invited_user, tx)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_create_invited_users_from_text_file(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
create_invited = mocker.patch(
|
||||
"backend.data.invited_user.create_invited_user",
|
||||
AsyncMock(
|
||||
side_effect=[
|
||||
_invited_user_record(),
|
||||
_invited_user_record(),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
result = await bulk_create_invited_users_from_file(
|
||||
"invites.txt",
|
||||
b"Invited@example.com\nsecond@example.com\n",
|
||||
)
|
||||
|
||||
assert result.created_count == 2
|
||||
assert result.skipped_count == 0
|
||||
assert result.error_count == 0
|
||||
assert [row.status for row in result.results] == ["CREATED", "CREATED"]
|
||||
assert create_invited.await_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_create_invited_users_handles_csv_duplicates_and_invalid_rows(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
create_invited = mocker.patch(
|
||||
"backend.data.invited_user.create_invited_user",
|
||||
AsyncMock(
|
||||
side_effect=[
|
||||
_invited_user_record(),
|
||||
PreconditionFailed("An invited user with this email already exists"),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
result = await bulk_create_invited_users_from_file(
|
||||
"invites.csv",
|
||||
(
|
||||
"email,name\n"
|
||||
"valid@example.com,Valid User\n"
|
||||
"not-an-email,Bad Row\n"
|
||||
"valid@example.com,Duplicate In File\n"
|
||||
"existing@example.com,Existing User\n"
|
||||
).encode("utf-8"),
|
||||
)
|
||||
|
||||
assert result.created_count == 1
|
||||
assert result.skipped_count == 2
|
||||
assert result.error_count == 1
|
||||
assert [row.status for row in result.results] == [
|
||||
"CREATED",
|
||||
"ERROR",
|
||||
"SKIPPED",
|
||||
"SKIPPED",
|
||||
]
|
||||
assert create_invited.await_count == 2
|
||||
@@ -8,8 +8,6 @@ from backend.api.model import NotificationPayload
|
||||
from backend.data.event_bus import AsyncRedisEventBus
|
||||
from backend.util.settings import Settings
|
||||
|
||||
_settings = Settings()
|
||||
|
||||
|
||||
class NotificationEvent(BaseModel):
|
||||
"""Generic notification event destined for websocket delivery."""
|
||||
@@ -28,7 +26,7 @@ class AsyncRedisNotificationEventBus(AsyncRedisEventBus[NotificationEvent]):
|
||||
|
||||
@property
|
||||
def event_bus_name(self) -> str:
|
||||
return _settings.config.notification_event_bus_name
|
||||
return Settings().config.notification_event_bus_name
|
||||
|
||||
async def publish(self, event: NotificationEvent) -> None:
|
||||
await self.publish_event(event, event.user_id)
|
||||
|
||||
@@ -41,7 +41,7 @@ _MAX_PAGES = 100
|
||||
_LLM_TIMEOUT = 30
|
||||
|
||||
|
||||
def mask_email(email: str) -> str:
|
||||
def _mask_email(email: str) -> str:
|
||||
"""Mask an email for safe logging: 'alice@example.com' -> 'a***e@example.com'."""
|
||||
try:
|
||||
local, domain = email.rsplit("@", 1)
|
||||
@@ -196,7 +196,8 @@ async def _refresh_cache(form_id: str) -> tuple[dict, list]:
|
||||
|
||||
Returns (email_index, questions).
|
||||
"""
|
||||
client = _make_tally_client(_settings.secrets.tally_api_key)
|
||||
settings = Settings()
|
||||
client = _make_tally_client(settings.secrets.tally_api_key)
|
||||
|
||||
redis = await get_redis_async()
|
||||
last_fetch_key = _LAST_FETCH_KEY.format(form_id=form_id)
|
||||
@@ -331,9 +332,6 @@ Fields:
|
||||
- current_software (list of strings): software/tools currently used
|
||||
- existing_automation (list of strings): existing automations
|
||||
- additional_notes (string): any additional context
|
||||
- suggested_prompts (list of 5 strings): short action prompts (each under 20 words) that would help \
|
||||
this person get started with automating their work. Should be specific to their industry, role, and \
|
||||
pain points; actionable and conversational in tone; focused on automation opportunities.
|
||||
|
||||
Form data:
|
||||
"""
|
||||
@@ -341,21 +339,21 @@ Form data:
|
||||
_EXTRACTION_SUFFIX = "\n\nReturn ONLY valid JSON."
|
||||
|
||||
|
||||
async def extract_business_understanding_from_tally(
|
||||
async def extract_business_understanding(
|
||||
formatted_text: str,
|
||||
) -> BusinessUnderstandingInput:
|
||||
"""
|
||||
Use an LLM to extract structured business understanding from form text.
|
||||
"""Use an LLM to extract structured business understanding from form text.
|
||||
|
||||
Raises on timeout or unparseable response so the caller can handle it.
|
||||
"""
|
||||
api_key = _settings.secrets.open_router_api_key
|
||||
settings = Settings()
|
||||
api_key = settings.secrets.open_router_api_key
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=OPENROUTER_BASE_URL)
|
||||
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
client.chat.completions.create(
|
||||
model=_settings.config.tally_extraction_llm_model,
|
||||
model="openai/gpt-4o-mini",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
@@ -380,57 +378,9 @@ async def extract_business_understanding_from_tally(
|
||||
|
||||
# Filter out null values before constructing
|
||||
cleaned = {k: v for k, v in data.items() if v is not None}
|
||||
|
||||
# Validate suggested_prompts: filter >20 words, keep top 3
|
||||
raw_prompts = cleaned.get("suggested_prompts", [])
|
||||
if isinstance(raw_prompts, list):
|
||||
valid = [
|
||||
p.strip()
|
||||
for p in raw_prompts
|
||||
if isinstance(p, str) and len(p.strip().split()) <= 20
|
||||
]
|
||||
# This will keep up to 3 suggestions
|
||||
short_prompts = valid[:3] if valid else None
|
||||
if short_prompts:
|
||||
cleaned["suggested_prompts"] = short_prompts
|
||||
else:
|
||||
# We dont want to add a None value suggested_prompts field
|
||||
cleaned.pop("suggested_prompts", None)
|
||||
else:
|
||||
# suggested_prompts must be a list - removing it as its not here
|
||||
cleaned.pop("suggested_prompts", None)
|
||||
|
||||
return BusinessUnderstandingInput(**cleaned)
|
||||
|
||||
|
||||
async def get_business_understanding_input_from_tally(
|
||||
email: str,
|
||||
*,
|
||||
require_api_key: bool = False,
|
||||
) -> Optional[BusinessUnderstandingInput]:
|
||||
if not _settings.secrets.tally_api_key:
|
||||
if require_api_key:
|
||||
raise RuntimeError("Tally API key is not configured")
|
||||
logger.debug("Tally: no API key configured, skipping")
|
||||
return None
|
||||
|
||||
masked = mask_email(email)
|
||||
result = await find_submission_by_email(TALLY_FORM_ID, email)
|
||||
if result is None:
|
||||
logger.debug(f"Tally: no submission found for {masked}")
|
||||
return None
|
||||
|
||||
submission, questions = result
|
||||
logger.info(f"Tally: found submission for {masked}, extracting understanding")
|
||||
|
||||
formatted = format_submission_for_llm(submission, questions)
|
||||
if not formatted.strip():
|
||||
logger.warning("Tally: formatted submission was empty, skipping")
|
||||
return None
|
||||
|
||||
return await extract_business_understanding_from_tally(formatted)
|
||||
|
||||
|
||||
async def populate_understanding_from_tally(user_id: str, email: str) -> None:
|
||||
"""Main orchestrator: check Tally for a matching submission and populate understanding.
|
||||
|
||||
@@ -445,10 +395,30 @@ async def populate_understanding_from_tally(user_id: str, email: str) -> None:
|
||||
)
|
||||
return
|
||||
|
||||
understanding_input = await get_business_understanding_input_from_tally(email)
|
||||
if understanding_input is None:
|
||||
# Check API key is configured
|
||||
settings = Settings()
|
||||
if not settings.secrets.tally_api_key:
|
||||
logger.debug("Tally: no API key configured, skipping")
|
||||
return
|
||||
|
||||
# Look up submission by email
|
||||
masked = _mask_email(email)
|
||||
result = await find_submission_by_email(TALLY_FORM_ID, email)
|
||||
if result is None:
|
||||
logger.debug(f"Tally: no submission found for {masked}")
|
||||
return
|
||||
|
||||
submission, questions = result
|
||||
logger.info(f"Tally: found submission for {masked}, extracting understanding")
|
||||
|
||||
# Format and extract
|
||||
formatted = format_submission_for_llm(submission, questions)
|
||||
if not formatted.strip():
|
||||
logger.warning("Tally: formatted submission was empty, skipping")
|
||||
return
|
||||
|
||||
understanding_input = await extract_business_understanding(formatted)
|
||||
|
||||
# Upsert into database
|
||||
await upsert_business_understanding(user_id, understanding_input)
|
||||
logger.info(f"Tally: successfully populated understanding for user {user_id}")
|
||||
|
||||
@@ -12,11 +12,11 @@ from backend.data.tally import (
|
||||
_build_email_index,
|
||||
_format_answer,
|
||||
_make_tally_client,
|
||||
_mask_email,
|
||||
_refresh_cache,
|
||||
extract_business_understanding_from_tally,
|
||||
extract_business_understanding,
|
||||
find_submission_by_email,
|
||||
format_submission_for_llm,
|
||||
mask_email,
|
||||
populate_understanding_from_tally,
|
||||
)
|
||||
|
||||
@@ -248,7 +248,7 @@ async def test_populate_understanding_skips_no_api_key():
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.data.tally._settings", mock_settings),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
@@ -284,7 +284,6 @@ async def test_populate_understanding_full_flow():
|
||||
],
|
||||
}
|
||||
mock_input = MagicMock()
|
||||
mock_input.suggested_prompts = ["Prompt 1", "Prompt 2", "Prompt 3"]
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -292,14 +291,14 @@ async def test_populate_understanding_full_flow():
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.data.tally._settings", mock_settings),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(submission, SAMPLE_QUESTIONS),
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally.extract_business_understanding_from_tally",
|
||||
"backend.data.tally.extract_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_input,
|
||||
) as mock_extract,
|
||||
@@ -332,14 +331,14 @@ async def test_populate_understanding_handles_llm_timeout():
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.data.tally._settings", mock_settings),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(submission, SAMPLE_QUESTIONS),
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally.extract_business_understanding_from_tally",
|
||||
"backend.data.tally.extract_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=asyncio.TimeoutError(),
|
||||
),
|
||||
@@ -357,13 +356,13 @@ async def test_populate_understanding_handles_llm_timeout():
|
||||
|
||||
|
||||
def test_mask_email():
|
||||
assert mask_email("alice@example.com") == "a***e@example.com"
|
||||
assert mask_email("ab@example.com") == "a***@example.com"
|
||||
assert mask_email("a@example.com") == "a***@example.com"
|
||||
assert _mask_email("alice@example.com") == "a***e@example.com"
|
||||
assert _mask_email("ab@example.com") == "a***@example.com"
|
||||
assert _mask_email("a@example.com") == "a***@example.com"
|
||||
|
||||
|
||||
def test_mask_email_invalid():
|
||||
assert mask_email("no-at-sign") == "***"
|
||||
assert _mask_email("no-at-sign") == "***"
|
||||
|
||||
|
||||
# ── Prompt construction (curly-brace safety) ─────────────────────────────────
|
||||
@@ -394,11 +393,11 @@ def test_extraction_prompt_no_format_placeholders():
|
||||
assert single_braces == [], f"Found format placeholders: {single_braces}"
|
||||
|
||||
|
||||
# ── extract_business_understanding_from_tally ────────────────────────────────────────────
|
||||
# ── extract_business_understanding ────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_from_tally_success():
|
||||
async def test_extract_business_understanding_success():
|
||||
"""Happy path: LLM returns valid JSON that maps to BusinessUnderstandingInput."""
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
@@ -407,13 +406,6 @@ async def test_extract_business_understanding_from_tally_success():
|
||||
"business_name": "Acme Corp",
|
||||
"industry": "Technology",
|
||||
"pain_points": ["manual reporting"],
|
||||
"suggested_prompts": [
|
||||
"Automate weekly reports",
|
||||
"Set up invoice processing",
|
||||
"Create a customer onboarding flow",
|
||||
"Track project deadlines automatically",
|
||||
"Send follow-up emails after meetings",
|
||||
],
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
@@ -423,56 +415,16 @@ async def test_extract_business_understanding_from_tally_success():
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
|
||||
result = await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
|
||||
result = await extract_business_understanding("Q: Name?\nA: Alice")
|
||||
|
||||
assert result.user_name == "Alice"
|
||||
assert result.business_name == "Acme Corp"
|
||||
assert result.industry == "Technology"
|
||||
assert result.pain_points == ["manual reporting"]
|
||||
# suggested_prompts validated and sliced to top 3
|
||||
assert result.suggested_prompts == [
|
||||
"Automate weekly reports",
|
||||
"Set up invoice processing",
|
||||
"Create a customer onboarding flow",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_from_tally_filters_long_prompts():
|
||||
"""Prompts exceeding 20 words are excluded and only top 3 are kept."""
|
||||
long_prompt = " ".join(["word"] * 21)
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
{
|
||||
"user_name": "Alice",
|
||||
"suggested_prompts": [
|
||||
long_prompt,
|
||||
"Short prompt one",
|
||||
long_prompt,
|
||||
"Short prompt two",
|
||||
"Short prompt three",
|
||||
"Short prompt four",
|
||||
],
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
|
||||
result = await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
|
||||
|
||||
assert result.suggested_prompts == [
|
||||
"Short prompt one",
|
||||
"Short prompt two",
|
||||
"Short prompt three",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_from_tally_filters_nulls():
|
||||
async def test_extract_business_understanding_filters_nulls():
|
||||
"""Null values from LLM should be excluded from the result."""
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
@@ -485,7 +437,7 @@ async def test_extract_business_understanding_from_tally_filters_nulls():
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
|
||||
result = await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
|
||||
result = await extract_business_understanding("Q: Name?\nA: Alice")
|
||||
|
||||
assert result.user_name == "Alice"
|
||||
assert result.business_name is None
|
||||
@@ -493,7 +445,7 @@ async def test_extract_business_understanding_from_tally_filters_nulls():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_from_tally_invalid_json():
|
||||
async def test_extract_business_understanding_invalid_json():
|
||||
"""Invalid JSON from LLM should raise JSONDecodeError."""
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = "not valid json {"
|
||||
@@ -507,11 +459,11 @@ async def test_extract_business_understanding_from_tally_invalid_json():
|
||||
patch("backend.data.tally.AsyncOpenAI", return_value=mock_client),
|
||||
pytest.raises(json.JSONDecodeError),
|
||||
):
|
||||
await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
|
||||
await extract_business_understanding("Q: Name?\nA: Alice")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_from_tally_timeout():
|
||||
async def test_extract_business_understanding_timeout():
|
||||
"""LLM timeout should propagate as asyncio.TimeoutError."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.side_effect = asyncio.TimeoutError()
|
||||
@@ -521,7 +473,7 @@ async def test_extract_business_understanding_from_tally_timeout():
|
||||
patch("backend.data.tally._LLM_TIMEOUT", 0.001),
|
||||
pytest.raises(asyncio.TimeoutError),
|
||||
):
|
||||
await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
|
||||
await extract_business_understanding("Q: Name?\nA: Alice")
|
||||
|
||||
|
||||
# ── _refresh_cache ───────────────────────────────────────────────────────────
|
||||
@@ -540,7 +492,7 @@ async def test_refresh_cache_full_fetch():
|
||||
submissions = SAMPLE_SUBMISSIONS
|
||||
|
||||
with (
|
||||
patch("backend.data.tally._settings", mock_settings),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
@@ -588,7 +540,7 @@ async def test_refresh_cache_incremental_fetch():
|
||||
new_submissions = [SAMPLE_SUBMISSIONS[0]] # Just Alice
|
||||
|
||||
with (
|
||||
patch("backend.data.tally._settings", mock_settings),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
|
||||
@@ -86,11 +86,6 @@ class BusinessUnderstandingInput(pydantic.BaseModel):
|
||||
None, description="Any additional context"
|
||||
)
|
||||
|
||||
# Suggested prompts (UI-only, not included in system prompt)
|
||||
suggested_prompts: Optional[list[str]] = pydantic.Field(
|
||||
None, description="LLM-generated suggested prompts based on business context"
|
||||
)
|
||||
|
||||
|
||||
class BusinessUnderstanding(pydantic.BaseModel):
|
||||
"""Full business understanding model returned from database."""
|
||||
@@ -127,9 +122,6 @@ class BusinessUnderstanding(pydantic.BaseModel):
|
||||
# Additional context
|
||||
additional_notes: Optional[str] = None
|
||||
|
||||
# Suggested prompts (UI-only, not included in system prompt)
|
||||
suggested_prompts: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_record: CoPilotUnderstanding) -> "BusinessUnderstanding":
|
||||
"""Convert database record to Pydantic model."""
|
||||
@@ -157,7 +149,6 @@ class BusinessUnderstanding(pydantic.BaseModel):
|
||||
current_software=_json_to_list(business.get("current_software")),
|
||||
existing_automation=_json_to_list(business.get("existing_automation")),
|
||||
additional_notes=business.get("additional_notes"),
|
||||
suggested_prompts=_json_to_list(data.get("suggested_prompts")),
|
||||
)
|
||||
|
||||
|
||||
@@ -175,62 +166,6 @@ def _merge_lists(existing: list | None, new: list | None) -> list | None:
|
||||
return merged
|
||||
|
||||
|
||||
def merge_business_understanding_data(
|
||||
existing_data: dict[str, Any],
|
||||
input_data: BusinessUnderstandingInput,
|
||||
) -> dict[str, Any]:
|
||||
merged_data = dict(existing_data)
|
||||
|
||||
merged_business: dict[str, Any] = {}
|
||||
if isinstance(merged_data.get("business"), dict):
|
||||
merged_business = dict(merged_data["business"])
|
||||
|
||||
business_string_fields = [
|
||||
"job_title",
|
||||
"business_name",
|
||||
"industry",
|
||||
"business_size",
|
||||
"user_role",
|
||||
"additional_notes",
|
||||
]
|
||||
business_list_fields = [
|
||||
"key_workflows",
|
||||
"daily_activities",
|
||||
"pain_points",
|
||||
"bottlenecks",
|
||||
"manual_tasks",
|
||||
"automation_goals",
|
||||
"current_software",
|
||||
"existing_automation",
|
||||
]
|
||||
|
||||
if input_data.user_name is not None:
|
||||
merged_data["name"] = input_data.user_name
|
||||
|
||||
for field in business_string_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
merged_business[field] = value
|
||||
|
||||
for field in business_list_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
existing_list = _json_to_list(merged_business.get(field))
|
||||
merged_list = _merge_lists(existing_list, value)
|
||||
merged_business[field] = merged_list
|
||||
|
||||
merged_business["version"] = 1
|
||||
merged_data["business"] = merged_business
|
||||
|
||||
# suggested_prompts lives at the top level (not under `business`) because
|
||||
# it's a UI-only artifact consumed by the frontend, not business understanding
|
||||
# data. The `business` sub-dict feeds the system prompt.
|
||||
if input_data.suggested_prompts is not None:
|
||||
merged_data["suggested_prompts"] = input_data.suggested_prompts
|
||||
|
||||
return merged_data
|
||||
|
||||
|
||||
async def _get_from_cache(user_id: str) -> Optional[BusinessUnderstanding]:
|
||||
"""Get business understanding from Redis cache."""
|
||||
try:
|
||||
@@ -310,18 +245,63 @@ async def upsert_business_understanding(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
|
||||
# Get existing data structure or start fresh
|
||||
existing_data: dict[str, Any] = {}
|
||||
if existing and isinstance(existing.data, dict):
|
||||
existing_data = dict(existing.data)
|
||||
|
||||
merged_data = merge_business_understanding_data(existing_data, input_data)
|
||||
existing_business: dict[str, Any] = {}
|
||||
if isinstance(existing_data.get("business"), dict):
|
||||
existing_business = dict(existing_data["business"])
|
||||
|
||||
# Business fields (stored inside business object)
|
||||
business_string_fields = [
|
||||
"job_title",
|
||||
"business_name",
|
||||
"industry",
|
||||
"business_size",
|
||||
"user_role",
|
||||
"additional_notes",
|
||||
]
|
||||
business_list_fields = [
|
||||
"key_workflows",
|
||||
"daily_activities",
|
||||
"pain_points",
|
||||
"bottlenecks",
|
||||
"manual_tasks",
|
||||
"automation_goals",
|
||||
"current_software",
|
||||
"existing_automation",
|
||||
]
|
||||
|
||||
# Handle top-level name field
|
||||
if input_data.user_name is not None:
|
||||
existing_data["name"] = input_data.user_name
|
||||
|
||||
# Business string fields - overwrite if provided
|
||||
for field in business_string_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
existing_business[field] = value
|
||||
|
||||
# Business list fields - merge with existing
|
||||
for field in business_list_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
existing_list = _json_to_list(existing_business.get(field))
|
||||
merged = _merge_lists(existing_list, value)
|
||||
existing_business[field] = merged
|
||||
|
||||
# Set version and nest business data
|
||||
existing_business["version"] = 1
|
||||
existing_data["business"] = existing_business
|
||||
|
||||
# Upsert with the merged data
|
||||
record = await CoPilotUnderstanding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "data": SafeJson(merged_data)},
|
||||
"update": {"data": SafeJson(merged_data)},
|
||||
"create": {"userId": user_id, "data": SafeJson(existing_data)},
|
||||
"update": {"data": SafeJson(existing_data)},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
"""Tests for business understanding merge and format logic."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstanding,
|
||||
BusinessUnderstandingInput,
|
||||
format_understanding_for_prompt,
|
||||
merge_business_understanding_data,
|
||||
)
|
||||
|
||||
|
||||
def _make_input(**kwargs: Any) -> BusinessUnderstandingInput:
|
||||
"""Create a BusinessUnderstandingInput with only the specified fields."""
|
||||
return BusinessUnderstandingInput.model_validate(kwargs)
|
||||
|
||||
|
||||
# ─── merge_business_understanding_data: suggested_prompts ─────────────
|
||||
|
||||
|
||||
def test_merge_suggested_prompts_overwrites_existing():
|
||||
"""New suggested_prompts should fully replace existing ones (not append)."""
|
||||
existing = {
|
||||
"name": "Alice",
|
||||
"business": {"industry": "Tech", "version": 1},
|
||||
"suggested_prompts": ["Old prompt 1", "Old prompt 2"],
|
||||
}
|
||||
input_data = _make_input(
|
||||
suggested_prompts=["New prompt A", "New prompt B", "New prompt C"],
|
||||
)
|
||||
|
||||
result = merge_business_understanding_data(existing, input_data)
|
||||
|
||||
assert result["suggested_prompts"] == [
|
||||
"New prompt A",
|
||||
"New prompt B",
|
||||
"New prompt C",
|
||||
]
|
||||
|
||||
|
||||
def test_merge_suggested_prompts_none_preserves_existing():
|
||||
"""When input has suggested_prompts=None, existing prompts are preserved."""
|
||||
existing = {
|
||||
"name": "Alice",
|
||||
"business": {"industry": "Tech", "version": 1},
|
||||
"suggested_prompts": ["Keep me"],
|
||||
}
|
||||
input_data = _make_input(industry="Finance")
|
||||
|
||||
result = merge_business_understanding_data(existing, input_data)
|
||||
|
||||
assert result["suggested_prompts"] == ["Keep me"]
|
||||
assert result["business"]["industry"] == "Finance"
|
||||
|
||||
|
||||
def test_merge_suggested_prompts_added_to_empty_data():
|
||||
"""Suggested prompts are set at top level even when starting from empty data."""
|
||||
existing: dict[str, Any] = {}
|
||||
input_data = _make_input(suggested_prompts=["Prompt 1"])
|
||||
|
||||
result = merge_business_understanding_data(existing, input_data)
|
||||
|
||||
assert result["suggested_prompts"] == ["Prompt 1"]
|
||||
|
||||
|
||||
def test_merge_suggested_prompts_empty_list_overwrites():
|
||||
"""An explicit empty list should overwrite existing prompts."""
|
||||
existing: dict[str, Any] = {
|
||||
"suggested_prompts": ["Old prompt"],
|
||||
"business": {"version": 1},
|
||||
}
|
||||
input_data = _make_input(suggested_prompts=[])
|
||||
|
||||
result = merge_business_understanding_data(existing, input_data)
|
||||
|
||||
assert result["suggested_prompts"] == []
|
||||
|
||||
|
||||
# ─── format_understanding_for_prompt: excludes suggested_prompts ──────
|
||||
|
||||
|
||||
def test_format_understanding_excludes_suggested_prompts():
|
||||
"""suggested_prompts is UI-only and must NOT appear in the system prompt."""
|
||||
understanding = BusinessUnderstanding(
|
||||
id="test-id",
|
||||
user_id="user-1",
|
||||
created_at=datetime.now(tz=timezone.utc),
|
||||
updated_at=datetime.now(tz=timezone.utc),
|
||||
user_name="Alice",
|
||||
industry="Technology",
|
||||
suggested_prompts=["Automate reports", "Set up alerts", "Track KPIs"],
|
||||
)
|
||||
|
||||
formatted = format_understanding_for_prompt(understanding)
|
||||
|
||||
assert "Alice" in formatted
|
||||
assert "Technology" in formatted
|
||||
assert "suggested_prompts" not in formatted
|
||||
assert "Automate reports" not in formatted
|
||||
assert "Set up alerts" not in formatted
|
||||
assert "Track KPIs" not in formatted
|
||||
@@ -46,7 +46,7 @@ from backend.util.exceptions import (
|
||||
)
|
||||
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
|
||||
from backend.util.settings import Config
|
||||
from backend.util.type import coerce_inputs_to_schema
|
||||
from backend.util.type import convert
|
||||
|
||||
config = Config()
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil]")
|
||||
@@ -213,8 +213,11 @@ def validate_exec(
|
||||
if resolve_input:
|
||||
data = merge_execution_input(data)
|
||||
|
||||
# Coerce non-matching data types to the expected input schema.
|
||||
coerce_inputs_to_schema(data, schema)
|
||||
# Convert non-matching data types to the expected input schema.
|
||||
for name, data_type in schema.__annotations__.items():
|
||||
value = data.get(name)
|
||||
if (value is not None) and (type(value) is not data_type):
|
||||
data[name] = convert(value, data_type)
|
||||
|
||||
# Input data post-merge should contain all required fields from the schema.
|
||||
if missing_input := schema.get_missing_input(data):
|
||||
|
||||
@@ -70,9 +70,6 @@ def _msg_tokens(msg: dict, enc) -> int:
|
||||
# Count tool result tokens
|
||||
tool_call_tokens += _tok_len(item.get("tool_use_id", ""), enc)
|
||||
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
||||
elif isinstance(item, dict) and item.get("type") == "text":
|
||||
# Count text block tokens
|
||||
tool_call_tokens += _tok_len(item.get("text", ""), enc)
|
||||
elif isinstance(item, dict) and "content" in item:
|
||||
# Other content types with content field
|
||||
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
||||
@@ -148,14 +145,10 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
||||
if len(ids) <= max_tok:
|
||||
return text # nothing to do
|
||||
|
||||
# Need at least 3 tokens (head + ellipsis + tail) for meaningful truncation
|
||||
mid = enc.encode(" … ")
|
||||
if max_tok < 3:
|
||||
return enc.decode(mid)
|
||||
|
||||
# Split the allowance between the two ends:
|
||||
head = max_tok // 2 - 1 # -1 for the ellipsis
|
||||
tail = max_tok - head - 1
|
||||
mid = enc.encode(" … ")
|
||||
return enc.decode(ids[:head] + mid + ids[-tail:])
|
||||
|
||||
|
||||
@@ -403,7 +396,7 @@ def validate_and_remove_orphan_tool_responses(
|
||||
|
||||
if log_warning:
|
||||
logger.warning(
|
||||
"Removing %d orphan tool response(s): %s", len(orphan_ids), orphan_ids
|
||||
f"Removing {len(orphan_ids)} orphan tool response(s): {orphan_ids}"
|
||||
)
|
||||
|
||||
return _remove_orphan_tool_responses(messages, orphan_ids)
|
||||
@@ -495,9 +488,8 @@ def _ensure_tool_pairs_intact(
|
||||
# Some tool_call_ids couldn't be resolved - remove those tool responses
|
||||
# This shouldn't happen in normal operation but handles edge cases
|
||||
logger.warning(
|
||||
"Could not find assistant messages for tool_call_ids: %s. "
|
||||
"Removing orphan tool responses.",
|
||||
orphan_tool_call_ids,
|
||||
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
|
||||
"Removing orphan tool responses."
|
||||
)
|
||||
recent_messages = _remove_orphan_tool_responses(
|
||||
recent_messages, orphan_tool_call_ids
|
||||
@@ -505,8 +497,8 @@ def _ensure_tool_pairs_intact(
|
||||
|
||||
if messages_to_prepend:
|
||||
logger.info(
|
||||
"Extended recent messages by %d to preserve tool_call/tool_response pairs",
|
||||
len(messages_to_prepend),
|
||||
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
|
||||
f"tool_call/tool_response pairs"
|
||||
)
|
||||
return messages_to_prepend + recent_messages
|
||||
|
||||
@@ -694,15 +686,11 @@ async def compress_context(
|
||||
msgs = [summary_msg] + recent_msgs
|
||||
|
||||
logger.info(
|
||||
"Context summarized: %d -> %d tokens, summarized %d messages",
|
||||
original_count,
|
||||
total_tokens(),
|
||||
messages_summarized,
|
||||
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
|
||||
f"summarized {messages_summarized} messages"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Summarization failed, continuing with truncation: %s", e
|
||||
)
|
||||
logger.warning(f"Summarization failed, continuing with truncation: {e}")
|
||||
# Fall through to content truncation
|
||||
|
||||
# ---- STEP 2: Normalize content ----------------------------------------
|
||||
@@ -740,12 +728,6 @@ async def compress_context(
|
||||
# This is more granular than dropping all old messages at once.
|
||||
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
||||
deletable: list[int] = []
|
||||
# Count assistant messages to ensure we keep at least one
|
||||
assistant_indices: set[int] = {
|
||||
i
|
||||
for i in range(len(msgs))
|
||||
if msgs[i] is not None and msgs[i].get("role") == "assistant"
|
||||
}
|
||||
for i in range(1, len(msgs) - 1):
|
||||
msg = msgs[i]
|
||||
if (
|
||||
@@ -753,9 +735,6 @@ async def compress_context(
|
||||
and not _is_tool_message(msg)
|
||||
and not _is_objective_message(msg)
|
||||
):
|
||||
# Skip if this is the last remaining assistant message
|
||||
if msg.get("role") == "assistant" and len(assistant_indices) <= 1:
|
||||
continue
|
||||
deletable.append(i)
|
||||
if not deletable:
|
||||
break
|
||||
|
||||
@@ -89,10 +89,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
le=500,
|
||||
description="Thread pool size for FastAPI sync operations. All sync endpoints and dependencies automatically use this pool. Higher values support more concurrent sync operations but use more memory.",
|
||||
)
|
||||
tally_extraction_llm_model: str = Field(
|
||||
default="openai/gpt-4o-mini",
|
||||
description="OpenRouter model ID used for extracting business understanding from Tally form data",
|
||||
)
|
||||
ollama_host: str = Field(
|
||||
default="localhost:11434",
|
||||
description="Default Ollama host; exempted from SSRF checks.",
|
||||
@@ -121,10 +117,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default=True,
|
||||
description="If authentication is enabled or not",
|
||||
)
|
||||
enable_invite_gate: bool = Field(
|
||||
default=True,
|
||||
description="If the invite-only signup gate is enforced",
|
||||
)
|
||||
enable_credit: bool = Field(
|
||||
default=False,
|
||||
description="If user credit system is enabled or not",
|
||||
|
||||
@@ -249,87 +249,6 @@ def convert(value: Any, target_type: Any) -> Any:
|
||||
raise ConversionError(f"Failed to convert {value} to {target_type}") from e
|
||||
|
||||
|
||||
def _value_satisfies_type(value: Any, target: Any) -> bool:
|
||||
"""Check whether *value* already satisfies *target*, including inner elements.
|
||||
|
||||
For union types this checks each member; for generic container types it
|
||||
recursively checks that inner elements satisfy the type args (e.g. every
|
||||
element in a ``list[str]`` is a ``str``). Returns ``False`` when uncertain
|
||||
so the caller falls through to :func:`convert`.
|
||||
"""
|
||||
# typing.Any cannot be used with isinstance(); treat as always satisfied.
|
||||
if target is Any:
|
||||
return True
|
||||
|
||||
origin = get_origin(target)
|
||||
|
||||
if origin is Union or origin is types.UnionType:
|
||||
non_none = [a for a in get_args(target) if a is not type(None)]
|
||||
return any(_value_satisfies_type(value, member) for member in non_none)
|
||||
|
||||
# Generic container type (e.g. list[str], dict[str, int])
|
||||
if origin is not None:
|
||||
# Guard: origin may not be a runtime type (e.g. Literal)
|
||||
if not isinstance(origin, type):
|
||||
return False
|
||||
if not isinstance(value, origin):
|
||||
return False
|
||||
args = get_args(target)
|
||||
if not args:
|
||||
return True
|
||||
# Check inner elements satisfy the type args
|
||||
if _is_type_or_subclass(origin, list):
|
||||
return all(_value_satisfies_type(v, args[0]) for v in value)
|
||||
if _is_type_or_subclass(origin, dict) and len(args) >= 2:
|
||||
return all(
|
||||
_value_satisfies_type(k, args[0]) and _value_satisfies_type(v, args[1])
|
||||
for k, v in value.items()
|
||||
)
|
||||
if (
|
||||
_is_type_or_subclass(origin, set) or _is_type_or_subclass(origin, frozenset)
|
||||
) and args:
|
||||
return all(_value_satisfies_type(v, args[0]) for v in value)
|
||||
if _is_type_or_subclass(origin, tuple):
|
||||
# Homogeneous tuple[T, ...] — single type + Ellipsis
|
||||
if len(args) == 2 and args[1] is Ellipsis:
|
||||
return all(_value_satisfies_type(v, args[0]) for v in value)
|
||||
# Heterogeneous tuple[T1, T2, ...] — positional types
|
||||
if len(value) != len(args):
|
||||
return False
|
||||
return all(_value_satisfies_type(v, t) for v, t in zip(value, args))
|
||||
# Unhandled generic origin — fall through to convert()
|
||||
return False
|
||||
|
||||
# Simple type (e.g. str, int)
|
||||
if isinstance(target, type):
|
||||
return isinstance(value, target)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def coerce_inputs_to_schema(data: dict[str, Any], schema: type) -> None:
|
||||
"""Coerce *data* values in-place to match *schema*'s field types.
|
||||
|
||||
Uses ``model_fields`` (not ``__annotations__``) so inherited fields are
|
||||
included. Skips coercion when the value already satisfies the target
|
||||
type — in particular for union-typed fields where the value matches one
|
||||
member but differs from the annotation object itself.
|
||||
|
||||
This is the single authoritative coercion step shared by the executor
|
||||
(``validate_exec``) and the CoPilot (``execute_block``).
|
||||
"""
|
||||
for name, field_info in schema.model_fields.items():
|
||||
value = data.get(name)
|
||||
if value is None:
|
||||
continue
|
||||
target = field_info.annotation
|
||||
if target is None:
|
||||
continue
|
||||
if _value_satisfies_type(value, target):
|
||||
continue
|
||||
data[name] = convert(value, target)
|
||||
|
||||
|
||||
class FormattedStringType(str):
|
||||
string_format: str
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from typing import Any, List, Literal, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util.type import _value_satisfies_type, coerce_inputs_to_schema, convert
|
||||
from backend.util.type import convert
|
||||
|
||||
|
||||
def test_type_conversion():
|
||||
@@ -48,343 +46,3 @@ def test_type_conversion():
|
||||
# Test other empty list conversions
|
||||
assert convert([], int) == 0 # len([]) = 0
|
||||
assert convert([], Optional[int]) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _value_satisfies_type
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValueSatisfiesType:
|
||||
# --- simple types ---
|
||||
def test_simple_match(self):
|
||||
assert _value_satisfies_type("hello", str) is True
|
||||
assert _value_satisfies_type(42, int) is True
|
||||
assert _value_satisfies_type(3.14, float) is True
|
||||
assert _value_satisfies_type(True, bool) is True
|
||||
|
||||
def test_simple_mismatch(self):
|
||||
assert _value_satisfies_type("hello", int) is False
|
||||
assert _value_satisfies_type(42, str) is False
|
||||
assert _value_satisfies_type([1, 2], str) is False
|
||||
|
||||
# --- Any ---
|
||||
def test_any_always_satisfied(self):
|
||||
assert _value_satisfies_type("hello", Any) is True
|
||||
assert _value_satisfies_type(42, Any) is True
|
||||
assert _value_satisfies_type([1, 2], Any) is True
|
||||
assert _value_satisfies_type(None, Any) is True
|
||||
|
||||
# --- Optional / Union ---
|
||||
def test_optional_with_value(self):
|
||||
assert _value_satisfies_type("hello", Optional[str]) is True
|
||||
assert _value_satisfies_type(42, Optional[int]) is True
|
||||
|
||||
def test_optional_mismatch(self):
|
||||
assert _value_satisfies_type(42, Optional[str]) is False
|
||||
|
||||
def test_union_matches_first_member(self):
|
||||
assert _value_satisfies_type("hello", str | list[str]) is True
|
||||
|
||||
def test_union_matches_second_member(self):
|
||||
assert _value_satisfies_type(["a", "b"], str | list[str]) is True
|
||||
|
||||
def test_union_no_match(self):
|
||||
assert _value_satisfies_type(42, str | list[str]) is False
|
||||
|
||||
# --- list[T] ---
|
||||
def test_list_str_all_match(self):
|
||||
assert _value_satisfies_type(["a", "b", "c"], list[str]) is True
|
||||
|
||||
def test_list_str_inner_mismatch(self):
|
||||
assert _value_satisfies_type([1, 2, 3], list[str]) is False
|
||||
|
||||
def test_list_int_all_match(self):
|
||||
assert _value_satisfies_type([1, 2, 3], list[int]) is True
|
||||
|
||||
def test_list_int_inner_mismatch(self):
|
||||
assert _value_satisfies_type(["1", "2"], list[int]) is False
|
||||
|
||||
def test_empty_list_satisfies_any_list_type(self):
|
||||
assert _value_satisfies_type([], list[str]) is True
|
||||
assert _value_satisfies_type([], list[int]) is True
|
||||
|
||||
def test_string_does_not_satisfy_list(self):
|
||||
assert _value_satisfies_type("hello", list[str]) is False
|
||||
|
||||
# --- nested list[list[str]] ---
|
||||
def test_nested_list_all_match(self):
|
||||
assert _value_satisfies_type([["a", "b"], ["c"]], list[list[str]]) is True
|
||||
|
||||
def test_nested_list_inner_mismatch(self):
|
||||
assert _value_satisfies_type([["a", 1], ["c"]], list[list[str]]) is False
|
||||
|
||||
def test_nested_list_outer_mismatch(self):
|
||||
assert _value_satisfies_type(["a", "b"], list[list[str]]) is False
|
||||
|
||||
# --- dict[K, V] ---
|
||||
def test_dict_str_int_match(self):
|
||||
assert _value_satisfies_type({"a": 1, "b": 2}, dict[str, int]) is True
|
||||
|
||||
def test_dict_str_int_value_mismatch(self):
|
||||
assert _value_satisfies_type({"a": "1", "b": "2"}, dict[str, int]) is False
|
||||
|
||||
def test_dict_str_int_key_mismatch(self):
|
||||
assert _value_satisfies_type({1: 1, 2: 2}, dict[str, int]) is False
|
||||
|
||||
def test_empty_dict_satisfies(self):
|
||||
assert _value_satisfies_type({}, dict[str, int]) is True
|
||||
|
||||
# --- set[T] / tuple[T] ---
|
||||
def test_set_match(self):
|
||||
assert _value_satisfies_type({1, 2, 3}, set[int]) is True
|
||||
|
||||
def test_set_mismatch(self):
|
||||
assert _value_satisfies_type({"a", "b"}, set[int]) is False
|
||||
|
||||
def test_tuple_homogeneous_match(self):
|
||||
assert _value_satisfies_type((1, 2, 3), tuple[int, ...]) is True
|
||||
|
||||
def test_tuple_homogeneous_mismatch(self):
|
||||
assert _value_satisfies_type((1, "2", 3), tuple[int, ...]) is False
|
||||
|
||||
def test_tuple_heterogeneous_match(self):
|
||||
assert _value_satisfies_type(("a", 1, True), tuple[str, int, bool]) is True
|
||||
|
||||
def test_tuple_heterogeneous_mismatch(self):
|
||||
assert _value_satisfies_type(("a", "b", True), tuple[str, int, bool]) is False
|
||||
|
||||
def test_tuple_heterogeneous_wrong_length(self):
|
||||
assert _value_satisfies_type(("a", 1), tuple[str, int, bool]) is False
|
||||
|
||||
# --- bare generics (no args) ---
|
||||
def test_bare_list(self):
|
||||
assert _value_satisfies_type([1, "a"], list) is True
|
||||
|
||||
def test_bare_dict(self):
|
||||
assert _value_satisfies_type({"a": 1}, dict) is True
|
||||
|
||||
# --- union with generic inner mismatch ---
|
||||
def test_union_list_with_wrong_inner_falls_through(self):
|
||||
# [1, 2] doesn't satisfy list[str] (inner mismatch), and not str either
|
||||
assert _value_satisfies_type([1, 2], str | list[str]) is False
|
||||
|
||||
# --- Literal (non-runtime origin) ---
|
||||
def test_literal_does_not_crash(self):
|
||||
"""Literal origins are not runtime types — should return False, not crash."""
|
||||
assert _value_satisfies_type("active", Literal["active", "inactive"]) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# coerce_inputs_to_schema — using real Pydantic models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SampleSchema(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
items: list[str]
|
||||
config: dict[str, int] = {}
|
||||
|
||||
|
||||
class NestedSchema(BaseModel):
|
||||
rows: list[list[str]]
|
||||
|
||||
|
||||
class UnionSchema(BaseModel):
|
||||
content: str | list[str]
|
||||
|
||||
|
||||
class OptionalSchema(BaseModel):
|
||||
label: Optional[str] = None
|
||||
value: int = 0
|
||||
|
||||
|
||||
class AnyFieldSchema(BaseModel):
|
||||
data: Any
|
||||
|
||||
|
||||
class TestCoerceInputsToSchema:
|
||||
def test_string_to_int(self):
|
||||
data: dict[str, Any] = {"name": "test", "count": "42", "items": ["a"]}
|
||||
coerce_inputs_to_schema(data, SampleSchema)
|
||||
assert data["count"] == 42
|
||||
assert isinstance(data["count"], int)
|
||||
|
||||
def test_json_string_to_list(self):
|
||||
data: dict[str, Any] = {"name": "test", "count": 1, "items": '["a","b","c"]'}
|
||||
coerce_inputs_to_schema(data, SampleSchema)
|
||||
assert data["items"] == ["a", "b", "c"]
|
||||
|
||||
def test_already_correct_types_unchanged(self):
|
||||
data: dict[str, Any] = {
|
||||
"name": "test",
|
||||
"count": 42,
|
||||
"items": ["a", "b"],
|
||||
"config": {"x": 1},
|
||||
}
|
||||
coerce_inputs_to_schema(data, SampleSchema)
|
||||
assert data == {
|
||||
"name": "test",
|
||||
"count": 42,
|
||||
"items": ["a", "b"],
|
||||
"config": {"x": 1},
|
||||
}
|
||||
|
||||
def test_inner_element_coercion(self):
|
||||
"""list[str] with int inner elements → coerced to strings."""
|
||||
data: dict[str, Any] = {"name": "test", "count": 1, "items": [1, 2, 3]}
|
||||
coerce_inputs_to_schema(data, SampleSchema)
|
||||
assert data["items"] == ["1", "2", "3"]
|
||||
|
||||
def test_dict_value_coercion(self):
|
||||
"""dict[str, int] with string values → coerced to ints."""
|
||||
data: dict[str, Any] = {
|
||||
"name": "test",
|
||||
"count": 1,
|
||||
"items": [],
|
||||
"config": {"x": "10", "y": "20"},
|
||||
}
|
||||
coerce_inputs_to_schema(data, SampleSchema)
|
||||
assert data["config"] == {"x": 10, "y": 20}
|
||||
|
||||
def test_nested_list_from_json_string(self):
|
||||
data: dict[str, Any] = {
|
||||
"rows": '[["Name","Score"],["Alice","90"]]',
|
||||
}
|
||||
coerce_inputs_to_schema(data, NestedSchema)
|
||||
assert data["rows"] == [["Name", "Score"], ["Alice", "90"]]
|
||||
|
||||
def test_nested_list_already_correct(self):
|
||||
original = [["a", "b"], ["c", "d"]]
|
||||
data: dict[str, Any] = {"rows": original}
|
||||
coerce_inputs_to_schema(data, NestedSchema)
|
||||
assert data["rows"] == original
|
||||
|
||||
def test_union_preserves_valid_list(self):
|
||||
"""list[str] value should NOT be stringified for str | list[str]."""
|
||||
data: dict[str, Any] = {"content": ["a", "b"]}
|
||||
coerce_inputs_to_schema(data, UnionSchema)
|
||||
assert data["content"] == ["a", "b"]
|
||||
assert isinstance(data["content"], list)
|
||||
|
||||
def test_union_preserves_valid_string(self):
|
||||
data: dict[str, Any] = {"content": "hello"}
|
||||
coerce_inputs_to_schema(data, UnionSchema)
|
||||
assert data["content"] == "hello"
|
||||
|
||||
def test_union_list_with_wrong_inner_gets_coerced(self):
|
||||
"""[1, 2] for str | list[str] — inner ints don't match list[str],
|
||||
so convert() is called. convert tries str first → stringifies."""
|
||||
data: dict[str, Any] = {"content": [1, 2]}
|
||||
coerce_inputs_to_schema(data, UnionSchema)
|
||||
# convert([1,2], str | list[str]) tries str first → "[1, 2]"
|
||||
# This is convert()'s union behavior — str wins over list[str]
|
||||
assert isinstance(data["content"], (str, list))
|
||||
|
||||
def test_skips_none_values(self):
|
||||
data: dict[str, Any] = {"label": None, "value": "5"}
|
||||
coerce_inputs_to_schema(data, OptionalSchema)
|
||||
assert data["label"] is None
|
||||
assert data["value"] == 5
|
||||
|
||||
def test_skips_missing_fields(self):
|
||||
data: dict[str, Any] = {"value": "10"}
|
||||
coerce_inputs_to_schema(data, OptionalSchema)
|
||||
assert "label" not in data
|
||||
assert data["value"] == 10
|
||||
|
||||
def test_any_field_skipped(self):
|
||||
"""Fields typed as Any should pass through without coercion."""
|
||||
data: dict[str, Any] = {"data": [1, "mixed", {"nested": True}]}
|
||||
coerce_inputs_to_schema(data, AnyFieldSchema)
|
||||
assert data["data"] == [1, "mixed", {"nested": True}]
|
||||
|
||||
def test_preserves_all_convert_capabilities(self):
|
||||
"""Verify coerce_inputs_to_schema doesn't lose any convert() capability
|
||||
that existed before the _value_satisfies_type gate was added."""
|
||||
|
||||
class FullSchema(BaseModel):
|
||||
as_int: int
|
||||
as_float: float
|
||||
as_bool: bool
|
||||
as_str: str
|
||||
as_list: list[int]
|
||||
as_dict: dict[str, str]
|
||||
|
||||
data: dict[str, Any] = {
|
||||
"as_int": "42",
|
||||
"as_float": "3.14",
|
||||
"as_bool": "True",
|
||||
"as_str": 123,
|
||||
"as_list": "[1,2,3]",
|
||||
"as_dict": '{"a": "b"}',
|
||||
}
|
||||
coerce_inputs_to_schema(data, FullSchema)
|
||||
assert data["as_int"] == 42
|
||||
assert data["as_float"] == 3.14
|
||||
assert data["as_bool"] is True
|
||||
assert data["as_str"] == "123"
|
||||
assert data["as_list"] == [1, 2, 3]
|
||||
assert data["as_dict"] == {"a": "b"}
|
||||
|
||||
def test_inherited_fields_are_coerced(self):
|
||||
"""model_fields includes inherited fields; __annotations__ does not.
|
||||
This verifies that fields from a parent schema are still coerced."""
|
||||
|
||||
class ParentSchema(BaseModel):
|
||||
base_count: int
|
||||
|
||||
class ChildSchema(ParentSchema):
|
||||
name: str
|
||||
|
||||
# base_count is inherited — __annotations__ wouldn't include it
|
||||
assert "base_count" not in ChildSchema.__annotations__
|
||||
assert "base_count" in ChildSchema.model_fields
|
||||
|
||||
data: dict[str, Any] = {"base_count": "42", "name": "test"}
|
||||
coerce_inputs_to_schema(data, ChildSchema)
|
||||
assert data["base_count"] == 42
|
||||
assert isinstance(data["base_count"], int)
|
||||
|
||||
def test_nested_pydantic_model_field(self):
|
||||
"""dict input for a Pydantic model-typed field passes through.
|
||||
convert() doesn't construct Pydantic models — Pydantic validation
|
||||
handles that downstream. This test documents the behavior."""
|
||||
|
||||
class InnerModel(BaseModel):
|
||||
x: int
|
||||
|
||||
class OuterModel(BaseModel):
|
||||
inner: InnerModel
|
||||
|
||||
data: dict[str, Any] = {"inner": {"x": 1}}
|
||||
coerce_inputs_to_schema(data, OuterModel)
|
||||
# dict stays as dict — convert() doesn't construct Pydantic models
|
||||
assert data["inner"] == {"x": 1}
|
||||
assert isinstance(data["inner"], dict)
|
||||
|
||||
def test_literal_field_passes_through(self):
|
||||
"""Literal-typed fields should not crash coercion."""
|
||||
|
||||
class LiteralSchema(BaseModel):
|
||||
status: Literal["active", "inactive"]
|
||||
|
||||
data: dict[str, Any] = {"status": "active"}
|
||||
coerce_inputs_to_schema(data, LiteralSchema)
|
||||
assert data["status"] == "active"
|
||||
|
||||
def test_list_of_pydantic_model_field(self):
|
||||
"""list[dict] for list[PydanticModel] passes through unchanged."""
|
||||
|
||||
class ItemModel(BaseModel):
|
||||
name: str
|
||||
|
||||
class ContainerModel(BaseModel):
|
||||
items: list[ItemModel]
|
||||
|
||||
data: dict[str, Any] = {"items": [{"name": "a"}, {"name": "b"}]}
|
||||
coerce_inputs_to_schema(data, ContainerModel)
|
||||
# Dicts stay as dicts — Pydantic validation handles construction
|
||||
assert data["items"] == [{"name": "a"}, {"name": "b"}]
|
||||
assert isinstance(data["items"][0], dict)
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
-- Migrate Gemini 3 Pro Preview to Gemini 3.1 Pro Preview
|
||||
-- This updates all AgentNode blocks that use the deprecated Gemini 3 Pro Preview model
|
||||
-- Google is shutting down google/gemini-3-pro-preview on March 9, 2026
|
||||
|
||||
-- Update AgentNode constant inputs
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
'"google/gemini-3.1-pro-preview"'::jsonb
|
||||
)
|
||||
WHERE "constantInput"::jsonb->>'model' = 'google/gemini-3-pro-preview';
|
||||
|
||||
-- Update AgentPreset input overrides (stored in AgentNodeExecutionInputOutput)
|
||||
UPDATE "AgentNodeExecutionInputOutput"
|
||||
SET "data" = JSONB_SET(
|
||||
"data"::jsonb,
|
||||
'{model}',
|
||||
'"google/gemini-3.1-pro-preview"'::jsonb
|
||||
)
|
||||
WHERE "agentPresetId" IS NOT NULL
|
||||
AND "data"::jsonb->>'model' = 'google/gemini-3-pro-preview';
|
||||
@@ -1,46 +0,0 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- You are about to drop the column `search` on the `StoreListingVersion` table. All the data in the column will be lost.
|
||||
|
||||
*/-- CreateEnum
|
||||
CREATE TYPE "InvitedUserStatus" AS ENUM('INVITED',
|
||||
'CLAIMED',
|
||||
'REVOKED');
|
||||
-- CreateEnum
|
||||
CREATE TYPE "TallyComputationStatus" AS ENUM('PENDING',
|
||||
'RUNNING',
|
||||
'READY',
|
||||
'FAILED');
|
||||
-- CreateTable
|
||||
CREATE TABLE "InvitedUser"(
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"email" TEXT NOT NULL,
|
||||
"status" "InvitedUserStatus" NOT NULL DEFAULT 'INVITED',
|
||||
"authUserId" TEXT,
|
||||
"name" TEXT,
|
||||
"tallyUnderstanding" JSONB,
|
||||
"tallyStatus" "TallyComputationStatus" NOT NULL DEFAULT 'PENDING',
|
||||
"tallyComputedAt" TIMESTAMP(3),
|
||||
"tallyError" TEXT,
|
||||
CONSTRAINT "InvitedUser_pkey" PRIMARY KEY("id")
|
||||
);
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "InvitedUser_email_key"
|
||||
ON "InvitedUser"("email");
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "InvitedUser_authUserId_key"
|
||||
ON "InvitedUser"("authUserId");
|
||||
-- CreateIndex
|
||||
CREATE INDEX "InvitedUser_status_idx"
|
||||
ON "InvitedUser"("status");
|
||||
-- CreateIndex
|
||||
CREATE INDEX "InvitedUser_tallyStatus_idx"
|
||||
ON "InvitedUser"("tallyStatus");
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "InvitedUser" ADD CONSTRAINT "InvitedUser_authUserId_fkey" FOREIGN KEY("authUserId") REFERENCES "User"("id")
|
||||
ON DELETE
|
||||
SET NULL
|
||||
ON UPDATE CASCADE;
|
||||
@@ -1,15 +0,0 @@
|
||||
-- Drop the trigger that auto-creates User + Profile on auth.users INSERT.
|
||||
-- The invite activation flow in get_or_activate_user() now handles this.
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1 FROM information_schema.tables
|
||||
WHERE table_schema = 'auth' AND table_name = 'users'
|
||||
) THEN
|
||||
DROP TRIGGER IF EXISTS user_added_to_platform ON auth.users;
|
||||
END IF;
|
||||
END $$;
|
||||
|
||||
DROP FUNCTION IF EXISTS add_user_and_profile_to_platform();
|
||||
DROP FUNCTION IF EXISTS add_user_to_platform();
|
||||
-- Keep generate_username() — used by backfill migration 20250205110132
|
||||
@@ -1,7 +0,0 @@
|
||||
-- DropIndex
|
||||
DROP INDEX "InvitedUser_status_idx";
|
||||
-- DropIndex
|
||||
DROP INDEX "InvitedUser_tallyStatus_idx";
|
||||
-- CreateIndex
|
||||
CREATE INDEX "InvitedUser_createdAt_idx"
|
||||
ON "InvitedUser"("createdAt");
|
||||
@@ -1,40 +0,0 @@
|
||||
-- Fix PerplexityBlock nodes that have invalid model values (e.g. gpt-4o,
|
||||
-- gpt-5.2-2025-12-11) set by the agent generator. Defaults them to the
|
||||
-- standard "perplexity/sonar" model.
|
||||
--
|
||||
-- PerplexityBlock ID: c8a5f2e9-8b3d-4a7e-9f6c-1d5e3c9b7a4f
|
||||
-- Valid models: perplexity/sonar, perplexity/sonar-pro, perplexity/sonar-deep-research
|
||||
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
'"perplexity/sonar"'::jsonb
|
||||
)
|
||||
WHERE "agentBlockId" = 'c8a5f2e9-8b3d-4a7e-9f6c-1d5e3c9b7a4f'
|
||||
AND "constantInput"::jsonb ? 'model'
|
||||
AND "constantInput"::jsonb->>'model' NOT IN (
|
||||
'perplexity/sonar',
|
||||
'perplexity/sonar-pro',
|
||||
'perplexity/sonar-deep-research'
|
||||
);
|
||||
|
||||
-- Update AgentPreset input overrides (stored in AgentNodeExecutionInputOutput).
|
||||
-- The table links to AgentNode through AgentNodeExecution, not directly.
|
||||
UPDATE "AgentNodeExecutionInputOutput" io
|
||||
SET "data" = JSONB_SET(
|
||||
io."data"::jsonb,
|
||||
'{model}',
|
||||
'"perplexity/sonar"'::jsonb
|
||||
)
|
||||
FROM "AgentNodeExecution" exe
|
||||
JOIN "AgentNode" n ON n."id" = exe."agentNodeId"
|
||||
WHERE io."agentPresetId" IS NOT NULL
|
||||
AND (io."referencedByInputExecId" = exe."id" OR io."referencedByOutputExecId" = exe."id")
|
||||
AND n."agentBlockId" = 'c8a5f2e9-8b3d-4a7e-9f6c-1d5e3c9b7a4f'
|
||||
AND io."data"::jsonb ? 'model'
|
||||
AND io."data"::jsonb->>'model' NOT IN (
|
||||
'perplexity/sonar',
|
||||
'perplexity/sonar-pro',
|
||||
'perplexity/sonar-deep-research'
|
||||
);
|
||||
8
autogpt_platform/backend/poetry.lock
generated
8
autogpt_platform/backend/poetry.lock
generated
@@ -1282,14 +1282,14 @@ pgp = ["gpg"]
|
||||
|
||||
[[package]]
|
||||
name = "e2b"
|
||||
version = "2.15.2"
|
||||
version = "2.15.1"
|
||||
description = "E2B SDK that give agents cloud environments"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "e2b-2.15.2-py3-none-any.whl", hash = "sha256:19a56fbdea25974dc81426ed48337eae6cea91d404f5bcf8861a5a2c6e4d982a"},
|
||||
{file = "e2b-2.15.2.tar.gz", hash = "sha256:414379d2421d6827eeb2eb50a4d6b3fdb7d691b39ff73b5ea05ca4b532819831"},
|
||||
{file = "e2b-2.15.1-py3-none-any.whl", hash = "sha256:a3bc4e004eab51fb05bae44e9ee4fe821e4637260f4ce3064c8f7c6ed7f5a2a0"},
|
||||
{file = "e2b-2.15.1.tar.gz", hash = "sha256:a4f1bbc8b5180a8a1098079257fcb73e42503ed546098f676f722f11f0d68c09"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -8882,4 +8882,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "4e4365721cd3b68c58c237353b74adae1c64233fd4446904c335f23eb866fdca"
|
||||
content-hash = "618d61b0586ab82fec1e28d1feb549a198e0b5c9d152e808862e55efc00a65b9"
|
||||
|
||||
@@ -20,7 +20,7 @@ claude-agent-sdk = "0.1.45" # see copilot/sdk/sdk_compat_test.py for capability
|
||||
click = "^8.2.0"
|
||||
cryptography = "^46.0"
|
||||
discord-py = "^2.5.2"
|
||||
e2b = "^2.15.2"
|
||||
e2b = "^2.0"
|
||||
e2b-code-interpreter = "^2.0"
|
||||
elevenlabs = "^1.50.0"
|
||||
fastapi = "^0.128.6"
|
||||
|
||||
@@ -65,7 +65,6 @@ model User {
|
||||
NotificationBatches UserNotificationBatch[]
|
||||
PendingHumanReviews PendingHumanReview[]
|
||||
Workspace UserWorkspace?
|
||||
ClaimedInvite InvitedUser? @relation("InvitedUserAuthUser")
|
||||
|
||||
// OAuth Provider relations
|
||||
OAuthApplications OAuthApplication[]
|
||||
@@ -74,38 +73,6 @@ model User {
|
||||
OAuthRefreshTokens OAuthRefreshToken[]
|
||||
}
|
||||
|
||||
enum InvitedUserStatus {
|
||||
INVITED
|
||||
CLAIMED
|
||||
REVOKED
|
||||
}
|
||||
|
||||
enum TallyComputationStatus {
|
||||
PENDING
|
||||
RUNNING
|
||||
READY
|
||||
FAILED
|
||||
}
|
||||
|
||||
model InvitedUser {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
email String @unique
|
||||
status InvitedUserStatus @default(INVITED)
|
||||
authUserId String? @unique
|
||||
AuthUser User? @relation("InvitedUserAuthUser", fields: [authUserId], references: [id], onDelete: SetNull)
|
||||
name String?
|
||||
|
||||
tallyUnderstanding Json?
|
||||
tallyStatus TallyComputationStatus @default(PENDING)
|
||||
tallyComputedAt DateTime?
|
||||
tallyError String?
|
||||
|
||||
@@index([createdAt])
|
||||
}
|
||||
|
||||
enum OnboardingStep {
|
||||
// Introductory onboarding (Library)
|
||||
WELCOME
|
||||
@@ -1025,7 +992,7 @@ model StoreListing {
|
||||
ActiveVersion StoreListingVersion? @relation("ActiveVersion", fields: [activeVersionId], references: [id])
|
||||
|
||||
// The agent link here is only so we can do lookup on agentId
|
||||
agentGraphId String @unique
|
||||
agentGraphId String @unique
|
||||
|
||||
owningUserId String
|
||||
OwningUser User @relation(fields: [owningUserId], references: [id])
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
"id": "test-agent-1",
|
||||
"graph_id": "test-agent-1",
|
||||
"graph_version": 1,
|
||||
"owner_user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||
"image_url": null,
|
||||
"creator_name": "Test Creator",
|
||||
"creator_image_url": "",
|
||||
@@ -50,6 +51,7 @@
|
||||
"id": "test-agent-2",
|
||||
"graph_id": "test-agent-2",
|
||||
"graph_version": 1,
|
||||
"owner_user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||
"image_url": null,
|
||||
"creator_name": "Test Creator",
|
||||
"creator_image_url": "",
|
||||
|
||||
@@ -84,27 +84,6 @@ class TestGmailReadBlock:
|
||||
assert "Hello World" in result
|
||||
assert "This is HTML content" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_html_fallback_when_html2text_conversion_fails(self):
|
||||
"""Fallback to raw HTML when html2text converter raises unexpectedly."""
|
||||
html_text = "<html><body><p>Broken <b>HTML</p></body></html>"
|
||||
|
||||
msg = {
|
||||
"id": "test_msg_html_error",
|
||||
"payload": {
|
||||
"mimeType": "text/html",
|
||||
"body": {"data": self._encode_base64(html_text)},
|
||||
},
|
||||
}
|
||||
|
||||
with patch("html2text.HTML2Text") as mock_html2text:
|
||||
mock_converter = Mock()
|
||||
mock_converter.handle.side_effect = ValueError("conversion failed")
|
||||
mock_html2text.return_value = mock_converter
|
||||
|
||||
result = await self.gmail_block._get_email_body(msg, self.mock_service)
|
||||
assert result == html_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_html_fallback_when_html2text_unavailable(self):
|
||||
"""Test fallback to raw HTML when html2text is not available."""
|
||||
|
||||
@@ -34,7 +34,7 @@ from backend.data.auth.api_key import create_api_key
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.db import prisma
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.invited_user import get_or_activate_user
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.util.clients import get_supabase
|
||||
|
||||
faker = Faker()
|
||||
@@ -151,7 +151,7 @@ class TestDataCreator:
|
||||
}
|
||||
|
||||
# Use the API function to create user in local database
|
||||
user = await get_or_activate_user(user_data)
|
||||
user = await get_or_create_user(user_data)
|
||||
users.append(user.model_dump())
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,14 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { Sidebar } from "@/components/__legacy__/Sidebar";
|
||||
import {
|
||||
UsersIcon,
|
||||
CurrencyDollarSimpleIcon,
|
||||
UserPlusIcon,
|
||||
MagnifyingGlassIcon,
|
||||
FileTextIcon,
|
||||
SlidersHorizontalIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { Users, DollarSign, UserSearch, FileText } from "lucide-react";
|
||||
|
||||
import { IconSliders } from "@/components/__legacy__/ui/icons";
|
||||
|
||||
const sidebarLinkGroups = [
|
||||
{
|
||||
@@ -16,32 +9,27 @@ const sidebarLinkGroups = [
|
||||
{
|
||||
text: "Marketplace Management",
|
||||
href: "/admin/marketplace",
|
||||
icon: <UsersIcon size={24} />,
|
||||
icon: <Users className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "User Spending",
|
||||
href: "/admin/spending",
|
||||
icon: <CurrencyDollarSimpleIcon size={24} />,
|
||||
},
|
||||
{
|
||||
text: "Beta Invites",
|
||||
href: "/admin/users",
|
||||
icon: <UserPlusIcon size={24} />,
|
||||
icon: <DollarSign className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "User Impersonation",
|
||||
href: "/admin/impersonation",
|
||||
icon: <MagnifyingGlassIcon size={24} />,
|
||||
icon: <UserSearch className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Execution Analytics",
|
||||
href: "/admin/execution-analytics",
|
||||
icon: <FileTextIcon size={24} />,
|
||||
icon: <FileText className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Admin User Management",
|
||||
href: "/admin/settings",
|
||||
icon: <SlidersHorizontalIcon size={24} />,
|
||||
icon: <IconSliders className="h-6 w-6" />,
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Card } from "@/components/atoms/Card/Card";
|
||||
import { BulkInviteForm } from "../BulkInviteForm/BulkInviteForm";
|
||||
import { InviteUserForm } from "../InviteUserForm/InviteUserForm";
|
||||
import { InvitedUsersTable } from "../InvitedUsersTable/InvitedUsersTable";
|
||||
import { useAdminUsersPage } from "../../useAdminUsersPage";
|
||||
|
||||
export function AdminUsersPage() {
|
||||
const {
|
||||
email,
|
||||
name,
|
||||
bulkInviteFile,
|
||||
bulkInviteInputKey,
|
||||
lastBulkInviteResult,
|
||||
invitedUsers,
|
||||
isLoadingInvitedUsers,
|
||||
isRefreshingInvitedUsers,
|
||||
isCreatingInvite,
|
||||
isBulkInviting,
|
||||
pendingInviteAction,
|
||||
setEmail,
|
||||
setName,
|
||||
handleBulkInviteFileChange,
|
||||
handleBulkInviteSubmit,
|
||||
handleCreateInvite,
|
||||
handleRetryTally,
|
||||
handleRevoke,
|
||||
} = useAdminUsersPage();
|
||||
|
||||
return (
|
||||
<div className="mx-auto flex max-w-7xl flex-col gap-6 p-6">
|
||||
<div className="flex flex-col gap-2">
|
||||
<h1 className="text-3xl font-bold text-zinc-900">Beta Invites</h1>
|
||||
<p className="max-w-3xl text-sm text-zinc-600">
|
||||
Pre-provision beta users before they sign up. Invites store the
|
||||
platform-side record, run Tally understanding extraction, and activate
|
||||
the real account on the user's first authenticated request.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-6 xl:grid-cols-[24rem,1fr]">
|
||||
<div className="flex flex-col gap-6">
|
||||
<Card className="border border-zinc-200 shadow-sm">
|
||||
<InviteUserForm
|
||||
email={email}
|
||||
name={name}
|
||||
isSubmitting={isCreatingInvite}
|
||||
onEmailChange={setEmail}
|
||||
onNameChange={setName}
|
||||
onSubmit={handleCreateInvite}
|
||||
/>
|
||||
</Card>
|
||||
|
||||
<Card className="border border-zinc-200 shadow-sm">
|
||||
<BulkInviteForm
|
||||
selectedFile={bulkInviteFile}
|
||||
inputKey={bulkInviteInputKey}
|
||||
isSubmitting={isBulkInviting}
|
||||
lastResult={lastBulkInviteResult}
|
||||
onFileChange={handleBulkInviteFileChange}
|
||||
onSubmit={handleBulkInviteSubmit}
|
||||
/>
|
||||
</Card>
|
||||
</div>
|
||||
|
||||
<Card className="border border-zinc-200 shadow-sm">
|
||||
<InvitedUsersTable
|
||||
invitedUsers={invitedUsers}
|
||||
isLoading={isLoadingInvitedUsers}
|
||||
isRefreshing={isRefreshingInvitedUsers}
|
||||
pendingInviteAction={pendingInviteAction}
|
||||
onRetryTally={handleRetryTally}
|
||||
onRevoke={handleRevoke}
|
||||
/>
|
||||
</Card>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,135 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import type { BulkInvitedUsersResponse } from "@/app/api/__generated__/models/bulkInvitedUsersResponse";
|
||||
import { Badge } from "@/components/atoms/Badge/Badge";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { FormEvent } from "react";
|
||||
|
||||
interface Props {
|
||||
selectedFile: File | null;
|
||||
inputKey: number;
|
||||
isSubmitting: boolean;
|
||||
lastResult: BulkInvitedUsersResponse | null;
|
||||
onFileChange: (file: File | null) => void;
|
||||
onSubmit: (event: FormEvent<HTMLFormElement>) => void;
|
||||
}
|
||||
|
||||
function getStatusVariant(status: "CREATED" | "SKIPPED" | "ERROR") {
|
||||
if (status === "CREATED") {
|
||||
return "success";
|
||||
}
|
||||
|
||||
if (status === "ERROR") {
|
||||
return "error";
|
||||
}
|
||||
|
||||
return "info";
|
||||
}
|
||||
|
||||
export function BulkInviteForm({
|
||||
selectedFile,
|
||||
inputKey,
|
||||
isSubmitting,
|
||||
lastResult,
|
||||
onFileChange,
|
||||
onSubmit,
|
||||
}: Props) {
|
||||
return (
|
||||
<form className="flex flex-col gap-4" onSubmit={onSubmit}>
|
||||
<div className="flex flex-col gap-1">
|
||||
<h2 className="text-xl font-semibold text-zinc-900">Bulk invite</h2>
|
||||
<p className="text-sm text-zinc-600">
|
||||
Upload a <span className="font-medium text-zinc-800">.txt</span> file
|
||||
with one email per line, or a{" "}
|
||||
<span className="font-medium text-zinc-800">.csv</span> with
|
||||
<span className="font-medium text-zinc-800"> email</span> and optional
|
||||
<span className="font-medium text-zinc-800"> name</span> columns.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<label
|
||||
htmlFor="bulk-invite-file-input"
|
||||
className="flex cursor-pointer flex-col gap-2 rounded-2xl border border-dashed border-zinc-300 bg-zinc-50 px-4 py-5 text-sm text-zinc-600 transition-colors focus-within:ring-2 focus-within:ring-zinc-500 focus-within:ring-offset-2 hover:border-zinc-400 hover:bg-zinc-100"
|
||||
>
|
||||
<span className="font-medium text-zinc-900">
|
||||
{selectedFile ? selectedFile.name : "Choose invite file"}
|
||||
</span>
|
||||
<span>Maximum 500 rows, UTF-8 encoded.</span>
|
||||
<input
|
||||
id="bulk-invite-file-input"
|
||||
key={inputKey}
|
||||
type="file"
|
||||
accept=".txt,.csv,text/plain,text/csv"
|
||||
disabled={isSubmitting}
|
||||
className="sr-only"
|
||||
onChange={(event) =>
|
||||
onFileChange(event.target.files?.item(0) ?? null)
|
||||
}
|
||||
/>
|
||||
</label>
|
||||
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
loading={isSubmitting}
|
||||
disabled={!selectedFile}
|
||||
className="w-full"
|
||||
>
|
||||
{isSubmitting ? "Uploading invites..." : "Upload invite file"}
|
||||
</Button>
|
||||
|
||||
{lastResult ? (
|
||||
<div className="flex flex-col gap-3 rounded-2xl border border-zinc-200 bg-zinc-50 p-4">
|
||||
<div className="grid grid-cols-3 gap-2 text-center">
|
||||
<div className="rounded-xl bg-white px-3 py-2">
|
||||
<div className="text-lg font-semibold text-zinc-900">
|
||||
{lastResult.created_count}
|
||||
</div>
|
||||
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
|
||||
Created
|
||||
</div>
|
||||
</div>
|
||||
<div className="rounded-xl bg-white px-3 py-2">
|
||||
<div className="text-lg font-semibold text-zinc-900">
|
||||
{lastResult.skipped_count}
|
||||
</div>
|
||||
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
|
||||
Skipped
|
||||
</div>
|
||||
</div>
|
||||
<div className="rounded-xl bg-white px-3 py-2">
|
||||
<div className="text-lg font-semibold text-zinc-900">
|
||||
{lastResult.error_count}
|
||||
</div>
|
||||
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
|
||||
Errors
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="max-h-64 overflow-y-auto rounded-xl border border-zinc-200 bg-white">
|
||||
<div className="flex flex-col divide-y divide-zinc-100">
|
||||
{lastResult.results.map((row) => (
|
||||
<div
|
||||
key={`${row.row_number}-${row.email ?? row.message}`}
|
||||
className="flex items-start gap-3 px-3 py-3"
|
||||
>
|
||||
<Badge variant={getStatusVariant(row.status)} size="small">
|
||||
{row.status}
|
||||
</Badge>
|
||||
<div className="flex min-w-0 flex-1 flex-col gap-1">
|
||||
<span className="text-sm font-medium text-zinc-900">
|
||||
Row {row.row_number}
|
||||
{row.email ? ` · ${row.email}` : ""}
|
||||
</span>
|
||||
<span className="text-xs text-zinc-500">{row.message}</span>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : null}
|
||||
</form>
|
||||
);
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import type { FormEvent } from "react";
|
||||
|
||||
interface Props {
|
||||
email: string;
|
||||
name: string;
|
||||
isSubmitting: boolean;
|
||||
onEmailChange: (value: string) => void;
|
||||
onNameChange: (value: string) => void;
|
||||
onSubmit: (event: FormEvent<HTMLFormElement>) => void;
|
||||
}
|
||||
|
||||
export function InviteUserForm({
|
||||
email,
|
||||
name,
|
||||
isSubmitting,
|
||||
onEmailChange,
|
||||
onNameChange,
|
||||
onSubmit,
|
||||
}: Props) {
|
||||
return (
|
||||
<form className="flex flex-col gap-4" onSubmit={onSubmit}>
|
||||
<div className="flex flex-col gap-1">
|
||||
<h2 className="text-xl font-semibold text-zinc-900">Create invite</h2>
|
||||
<p className="text-sm text-zinc-600">
|
||||
The invite is stored immediately, then Tally pre-seeding starts in the
|
||||
background.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<Input
|
||||
id="invite-email"
|
||||
label="Email"
|
||||
type="email"
|
||||
value={email}
|
||||
placeholder="jane@example.com"
|
||||
autoComplete="email"
|
||||
disabled={isSubmitting}
|
||||
onChange={(event) => onEmailChange(event.target.value)}
|
||||
/>
|
||||
|
||||
<Input
|
||||
id="invite-name"
|
||||
label="Name"
|
||||
type="text"
|
||||
value={name}
|
||||
placeholder="Jane Doe"
|
||||
disabled={isSubmitting}
|
||||
onChange={(event) => onNameChange(event.target.value)}
|
||||
/>
|
||||
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
loading={isSubmitting}
|
||||
disabled={!email.trim()}
|
||||
className="w-full"
|
||||
>
|
||||
{isSubmitting ? "Creating invite..." : "Create invite"}
|
||||
</Button>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
@@ -1,209 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import type { InvitedUserResponse } from "@/app/api/__generated__/models/invitedUserResponse";
|
||||
import { Badge } from "@/components/atoms/Badge/Badge";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/__legacy__/ui/table";
|
||||
|
||||
interface Props {
|
||||
invitedUsers: InvitedUserResponse[];
|
||||
isLoading: boolean;
|
||||
isRefreshing: boolean;
|
||||
pendingInviteAction: string | null;
|
||||
onRetryTally: (invitedUserId: string) => void;
|
||||
onRevoke: (invitedUserId: string) => void;
|
||||
}
|
||||
|
||||
function getInviteBadgeVariant(status: InvitedUserResponse["status"]) {
|
||||
if (status === "CLAIMED") {
|
||||
return "success";
|
||||
}
|
||||
|
||||
if (status === "REVOKED") {
|
||||
return "error";
|
||||
}
|
||||
|
||||
return "info";
|
||||
}
|
||||
|
||||
function getTallyBadgeVariant(status: InvitedUserResponse["tally_status"]) {
|
||||
if (status === "READY") {
|
||||
return "success";
|
||||
}
|
||||
|
||||
if (status === "FAILED") {
|
||||
return "error";
|
||||
}
|
||||
|
||||
return "info";
|
||||
}
|
||||
|
||||
function formatDate(value: Date | undefined) {
|
||||
if (!value) {
|
||||
return "-";
|
||||
}
|
||||
|
||||
return value.toLocaleString();
|
||||
}
|
||||
|
||||
function getTallySummary(invitedUser: InvitedUserResponse) {
|
||||
if (invitedUser.tally_status === "FAILED" && invitedUser.tally_error) {
|
||||
return invitedUser.tally_error;
|
||||
}
|
||||
|
||||
if (invitedUser.tally_status === "READY" && invitedUser.tally_understanding) {
|
||||
return "Stored and ready for activation";
|
||||
}
|
||||
|
||||
if (invitedUser.tally_status === "READY") {
|
||||
return "No matching Tally submission found";
|
||||
}
|
||||
|
||||
if (invitedUser.tally_status === "RUNNING") {
|
||||
return "Extraction in progress";
|
||||
}
|
||||
|
||||
return "Waiting to run";
|
||||
}
|
||||
|
||||
function isActionPending(
|
||||
pendingInviteAction: string | null,
|
||||
action: "retry" | "revoke",
|
||||
invitedUserId: string,
|
||||
) {
|
||||
return pendingInviteAction === `${action}:${invitedUserId}`;
|
||||
}
|
||||
|
||||
export function InvitedUsersTable({
|
||||
invitedUsers,
|
||||
isLoading,
|
||||
isRefreshing,
|
||||
pendingInviteAction,
|
||||
onRetryTally,
|
||||
onRevoke,
|
||||
}: Props) {
|
||||
return (
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex items-center justify-between gap-4">
|
||||
<div className="flex flex-col gap-1">
|
||||
<h2 className="text-xl font-semibold text-zinc-900">Invited users</h2>
|
||||
<p className="text-sm text-zinc-600">
|
||||
Live invite state, claim status, and Tally pre-seeding progress.
|
||||
</p>
|
||||
</div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
{isRefreshing ? "Refreshing" : `${invitedUsers.length} total`}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div className="overflow-hidden rounded-2xl border border-zinc-200">
|
||||
<Table>
|
||||
<TableHeader className="bg-zinc-50">
|
||||
<TableRow>
|
||||
<TableHead>Email</TableHead>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>Invite</TableHead>
|
||||
<TableHead>Tally</TableHead>
|
||||
<TableHead>Claimed User</TableHead>
|
||||
<TableHead>Created</TableHead>
|
||||
<TableHead className="text-right">Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{isLoading ? (
|
||||
<TableRow>
|
||||
<TableCell
|
||||
colSpan={7}
|
||||
className="py-10 text-center text-zinc-500"
|
||||
>
|
||||
Loading invited users...
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
) : invitedUsers.length === 0 ? (
|
||||
<TableRow>
|
||||
<TableCell
|
||||
colSpan={7}
|
||||
className="py-10 text-center text-zinc-500"
|
||||
>
|
||||
No invited users yet
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
) : (
|
||||
invitedUsers.map((invitedUser) => (
|
||||
<TableRow key={invitedUser.id} className="align-top">
|
||||
<TableCell className="font-medium text-zinc-900">
|
||||
{invitedUser.email}
|
||||
</TableCell>
|
||||
<TableCell>{invitedUser.name || "-"}</TableCell>
|
||||
<TableCell>
|
||||
<Badge variant={getInviteBadgeVariant(invitedUser.status)}>
|
||||
{invitedUser.status}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex max-w-xs flex-col gap-2">
|
||||
<Badge
|
||||
variant={getTallyBadgeVariant(invitedUser.tally_status)}
|
||||
>
|
||||
{invitedUser.tally_status}
|
||||
</Badge>
|
||||
<span className="text-xs text-zinc-500">
|
||||
{getTallySummary(invitedUser)}
|
||||
</span>
|
||||
<span className="text-xs text-zinc-400">
|
||||
{formatDate(invitedUser.tally_computed_at ?? undefined)}
|
||||
</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell className="font-mono text-xs text-zinc-500">
|
||||
{invitedUser.auth_user_id || "-"}
|
||||
</TableCell>
|
||||
<TableCell className="text-sm text-zinc-500">
|
||||
{formatDate(invitedUser.created_at)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex justify-end gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
disabled={invitedUser.status === "REVOKED"}
|
||||
loading={isActionPending(
|
||||
pendingInviteAction,
|
||||
"retry",
|
||||
invitedUser.id,
|
||||
)}
|
||||
onClick={() => onRetryTally(invitedUser.id)}
|
||||
>
|
||||
Retry Tally
|
||||
</Button>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="small"
|
||||
disabled={invitedUser.status !== "INVITED"}
|
||||
loading={isActionPending(
|
||||
pendingInviteAction,
|
||||
"revoke",
|
||||
invitedUser.id,
|
||||
)}
|
||||
onClick={() => onRevoke(invitedUser.id)}
|
||||
>
|
||||
Revoke
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))
|
||||
)}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,11 +1,16 @@
|
||||
import { withRoleAccess } from "@/lib/withRoleAccess";
|
||||
import { AdminUsersPage } from "./components/AdminUsersPage/AdminUsersPage";
|
||||
import React from "react";
|
||||
|
||||
function AdminUsers() {
|
||||
return <AdminUsersPage />;
|
||||
return (
|
||||
<div>
|
||||
<h1>Users Dashboard</h1>
|
||||
{/* Add your admin-only content here */}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default async function AdminUsersRoute() {
|
||||
export default async function AdminUsersPage() {
|
||||
"use server";
|
||||
const withAdminAccess = await withRoleAccess(["admin"]);
|
||||
const ProtectedAdminUsers = await withAdminAccess(AdminUsers);
|
||||
|
||||
@@ -1,197 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import type { BulkInvitedUsersResponse } from "@/app/api/__generated__/models/bulkInvitedUsersResponse";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import {
|
||||
getGetV2ListInvitedUsersQueryKey,
|
||||
useGetV2ListInvitedUsers,
|
||||
usePostV2BulkCreateInvitedUsers,
|
||||
usePostV2CreateInvitedUser,
|
||||
usePostV2RetryInvitedUserTally,
|
||||
usePostV2RevokeInvitedUser,
|
||||
} from "@/app/api/__generated__/endpoints/admin/admin";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { type FormEvent, useState } from "react";
|
||||
|
||||
function getErrorMessage(error: unknown) {
|
||||
if (error instanceof Error) {
|
||||
return error.message;
|
||||
}
|
||||
|
||||
return "Something went wrong";
|
||||
}
|
||||
|
||||
export function useAdminUsersPage() {
|
||||
const queryClient = useQueryClient();
|
||||
const { toast } = useToast();
|
||||
const [email, setEmail] = useState("");
|
||||
const [name, setName] = useState("");
|
||||
const [bulkInviteFile, setBulkInviteFile] = useState<File | null>(null);
|
||||
const [bulkInviteInputKey, setBulkInviteInputKey] = useState(0);
|
||||
const [lastBulkInviteResult, setLastBulkInviteResult] =
|
||||
useState<BulkInvitedUsersResponse | null>(null);
|
||||
const [pendingInviteAction, setPendingInviteAction] = useState<string | null>(
|
||||
null,
|
||||
);
|
||||
|
||||
const invitedUsersQuery = useGetV2ListInvitedUsers(undefined, {
|
||||
query: {
|
||||
select: okData,
|
||||
refetchInterval: 30_000,
|
||||
},
|
||||
});
|
||||
|
||||
const createInvitedUserMutation = usePostV2CreateInvitedUser({
|
||||
mutation: {
|
||||
onSuccess: async () => {
|
||||
setEmail("");
|
||||
setName("");
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListInvitedUsersQueryKey(),
|
||||
});
|
||||
toast({
|
||||
title: "Invited user created",
|
||||
variant: "default",
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
toast({
|
||||
title: getErrorMessage(error),
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const bulkCreateInvitedUsersMutation = usePostV2BulkCreateInvitedUsers({
|
||||
mutation: {
|
||||
onSuccess: async (response) => {
|
||||
const result = okData(response) ?? null;
|
||||
setBulkInviteFile(null);
|
||||
setBulkInviteInputKey((currentValue) => currentValue + 1);
|
||||
setLastBulkInviteResult(result);
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListInvitedUsersQueryKey(),
|
||||
});
|
||||
toast({
|
||||
title: result
|
||||
? `${result.created_count} invites created`
|
||||
: "Bulk invite upload complete",
|
||||
variant: "default",
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
toast({
|
||||
title: getErrorMessage(error),
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const retryInvitedUserTallyMutation = usePostV2RetryInvitedUserTally({
|
||||
mutation: {
|
||||
onSuccess: async () => {
|
||||
setPendingInviteAction(null);
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListInvitedUsersQueryKey(),
|
||||
});
|
||||
toast({
|
||||
title: "Tally pre-seeding restarted",
|
||||
variant: "default",
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
setPendingInviteAction(null);
|
||||
toast({
|
||||
title: getErrorMessage(error),
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const revokeInvitedUserMutation = usePostV2RevokeInvitedUser({
|
||||
mutation: {
|
||||
onSuccess: async () => {
|
||||
setPendingInviteAction(null);
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListInvitedUsersQueryKey(),
|
||||
});
|
||||
toast({
|
||||
title: "Invite revoked",
|
||||
variant: "default",
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
setPendingInviteAction(null);
|
||||
toast({
|
||||
title: getErrorMessage(error),
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
function handleCreateInvite(event: FormEvent<HTMLFormElement>) {
|
||||
event.preventDefault();
|
||||
|
||||
createInvitedUserMutation.mutate({
|
||||
data: {
|
||||
email,
|
||||
name: name.trim() || null,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function handleRetryTally(invitedUserId: string) {
|
||||
setPendingInviteAction(`retry:${invitedUserId}`);
|
||||
retryInvitedUserTallyMutation.mutate({ invitedUserId });
|
||||
}
|
||||
|
||||
function handleBulkInviteFileChange(file: File | null) {
|
||||
setBulkInviteFile(file);
|
||||
}
|
||||
|
||||
function handleBulkInviteSubmit(event: FormEvent<HTMLFormElement>) {
|
||||
event.preventDefault();
|
||||
|
||||
if (!bulkInviteFile) {
|
||||
return;
|
||||
}
|
||||
|
||||
bulkCreateInvitedUsersMutation.mutate({
|
||||
data: {
|
||||
file: bulkInviteFile,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function handleRevoke(invitedUserId: string) {
|
||||
setPendingInviteAction(`revoke:${invitedUserId}`);
|
||||
revokeInvitedUserMutation.mutate({ invitedUserId });
|
||||
}
|
||||
|
||||
return {
|
||||
email,
|
||||
name,
|
||||
bulkInviteFile,
|
||||
bulkInviteInputKey,
|
||||
lastBulkInviteResult,
|
||||
invitedUsers: invitedUsersQuery.data?.invited_users ?? [],
|
||||
invitedUsersError: invitedUsersQuery.error,
|
||||
isLoadingInvitedUsers: invitedUsersQuery.isLoading,
|
||||
isRefreshingInvitedUsers: invitedUsersQuery.isFetching,
|
||||
isCreatingInvite: createInvitedUserMutation.isPending,
|
||||
isBulkInviting: bulkCreateInvitedUsersMutation.isPending,
|
||||
pendingInviteAction,
|
||||
setEmail,
|
||||
setName,
|
||||
handleBulkInviteFileChange,
|
||||
handleBulkInviteSubmit,
|
||||
handleCreateInvite,
|
||||
handleRetryTally,
|
||||
handleRevoke,
|
||||
};
|
||||
}
|
||||
@@ -75,7 +75,7 @@ export const getSecondCalculatorNode = () => {
|
||||
export const getFormContainerSelector = (blockId: string): string | null => {
|
||||
const node = getNodeByBlockId(blockId);
|
||||
if (node) {
|
||||
return `[data-id="form-creator-container-${node.id}-node"]`;
|
||||
return `[data-id="form-creator-container-${node.id}"]`;
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
*
|
||||
* Typography (body, small, action, info, tip, warning) uses Tailwind utilities directly in steps.ts
|
||||
*/
|
||||
import "shepherd.js/dist/css/shepherd.css";
|
||||
import "./tutorial.css";
|
||||
|
||||
export const injectTutorialStyles = () => {
|
||||
|
||||
@@ -1,14 +1,3 @@
|
||||
.new-builder-tutorial-disable {
|
||||
opacity: 0.3 !important;
|
||||
pointer-events: none !important;
|
||||
filter: grayscale(100%) !important;
|
||||
}
|
||||
|
||||
.new-builder-tutorial-highlight {
|
||||
position: relative;
|
||||
z-index: 10;
|
||||
}
|
||||
|
||||
.new-builder-tutorial-highlight * {
|
||||
opacity: 1 !important;
|
||||
filter: none !important;
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/molecules/DropdownMenu/DropdownMenu";
|
||||
import { SidebarProvider } from "@/components/ui/sidebar";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { UploadSimple } from "@phosphor-icons/react";
|
||||
import { DotsThree, UploadSimple } from "@phosphor-icons/react";
|
||||
import { useCallback, useRef, useState } from "react";
|
||||
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
||||
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
|
||||
import { DeleteChatDialog } from "./components/DeleteChatDialog/DeleteChatDialog";
|
||||
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
|
||||
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
|
||||
import { NotificationBanner } from "./components/NotificationBanner/NotificationBanner";
|
||||
import { NotificationDialog } from "./components/NotificationDialog/NotificationDialog";
|
||||
import { ScaleLoader } from "./components/ScaleLoader/ScaleLoader";
|
||||
import { useCopilotPage } from "./useCopilotPage";
|
||||
|
||||
@@ -86,6 +90,7 @@ export function CopilotPage() {
|
||||
// Delete functionality
|
||||
sessionToDelete,
|
||||
isDeleting,
|
||||
handleDeleteClick,
|
||||
handleConfirmDelete,
|
||||
handleCancelDelete,
|
||||
} = useCopilotPage();
|
||||
@@ -112,7 +117,6 @@ export function CopilotPage() {
|
||||
onDrop={handleDrop}
|
||||
>
|
||||
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
|
||||
<NotificationBanner />
|
||||
{/* Drop overlay */}
|
||||
<div
|
||||
className={cn(
|
||||
@@ -141,6 +145,38 @@ export function CopilotPage() {
|
||||
isUploadingFiles={isUploadingFiles}
|
||||
droppedFiles={droppedFiles}
|
||||
onDroppedFilesConsumed={handleDroppedFilesConsumed}
|
||||
headerSlot={
|
||||
isMobile && sessionId ? (
|
||||
<div className="flex justify-end">
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<button
|
||||
className="rounded p-1.5 hover:bg-neutral-100"
|
||||
aria-label="More actions"
|
||||
>
|
||||
<DotsThree className="h-5 w-5 text-neutral-600" />
|
||||
</button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem
|
||||
onClick={() => {
|
||||
const session = sessions.find(
|
||||
(s) => s.id === sessionId,
|
||||
);
|
||||
if (session) {
|
||||
handleDeleteClick(session.id, session.title);
|
||||
}
|
||||
}}
|
||||
disabled={isDeleting}
|
||||
className="text-red-600 focus:bg-red-50 focus:text-red-600"
|
||||
>
|
||||
Delete chat
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
) : undefined
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
@@ -165,7 +201,6 @@ export function CopilotPage() {
|
||||
onCancel={handleCancelDelete}
|
||||
/>
|
||||
)}
|
||||
<NotificationDialog />
|
||||
</SidebarProvider>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
|
||||
import { UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { LayoutGroup, motion } from "framer-motion";
|
||||
import { ReactNode } from "react";
|
||||
import { ChatMessagesContainer } from "../ChatMessagesContainer/ChatMessagesContainer";
|
||||
import { CopilotChatActionsProvider } from "../CopilotChatActionsProvider/CopilotChatActionsProvider";
|
||||
import { EmptySession } from "../EmptySession/EmptySession";
|
||||
@@ -20,6 +21,7 @@ export interface ChatContainerProps {
|
||||
onSend: (message: string, files?: File[]) => void | Promise<void>;
|
||||
onStop: () => void;
|
||||
isUploadingFiles?: boolean;
|
||||
headerSlot?: ReactNode;
|
||||
/** Files dropped onto the chat window. */
|
||||
droppedFiles?: File[];
|
||||
/** Called after droppedFiles have been consumed by ChatInput. */
|
||||
@@ -38,6 +40,7 @@ export const ChatContainer = ({
|
||||
onSend,
|
||||
onStop,
|
||||
isUploadingFiles,
|
||||
headerSlot,
|
||||
droppedFiles,
|
||||
onDroppedFilesConsumed,
|
||||
}: ChatContainerProps) => {
|
||||
@@ -60,6 +63,7 @@ export const ChatContainer = ({
|
||||
status={status}
|
||||
error={error}
|
||||
isLoading={isLoadingSession}
|
||||
headerSlot={headerSlot}
|
||||
sessionID={sessionId}
|
||||
/>
|
||||
<motion.div
|
||||
|
||||
@@ -30,6 +30,7 @@ interface Props {
|
||||
status: string;
|
||||
error: Error | undefined;
|
||||
isLoading: boolean;
|
||||
headerSlot?: React.ReactNode;
|
||||
sessionID?: string | null;
|
||||
}
|
||||
|
||||
@@ -101,6 +102,7 @@ export function ChatMessagesContainer({
|
||||
status,
|
||||
error,
|
||||
isLoading,
|
||||
headerSlot,
|
||||
sessionID,
|
||||
}: Props) {
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
@@ -133,6 +135,7 @@ export function ChatMessagesContainer({
|
||||
return (
|
||||
<Conversation className="min-h-0 flex-1">
|
||||
<ConversationContent className="flex flex-1 flex-col gap-6 px-3 py-6">
|
||||
{headerSlot}
|
||||
{isLoading && messages.length === 0 && (
|
||||
<div
|
||||
className="flex flex-1 items-center justify-center"
|
||||
|
||||
@@ -23,37 +23,24 @@ import {
|
||||
useSidebar,
|
||||
} from "@/components/ui/sidebar";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
CheckCircle,
|
||||
DotsThree,
|
||||
PlusCircleIcon,
|
||||
PlusIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { DotsThree, PlusCircleIcon, PlusIcon } from "@phosphor-icons/react";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import { parseAsString, useQueryState } from "nuqs";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
import { NotificationToggle } from "./components/NotificationToggle/NotificationToggle";
|
||||
import { DeleteChatDialog } from "../DeleteChatDialog/DeleteChatDialog";
|
||||
import { PulseLoader } from "../PulseLoader/PulseLoader";
|
||||
import { UsageLimits } from "../UsageLimits/UsageLimits";
|
||||
|
||||
export function ChatSidebar() {
|
||||
const { state } = useSidebar();
|
||||
const isCollapsed = state === "collapsed";
|
||||
const [sessionId, setSessionId] = useQueryState("sessionId", parseAsString);
|
||||
const {
|
||||
sessionToDelete,
|
||||
setSessionToDelete,
|
||||
completedSessionIDs,
|
||||
clearCompletedSession,
|
||||
} = useCopilotUIStore();
|
||||
const { sessionToDelete, setSessionToDelete } = useCopilotUIStore();
|
||||
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
const { data: sessionsResponse, isLoading: isLoadingSessions } =
|
||||
useGetV2ListSessions({ limit: 50 }, { query: { refetchInterval: 10_000 } });
|
||||
useGetV2ListSessions({ limit: 50 });
|
||||
|
||||
const { mutate: deleteSession, isPending: isDeleting } =
|
||||
useDeleteV2DeleteSession({
|
||||
@@ -112,22 +99,6 @@ export function ChatSidebar() {
|
||||
}
|
||||
}, [editingSessionId]);
|
||||
|
||||
// Refetch session list when active session changes
|
||||
useEffect(() => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
});
|
||||
}, [sessionId, queryClient]);
|
||||
|
||||
// Clear completed indicator when navigating to a session (works for all paths)
|
||||
useEffect(() => {
|
||||
if (!sessionId || !completedSessionIDs.has(sessionId)) return;
|
||||
clearCompletedSession(sessionId);
|
||||
const remaining = completedSessionIDs.size - 1;
|
||||
document.title =
|
||||
remaining > 0 ? `(${remaining}) Otto is ready - AutoGPT` : "AutoGPT";
|
||||
}, [sessionId, completedSessionIDs, clearCompletedSession]);
|
||||
|
||||
const sessions =
|
||||
sessionsResponse?.status === 200 ? sessionsResponse.data.sessions : [];
|
||||
|
||||
@@ -257,9 +228,7 @@ export function ChatSidebar() {
|
||||
<Text variant="h3" size="body-medium">
|
||||
Your chats
|
||||
</Text>
|
||||
<div className="flex items-center">
|
||||
<UsageLimits />
|
||||
<NotificationToggle />
|
||||
<div className="relative left-6">
|
||||
<SidebarTrigger />
|
||||
</div>
|
||||
</div>
|
||||
@@ -336,8 +305,8 @@ export function ChatSidebar() {
|
||||
onClick={() => handleSelectSession(session.id)}
|
||||
className="w-full px-3 py-2.5 pr-10 text-left"
|
||||
>
|
||||
<div className="flex min-w-0 max-w-full items-center gap-2">
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
|
||||
<div className="min-w-0 max-w-full">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
@@ -360,22 +329,10 @@ export function ChatSidebar() {
|
||||
</motion.span>
|
||||
</AnimatePresence>
|
||||
</Text>
|
||||
<Text variant="small" className="text-neutral-400">
|
||||
{formatDate(session.updated_at)}
|
||||
</Text>
|
||||
</div>
|
||||
{session.is_processing &&
|
||||
session.id !== sessionId &&
|
||||
!completedSessionIDs.has(session.id) && (
|
||||
<PulseLoader size={16} className="shrink-0" />
|
||||
)}
|
||||
{completedSessionIDs.has(session.id) &&
|
||||
session.id !== sessionId && (
|
||||
<CheckCircle
|
||||
className="h-4 w-4 shrink-0 text-green-500"
|
||||
weight="fill"
|
||||
/>
|
||||
)}
|
||||
<Text variant="small" className="text-neutral-400">
|
||||
{formatDate(session.updated_at)}
|
||||
</Text>
|
||||
</div>
|
||||
</button>
|
||||
)}
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Switch } from "@/components/atoms/Switch/Switch";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/molecules/Popover/Popover";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Bell, BellRinging, BellSlash } from "@phosphor-icons/react";
|
||||
import { useCopilotUIStore } from "../../../../store";
|
||||
|
||||
export function NotificationToggle() {
|
||||
const {
|
||||
isNotificationsEnabled,
|
||||
setNotificationsEnabled,
|
||||
isSoundEnabled,
|
||||
toggleSound,
|
||||
} = useCopilotUIStore();
|
||||
|
||||
async function handleToggleNotifications() {
|
||||
if (isNotificationsEnabled) {
|
||||
setNotificationsEnabled(false);
|
||||
return;
|
||||
}
|
||||
if (typeof Notification === "undefined") {
|
||||
toast({
|
||||
title: "Notifications not supported",
|
||||
description: "Your browser does not support notifications.",
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
const permission = await Notification.requestPermission();
|
||||
if (permission === "granted") {
|
||||
setNotificationsEnabled(true);
|
||||
} else {
|
||||
toast({
|
||||
title: "Notifications blocked",
|
||||
description:
|
||||
"Please allow notifications in your browser settings to enable this feature.",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<Button variant="ghost" size="icon" aria-label="Notification settings">
|
||||
{!isNotificationsEnabled ? (
|
||||
<BellSlash className="!size-5" />
|
||||
) : isSoundEnabled ? (
|
||||
<BellRinging className="!size-5" />
|
||||
) : (
|
||||
<Bell className="!size-5" />
|
||||
)}
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent align="start" className="w-56 p-3">
|
||||
<div className="flex flex-col gap-3">
|
||||
<label className="flex items-center justify-between">
|
||||
<span className="text-sm text-zinc-700">Notifications</span>
|
||||
<Switch
|
||||
checked={isNotificationsEnabled}
|
||||
onCheckedChange={handleToggleNotifications}
|
||||
/>
|
||||
</label>
|
||||
<label className="flex items-center justify-between">
|
||||
<span
|
||||
className={cn(
|
||||
"text-sm text-zinc-700",
|
||||
!isNotificationsEnabled && "opacity-50",
|
||||
)}
|
||||
>
|
||||
Sound
|
||||
</span>
|
||||
<Switch
|
||||
checked={isSoundEnabled && isNotificationsEnabled}
|
||||
onCheckedChange={toggleSound}
|
||||
disabled={!isNotificationsEnabled}
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
@@ -1,9 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { useGetV2GetSuggestedPrompts } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { SpinnerGapIcon } from "@phosphor-icons/react";
|
||||
@@ -35,38 +33,15 @@ export function EmptySession({
|
||||
}: Props) {
|
||||
const { user } = useSupabase();
|
||||
const greetingName = getGreetingName(user);
|
||||
|
||||
const { data: suggestedPromptsResponse, isLoading: isLoadingPrompts } =
|
||||
useGetV2GetSuggestedPrompts({
|
||||
query: { staleTime: Infinity },
|
||||
});
|
||||
const customPrompts =
|
||||
suggestedPromptsResponse?.status === 200
|
||||
? suggestedPromptsResponse.data.prompts
|
||||
: undefined;
|
||||
const quickActions = getQuickActions(customPrompts);
|
||||
const quickActions = getQuickActions();
|
||||
const [loadingAction, setLoadingAction] = useState<string | null>(null);
|
||||
const [inputPlaceholder, setInputPlaceholder] = useState(
|
||||
getInputPlaceholder(),
|
||||
);
|
||||
|
||||
// Use matchMedia instead of resize event — fires only when crossing
|
||||
// the 500px and 1081px breakpoints defined in getInputPlaceholder(),
|
||||
// rather than dozens of times per second during a window drag.
|
||||
useEffect(() => {
|
||||
function update() {
|
||||
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
|
||||
}
|
||||
const mq500 = window.matchMedia("(min-width: 500px)");
|
||||
const mq1081 = window.matchMedia("(min-width: 1081px)");
|
||||
update();
|
||||
mq500.addEventListener("change", update);
|
||||
mq1081.addEventListener("change", update);
|
||||
return () => {
|
||||
mq500.removeEventListener("change", update);
|
||||
mq1081.removeEventListener("change", update);
|
||||
};
|
||||
}, []);
|
||||
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
|
||||
}, [window.innerWidth]);
|
||||
|
||||
async function handleQuickActionClick(action: string) {
|
||||
if (isCreatingSession || loadingAction) return;
|
||||
@@ -116,32 +91,28 @@ export function EmptySession({
|
||||
</div>
|
||||
|
||||
<div className="flex flex-wrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
||||
{isLoadingPrompts
|
||||
? Array.from({ length: 3 }, (_, i) => (
|
||||
<Skeleton key={i} className="h-10 w-64 shrink-0 rounded-full" />
|
||||
))
|
||||
: quickActions.map((action) => (
|
||||
<Button
|
||||
key={action}
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={() => void handleQuickActionClick(action)}
|
||||
disabled={isCreatingSession || loadingAction !== null}
|
||||
aria-busy={loadingAction === action}
|
||||
leftIcon={
|
||||
loadingAction === action ? (
|
||||
<SpinnerGapIcon
|
||||
className="h-4 w-4 animate-spin"
|
||||
weight="bold"
|
||||
/>
|
||||
) : null
|
||||
}
|
||||
className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
|
||||
>
|
||||
{action}
|
||||
</Button>
|
||||
))}
|
||||
{quickActions.map((action) => (
|
||||
<Button
|
||||
key={action}
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={() => void handleQuickActionClick(action)}
|
||||
disabled={isCreatingSession || loadingAction !== null}
|
||||
aria-busy={loadingAction === action}
|
||||
leftIcon={
|
||||
loadingAction === action ? (
|
||||
<SpinnerGapIcon
|
||||
className="h-4 w-4 animate-spin"
|
||||
weight="bold"
|
||||
/>
|
||||
) : null
|
||||
}
|
||||
className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
|
||||
>
|
||||
{action}
|
||||
</Button>
|
||||
))}
|
||||
</div>
|
||||
</motion.div>
|
||||
</div>
|
||||
|
||||
@@ -12,17 +12,12 @@ export function getInputPlaceholder(width?: number) {
|
||||
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
|
||||
}
|
||||
|
||||
const DEFAULT_QUICK_ACTIONS = [
|
||||
"I don't know where to start, just ask me stuff",
|
||||
"I do the same thing every week and it's killing me",
|
||||
"Help me find where I'm wasting my time",
|
||||
];
|
||||
|
||||
export function getQuickActions(customPrompts?: string[]) {
|
||||
if (customPrompts && customPrompts.length > 0) {
|
||||
return customPrompts;
|
||||
}
|
||||
return DEFAULT_QUICK_ACTIONS;
|
||||
export function getQuickActions() {
|
||||
return [
|
||||
"I don't know where to start, just ask me stuff",
|
||||
"I do the same thing every week and it's killing me",
|
||||
"Help me find where I'm wasting my time",
|
||||
];
|
||||
}
|
||||
|
||||
export function getGreetingName(user?: User | null) {
|
||||
|
||||
@@ -3,17 +3,8 @@ import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { scrollbarStyles } from "@/components/styles/scrollbars";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
CheckCircle,
|
||||
PlusIcon,
|
||||
SpeakerHigh,
|
||||
SpeakerSlash,
|
||||
SpinnerGapIcon,
|
||||
X,
|
||||
} from "@phosphor-icons/react";
|
||||
import { PlusIcon, SpinnerGapIcon, X } from "@phosphor-icons/react";
|
||||
import { Drawer } from "vaul";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
import { PulseLoader } from "../PulseLoader/PulseLoader";
|
||||
|
||||
interface Props {
|
||||
isOpen: boolean;
|
||||
@@ -61,13 +52,6 @@ export function MobileDrawer({
|
||||
onClose,
|
||||
onOpenChange,
|
||||
}: Props) {
|
||||
const {
|
||||
completedSessionIDs,
|
||||
clearCompletedSession,
|
||||
isSoundEnabled,
|
||||
toggleSound,
|
||||
} = useCopilotUIStore();
|
||||
|
||||
return (
|
||||
<Drawer.Root open={isOpen} onOpenChange={onOpenChange} direction="left">
|
||||
<Drawer.Portal>
|
||||
@@ -78,31 +62,14 @@ export function MobileDrawer({
|
||||
<Drawer.Title className="text-lg font-semibold text-zinc-800">
|
||||
Your chats
|
||||
</Drawer.Title>
|
||||
<div className="flex items-center gap-1">
|
||||
<button
|
||||
onClick={toggleSound}
|
||||
className="rounded p-1.5 text-zinc-400 transition-colors hover:text-zinc-600"
|
||||
aria-label={
|
||||
isSoundEnabled
|
||||
? "Disable notification sound"
|
||||
: "Enable notification sound"
|
||||
}
|
||||
>
|
||||
{isSoundEnabled ? (
|
||||
<SpeakerHigh className="h-4 w-4" />
|
||||
) : (
|
||||
<SpeakerSlash className="h-4 w-4" />
|
||||
)}
|
||||
</button>
|
||||
<Button
|
||||
variant="icon"
|
||||
size="icon"
|
||||
aria-label="Close sessions"
|
||||
onClick={onClose}
|
||||
>
|
||||
<X width="1rem" height="1rem" />
|
||||
</Button>
|
||||
</div>
|
||||
<Button
|
||||
variant="icon"
|
||||
size="icon"
|
||||
aria-label="Close sessions"
|
||||
onClick={onClose}
|
||||
>
|
||||
<X width="1rem" height="1rem" />
|
||||
</Button>
|
||||
</div>
|
||||
{currentSessionId ? (
|
||||
<div className="mt-2">
|
||||
@@ -136,12 +103,7 @@ export function MobileDrawer({
|
||||
sessions.map((session) => (
|
||||
<button
|
||||
key={session.id}
|
||||
onClick={() => {
|
||||
onSelectSession(session.id);
|
||||
if (completedSessionIDs.has(session.id)) {
|
||||
clearCompletedSession(session.id);
|
||||
}
|
||||
}}
|
||||
onClick={() => onSelectSession(session.id)}
|
||||
className={cn(
|
||||
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
|
||||
session.id === currentSessionId
|
||||
@@ -150,7 +112,7 @@ export function MobileDrawer({
|
||||
)}
|
||||
>
|
||||
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
|
||||
<div className="flex min-w-0 max-w-full items-center gap-1.5">
|
||||
<div className="min-w-0 max-w-full">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
@@ -162,18 +124,6 @@ export function MobileDrawer({
|
||||
>
|
||||
{session.title || "Untitled chat"}
|
||||
</Text>
|
||||
{session.is_processing &&
|
||||
!completedSessionIDs.has(session.id) &&
|
||||
session.id !== currentSessionId && (
|
||||
<PulseLoader size={8} className="shrink-0" />
|
||||
)}
|
||||
{completedSessionIDs.has(session.id) &&
|
||||
session.id !== currentSessionId && (
|
||||
<CheckCircle
|
||||
className="h-4 w-4 shrink-0 text-green-500"
|
||||
weight="fill"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<Text variant="small" className="text-neutral-400">
|
||||
{formatDate(session.updated_at)}
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Key, storage } from "@/services/storage/local-storage";
|
||||
import { BellRinging, X } from "@phosphor-icons/react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
|
||||
export function NotificationBanner() {
|
||||
const { setNotificationsEnabled, isNotificationsEnabled } =
|
||||
useCopilotUIStore();
|
||||
|
||||
const [dismissed, setDismissed] = useState(
|
||||
() => storage.get(Key.COPILOT_NOTIFICATION_BANNER_DISMISSED) === "true",
|
||||
);
|
||||
const [permission, setPermission] = useState(() =>
|
||||
typeof Notification !== "undefined" ? Notification.permission : "denied",
|
||||
);
|
||||
|
||||
// Re-read dismissed flag when notifications are toggled off (e.g. clearCopilotLocalData)
|
||||
useEffect(() => {
|
||||
if (!isNotificationsEnabled) {
|
||||
setDismissed(
|
||||
storage.get(Key.COPILOT_NOTIFICATION_BANNER_DISMISSED) === "true",
|
||||
);
|
||||
}
|
||||
}, [isNotificationsEnabled]);
|
||||
|
||||
// Don't show if notifications aren't supported, already decided, dismissed, or already enabled
|
||||
if (
|
||||
typeof Notification === "undefined" ||
|
||||
permission !== "default" ||
|
||||
dismissed ||
|
||||
isNotificationsEnabled
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
function handleEnable() {
|
||||
Notification.requestPermission().then((result) => {
|
||||
setPermission(result);
|
||||
if (result === "granted") {
|
||||
setNotificationsEnabled(true);
|
||||
handleDismiss();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
function handleDismiss() {
|
||||
storage.set(Key.COPILOT_NOTIFICATION_BANNER_DISMISSED, "true");
|
||||
setDismissed(true);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex items-center gap-3 border-b border-amber-200 bg-amber-50 px-4 py-2.5">
|
||||
<BellRinging className="h-5 w-5 shrink-0 text-amber-600" weight="fill" />
|
||||
<Text variant="body" className="flex-1 text-sm text-amber-800">
|
||||
Enable browser notifications to know when Otto finishes working, even
|
||||
when you switch tabs.
|
||||
</Text>
|
||||
<Button variant="primary" size="small" onClick={handleEnable}>
|
||||
Enable
|
||||
</Button>
|
||||
<button
|
||||
onClick={handleDismiss}
|
||||
className="rounded p-1 text-amber-400 transition-colors hover:text-amber-600"
|
||||
aria-label="Dismiss"
|
||||
>
|
||||
<X className="h-4 w-4" />
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,95 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Key, storage } from "@/services/storage/local-storage";
|
||||
import { BellRinging } from "@phosphor-icons/react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
|
||||
export function NotificationDialog() {
|
||||
const {
|
||||
showNotificationDialog,
|
||||
setShowNotificationDialog,
|
||||
setNotificationsEnabled,
|
||||
isNotificationsEnabled,
|
||||
} = useCopilotUIStore();
|
||||
|
||||
const [dismissed, setDismissed] = useState(
|
||||
() => storage.get(Key.COPILOT_NOTIFICATION_DIALOG_DISMISSED) === "true",
|
||||
);
|
||||
const [permission, setPermission] = useState(() =>
|
||||
typeof Notification !== "undefined" ? Notification.permission : "denied",
|
||||
);
|
||||
|
||||
// Re-read dismissed flag when notifications are toggled off (e.g. clearCopilotLocalData)
|
||||
useEffect(() => {
|
||||
if (!isNotificationsEnabled) {
|
||||
setDismissed(
|
||||
storage.get(Key.COPILOT_NOTIFICATION_DIALOG_DISMISSED) === "true",
|
||||
);
|
||||
}
|
||||
}, [isNotificationsEnabled]);
|
||||
|
||||
const shouldShowAuto =
|
||||
typeof Notification !== "undefined" &&
|
||||
permission === "default" &&
|
||||
!dismissed;
|
||||
|
||||
const isOpen = showNotificationDialog || shouldShowAuto;
|
||||
|
||||
function handleEnable() {
|
||||
if (typeof Notification === "undefined") {
|
||||
handleDismiss();
|
||||
return;
|
||||
}
|
||||
Notification.requestPermission().then((result) => {
|
||||
setPermission(result);
|
||||
if (result === "granted") {
|
||||
setNotificationsEnabled(true);
|
||||
handleDismiss();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
function handleDismiss() {
|
||||
storage.set(Key.COPILOT_NOTIFICATION_DIALOG_DISMISSED, "true");
|
||||
setDismissed(true);
|
||||
setShowNotificationDialog(false);
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Stay in the loop"
|
||||
styling={{ maxWidth: "28rem", minWidth: "auto" }}
|
||||
controlled={{
|
||||
isOpen,
|
||||
set: async (open) => {
|
||||
if (!open) handleDismiss();
|
||||
},
|
||||
}}
|
||||
onClose={handleDismiss}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<div className="flex flex-col items-center gap-4 py-2">
|
||||
<div className="flex h-12 w-12 items-center justify-center rounded-full bg-violet-100">
|
||||
<BellRinging className="h-6 w-6 text-violet-600" weight="fill" />
|
||||
</div>
|
||||
<Text variant="body" className="text-center text-neutral-600">
|
||||
Otto can notify you when a response is ready, even if you switch
|
||||
tabs or close this page. Enable notifications so you never miss one.
|
||||
</Text>
|
||||
</div>
|
||||
<Dialog.Footer>
|
||||
<Button variant="secondary" onClick={handleDismiss}>
|
||||
Not now
|
||||
</Button>
|
||||
<Button variant="primary" onClick={handleEnable}>
|
||||
Enable notifications
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -15,8 +15,6 @@
|
||||
position: absolute;
|
||||
left: 0;
|
||||
top: 0;
|
||||
transform: scale(0);
|
||||
opacity: 0;
|
||||
animation: ripple 2s linear infinite;
|
||||
}
|
||||
|
||||
@@ -27,10 +25,7 @@
|
||||
@keyframes ripple {
|
||||
0% {
|
||||
transform: scale(0);
|
||||
opacity: 0.6;
|
||||
}
|
||||
50% {
|
||||
opacity: 0.3;
|
||||
opacity: 1;
|
||||
}
|
||||
100% {
|
||||
transform: scale(1);
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/molecules/Popover/Popover";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { ChartBar } from "@phosphor-icons/react";
|
||||
import { useUsageLimits } from "./useUsageLimits";
|
||||
|
||||
const MS_PER_MINUTE = 60_000;
|
||||
const MS_PER_HOUR = 3_600_000;
|
||||
|
||||
function formatResetTime(resetsAt: Date | string): string {
|
||||
const resetDate =
|
||||
typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt;
|
||||
const now = new Date();
|
||||
const diffMs = resetDate.getTime() - now.getTime();
|
||||
if (diffMs <= 0) return "now";
|
||||
|
||||
const hours = Math.floor(diffMs / MS_PER_HOUR);
|
||||
|
||||
// Under 24h: show relative time ("in 4h 23m")
|
||||
if (hours < 24) {
|
||||
const minutes = Math.floor((diffMs % MS_PER_HOUR) / MS_PER_MINUTE);
|
||||
if (hours > 0) return `in ${hours}h ${minutes}m`;
|
||||
return `in ${minutes}m`;
|
||||
}
|
||||
|
||||
// Over 24h: show day and time in local timezone ("Mon 12:00 AM PST")
|
||||
return resetDate.toLocaleString(undefined, {
|
||||
weekday: "short",
|
||||
hour: "numeric",
|
||||
minute: "2-digit",
|
||||
timeZoneName: "short",
|
||||
});
|
||||
}
|
||||
|
||||
function UsageBar({
|
||||
label,
|
||||
used,
|
||||
limit,
|
||||
resetsAt,
|
||||
}: {
|
||||
label: string;
|
||||
used: number;
|
||||
limit: number;
|
||||
resetsAt: Date | string;
|
||||
}) {
|
||||
if (limit <= 0) return null;
|
||||
|
||||
const rawPercent = (used / limit) * 100;
|
||||
const percent = Math.min(100, Math.round(rawPercent));
|
||||
const isHigh = percent >= 80;
|
||||
const percentLabel =
|
||||
used > 0 && percent === 0 ? "<1% used" : `${percent}% used`;
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="flex items-baseline justify-between">
|
||||
<span className="text-xs font-medium text-neutral-700">{label}</span>
|
||||
<span className="text-[11px] tabular-nums text-neutral-500">
|
||||
{percentLabel}
|
||||
</span>
|
||||
</div>
|
||||
<div className="text-[10px] text-neutral-400">
|
||||
Resets {formatResetTime(resetsAt)}
|
||||
</div>
|
||||
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
|
||||
<div
|
||||
className={`h-full rounded-full transition-[width] duration-300 ease-out ${
|
||||
isHigh ? "bg-orange-500" : "bg-blue-500"
|
||||
}`}
|
||||
style={{ width: `${Math.max(used > 0 ? 1 : 0, percent)}%` }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function UsagePanelContent({
|
||||
usage,
|
||||
showBillingLink = true,
|
||||
}: {
|
||||
usage: CoPilotUsageStatus;
|
||||
showBillingLink?: boolean;
|
||||
}) {
|
||||
const hasDailyLimit = usage.daily.limit > 0;
|
||||
const hasWeeklyLimit = usage.weekly.limit > 0;
|
||||
|
||||
if (!hasDailyLimit && !hasWeeklyLimit) {
|
||||
return (
|
||||
<div className="text-xs text-neutral-500">No usage limits configured</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-3">
|
||||
<div className="text-xs font-semibold text-neutral-800">Usage limits</div>
|
||||
{hasDailyLimit && (
|
||||
<UsageBar
|
||||
label="Today"
|
||||
used={usage.daily.used}
|
||||
limit={usage.daily.limit}
|
||||
resetsAt={usage.daily.resets_at}
|
||||
/>
|
||||
)}
|
||||
{hasWeeklyLimit && (
|
||||
<UsageBar
|
||||
label="This week"
|
||||
used={usage.weekly.used}
|
||||
limit={usage.weekly.limit}
|
||||
resetsAt={usage.weekly.resets_at}
|
||||
/>
|
||||
)}
|
||||
{showBillingLink && (
|
||||
<a
|
||||
href="/profile/credits"
|
||||
className="text-[11px] text-blue-600 hover:underline"
|
||||
>
|
||||
Learn more about usage limits
|
||||
</a>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function UsageLimits() {
|
||||
const { data: usage, isLoading } = useUsageLimits();
|
||||
|
||||
if (isLoading || !usage) return null;
|
||||
if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null;
|
||||
|
||||
return (
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<Button variant="ghost" size="icon" aria-label="Usage limits">
|
||||
<ChartBar className="!size-5" weight="light" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent align="start" className="w-64 p-3">
|
||||
<UsagePanelContent usage={usage} />
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
@@ -1,121 +0,0 @@
|
||||
import { render, screen, cleanup } from "@/tests/integrations/test-utils";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { UsageLimits } from "../UsageLimits";
|
||||
|
||||
// Mock the useUsageLimits hook
|
||||
const mockUseUsageLimits = vi.fn();
|
||||
vi.mock("../useUsageLimits", () => ({
|
||||
useUsageLimits: () => mockUseUsageLimits(),
|
||||
}));
|
||||
|
||||
// Mock Popover to render children directly (Radix portals don't work in happy-dom)
|
||||
vi.mock("@/components/molecules/Popover/Popover", () => ({
|
||||
Popover: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
PopoverTrigger: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
PopoverContent: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
}));
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
mockUseUsageLimits.mockReset();
|
||||
});
|
||||
|
||||
function makeUsage({
|
||||
dailyUsed = 500,
|
||||
dailyLimit = 10000,
|
||||
weeklyUsed = 2000,
|
||||
weeklyLimit = 50000,
|
||||
}: {
|
||||
dailyUsed?: number;
|
||||
dailyLimit?: number;
|
||||
weeklyUsed?: number;
|
||||
weeklyLimit?: number;
|
||||
} = {}) {
|
||||
const future = new Date(Date.now() + 3600 * 1000); // 1h from now
|
||||
return {
|
||||
daily: { used: dailyUsed, limit: dailyLimit, resets_at: future },
|
||||
weekly: { used: weeklyUsed, limit: weeklyLimit, resets_at: future },
|
||||
};
|
||||
}
|
||||
|
||||
describe("UsageLimits", () => {
|
||||
it("renders nothing while loading", () => {
|
||||
mockUseUsageLimits.mockReturnValue({ data: undefined, isLoading: true });
|
||||
const { container } = render(<UsageLimits />);
|
||||
expect(container.innerHTML).toBe("");
|
||||
});
|
||||
|
||||
it("renders nothing when no limits are configured", () => {
|
||||
mockUseUsageLimits.mockReturnValue({
|
||||
data: makeUsage({ dailyLimit: 0, weeklyLimit: 0 }),
|
||||
isLoading: false,
|
||||
});
|
||||
const { container } = render(<UsageLimits />);
|
||||
expect(container.innerHTML).toBe("");
|
||||
});
|
||||
|
||||
it("renders the usage button when limits exist", () => {
|
||||
mockUseUsageLimits.mockReturnValue({
|
||||
data: makeUsage(),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
expect(screen.getByRole("button", { name: /usage limits/i })).toBeDefined();
|
||||
});
|
||||
|
||||
it("displays daily and weekly usage percentages", () => {
|
||||
mockUseUsageLimits.mockReturnValue({
|
||||
data: makeUsage({ dailyUsed: 5000, dailyLimit: 10000 }),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
|
||||
expect(screen.getByText("50% used")).toBeDefined();
|
||||
expect(screen.getByText("Today")).toBeDefined();
|
||||
expect(screen.getByText("This week")).toBeDefined();
|
||||
expect(screen.getByText("Usage limits")).toBeDefined();
|
||||
});
|
||||
|
||||
it("shows only weekly bar when daily limit is 0", () => {
|
||||
mockUseUsageLimits.mockReturnValue({
|
||||
data: makeUsage({
|
||||
dailyLimit: 0,
|
||||
weeklyUsed: 25000,
|
||||
weeklyLimit: 50000,
|
||||
}),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
|
||||
expect(screen.getByText("This week")).toBeDefined();
|
||||
expect(screen.queryByText("Today")).toBeNull();
|
||||
});
|
||||
|
||||
it("caps percentage at 100% when over limit", () => {
|
||||
mockUseUsageLimits.mockReturnValue({
|
||||
data: makeUsage({ dailyUsed: 15000, dailyLimit: 10000 }),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
|
||||
expect(screen.getByText("100% used")).toBeDefined();
|
||||
});
|
||||
|
||||
it("shows learn more link to credits page", () => {
|
||||
mockUseUsageLimits.mockReturnValue({
|
||||
data: makeUsage(),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
|
||||
const link = screen.getByText("Learn more about usage limits");
|
||||
expect(link).toBeDefined();
|
||||
expect(link.closest("a")?.getAttribute("href")).toBe("/profile/credits");
|
||||
});
|
||||
});
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user