Compare commits

..

12 Commits

Author SHA1 Message Date
Otto (AGPT)
e9afd9fa01 fix: cast OnboardingStep enum to text in funnel view
The completedSteps column is a platform."OnboardingStep" enum array.
UNNEST produces enum values that can't be compared directly to text
from the VALUES clause. Adding ::text cast fixes the type mismatch.
2026-03-13 11:55:37 +00:00
Zamil Majdy
ddb4f6e9de fix(analytics): address second batch of PR review comments
- user_onboarding_funnel: build complete 22-step grid with VALUES CTE
  so zero-completion steps are always present, fixing LAG comparisons
  against wrong predecessors; update docs to reflect all 22 steps
- users_activities: use COUNT(DISTINCT "id") for agent_count to avoid
  counting multiple version rows per graph; add COALESCE(..., 0) for
  agent_count, unique_agent_runs, agent_runs; update docs column list
  to include node_execution_incomplete and node_execution_review
- generate_views: update Step 3 comment to clarify NOLOGIN role needs
  WITH LOGIN PASSWORD not just WITH PASSWORD; add fail-fast validation
  for unknown --only view names with helpful error message
2026-03-12 00:47:55 +07:00
Zamil Majdy
f585d97928 fix(analytics): move new status columns to end of users_activities SELECT
CREATE OR REPLACE VIEW requires existing columns to stay in position.
Moving node_execution_incomplete and node_execution_review after
is_active_after_7d so the replacement doesn't shift existing columns.
2026-03-12 00:01:40 +07:00
Zamil Majdy
7d39234fdd fix(analytics): address PR review comments
- user_block_spending: use ->> instead of -> for JSONB field extraction
  before casting to int (avoids runtime cast errors)
- generate_views: create analytics_readonly as NOLOGIN to avoid a
  usable role with a known default password
- generate_views: percent-encode DB credentials in the URI builder so
  passwords with reserved chars (@, :, /) connect correctly
- graph_execution: remove WHERE filter on sensitive_action_safe_mode
  before DISTINCT ON so the latest LibraryAgent version always wins
  (fixes possibly_ai being sticky once any version had the flag set)
- retention_agent: use DISTINCT ON ordered by version DESC instead of
  MAX(name) so renamed agents resolve to their latest name
- retention_login_daily: add 90-day cohort_start filter to first_login
  CTE so the view matches its documented window
- user_onboarding_funnel: map the 8 missing OnboardingStep enum values
  (VISIT_COPILOT, RE_RUN_AGENT, SCHEDULE_AGENT, RUN_AGENTS, RUN_3_DAYS,
  TRIGGER_WEBHOOK, RUN_14_DAYS, RUN_AGENTS_100) to step_order 15-22
- users_activities: use updatedAt instead of createdAt for
  last_agent_save_time; add node_execution_incomplete and
  node_execution_review status columns
2026-03-11 23:48:42 +07:00
Zamil Majdy
6e9d4c4333 perf(analytics): fix fan-out in users_activities view
The original CTEs drove all joins from user_logins, causing a
O(users × executions × node_executions) fan-out that made the view
too heavy for Supabase to serve. Rewrote each CTE to aggregate its
own source table directly by userId, then LEFT JOIN the aggregates
in the final SELECT.
2026-03-11 23:39:14 +07:00
Zamil Majdy
8aad333a45 refactor(analytics): move generate_views.py to backend, add poetry run analytics-setup/analytics-views scripts 2026-03-11 16:23:29 +07:00
Zamil Majdy
856f0d980d fix(analytics): restrict analytics_readonly to analytics schema only via security_invoker=false views 2026-03-11 16:16:03 +07:00
Zamil Majdy
3c3aadd361 docs(analytics): add step-by-step quick start to generate_views.py docstring 2026-03-11 16:12:22 +07:00
Zamil Majdy
e87a693fdd feat(analytics): auto-load DB creds from backend/.env as fallback 2026-03-11 16:10:31 +07:00
Zamil Majdy
fe265c10d4 refactor(analytics): generate setup.sql via --setup flag, gitignore it 2026-03-11 16:01:52 +07:00
Zamil Majdy
5d00a94693 chore(analytics): remove auto-generated files, gitignore views.sql 2026-03-11 16:00:48 +07:00
Zamil Majdy
6e1605994d feat(analytics): add documented SQL views with generation script
Introduces an analytics/ layer that wraps production Postgres data in
safe, read-only views exposed under the analytics schema.

- 14 documented query files in queries/ (one per Looker data source)
  covering auth activities, user activity, execution metrics, onboarding
  funnel, and cohort retention (login + execution, weekly + daily)
- setup.sql — one-time schema creation and role/grant setup for the
  analytics_readonly role (auth, platform, analytics schemas)
- generate_views.py — reads queries/*.sql and applies
  CREATE OR REPLACE VIEW analytics.<name> to the database;
  supports --dry-run, --only, and --db-url flags
- views.sql — pre-generated combined reference output
- README.md — full setup, deployment, and integration guide

Looker, PostHog Data Warehouse, and Supabase MCP (for Otto) all
connect to the same analytics.* views instead of raw tables.
2026-03-11 15:36:27 +07:00
145 changed files with 1413 additions and 11794 deletions

View File

@@ -37,10 +37,6 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
## ===== SIGNUP / INVITE GATE ===== ##
# Set to true to require an invite before users can sign up
ENABLE_INVITE_GATE=false
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
# Platform URLs (set these for webhooks and OAuth to work)
PLATFORM_BASE_URL=http://localhost:8000

View File

@@ -1,17 +1,8 @@
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, Optional
import prisma.enums
from pydantic import BaseModel, EmailStr
from pydantic import BaseModel
from backend.data.model import UserTransaction
from backend.util.models import Pagination
if TYPE_CHECKING:
from backend.data.invited_user import BulkInvitedUsersResult, InvitedUserRecord
class UserHistoryResponse(BaseModel):
"""Response model for listings with version history"""
@@ -23,70 +14,3 @@ class UserHistoryResponse(BaseModel):
class AddUserCreditsResponse(BaseModel):
new_balance: int
transaction_key: str
class CreateInvitedUserRequest(BaseModel):
email: EmailStr
name: Optional[str] = None
class InvitedUserResponse(BaseModel):
id: str
email: str
status: prisma.enums.InvitedUserStatus
auth_user_id: Optional[str] = None
name: Optional[str] = None
tally_understanding: Optional[dict[str, Any]] = None
tally_status: prisma.enums.TallyComputationStatus
tally_computed_at: Optional[datetime] = None
tally_error: Optional[str] = None
created_at: datetime
updated_at: datetime
@classmethod
def from_record(cls, record: InvitedUserRecord) -> InvitedUserResponse:
return cls.model_validate(record.model_dump())
class InvitedUsersResponse(BaseModel):
invited_users: list[InvitedUserResponse]
pagination: Pagination
class BulkInvitedUserRowResponse(BaseModel):
row_number: int
email: Optional[str] = None
name: Optional[str] = None
status: Literal["CREATED", "SKIPPED", "ERROR"]
message: str
invited_user: Optional[InvitedUserResponse] = None
class BulkInvitedUsersResponse(BaseModel):
created_count: int
skipped_count: int
error_count: int
results: list[BulkInvitedUserRowResponse]
@classmethod
def from_result(cls, result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
return cls(
created_count=result.created_count,
skipped_count=result.skipped_count,
error_count=result.error_count,
results=[
BulkInvitedUserRowResponse(
row_number=row.row_number,
email=row.email,
name=row.name,
status=row.status,
message=row.message,
invited_user=(
InvitedUserResponse.from_record(row.invited_user)
if row.invited_user is not None
else None
),
)
for row in result.results
],
)

View File

@@ -1,137 +0,0 @@
import logging
import math
from autogpt_libs.auth import get_user_id, requires_admin_user
from fastapi import APIRouter, File, Query, Security, UploadFile
from backend.data.invited_user import (
bulk_create_invited_users_from_file,
create_invited_user,
list_invited_users,
retry_invited_user_tally,
revoke_invited_user,
)
from backend.data.tally import mask_email
from backend.util.models import Pagination
from .model import (
BulkInvitedUsersResponse,
CreateInvitedUserRequest,
InvitedUserResponse,
InvitedUsersResponse,
)
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/admin",
tags=["users", "admin"],
dependencies=[Security(requires_admin_user)],
)
@router.get(
"/invited-users",
response_model=InvitedUsersResponse,
summary="List Invited Users",
)
async def get_invited_users(
admin_user_id: str = Security(get_user_id),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=200),
) -> InvitedUsersResponse:
logger.info("Admin user %s requested invited users", admin_user_id)
invited_users, total = await list_invited_users(page=page, page_size=page_size)
return InvitedUsersResponse(
invited_users=[InvitedUserResponse.from_record(iu) for iu in invited_users],
pagination=Pagination(
total_items=total,
total_pages=max(1, math.ceil(total / page_size)),
current_page=page,
page_size=page_size,
),
)
@router.post(
"/invited-users",
response_model=InvitedUserResponse,
summary="Create Invited User",
)
async def create_invited_user_route(
request: CreateInvitedUserRequest,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info(
"Admin user %s creating invited user for %s",
admin_user_id,
mask_email(request.email),
)
invited_user = await create_invited_user(request.email, request.name)
logger.info(
"Admin user %s created invited user %s",
admin_user_id,
invited_user.id,
)
return InvitedUserResponse.from_record(invited_user)
@router.post(
"/invited-users/bulk",
response_model=BulkInvitedUsersResponse,
summary="Bulk Create Invited Users",
operation_id="postV2BulkCreateInvitedUsers",
)
async def bulk_create_invited_users_route(
file: UploadFile = File(...),
admin_user_id: str = Security(get_user_id),
) -> BulkInvitedUsersResponse:
logger.info(
"Admin user %s bulk invited users from %s",
admin_user_id,
file.filename or "<unnamed>",
)
content = await file.read()
result = await bulk_create_invited_users_from_file(file.filename, content)
return BulkInvitedUsersResponse.from_result(result)
@router.post(
"/invited-users/{invited_user_id}/revoke",
response_model=InvitedUserResponse,
summary="Revoke Invited User",
)
async def revoke_invited_user_route(
invited_user_id: str,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info(
"Admin user %s revoking invited user %s", admin_user_id, invited_user_id
)
invited_user = await revoke_invited_user(invited_user_id)
logger.info("Admin user %s revoked invited user %s", admin_user_id, invited_user_id)
return InvitedUserResponse.from_record(invited_user)
@router.post(
"/invited-users/{invited_user_id}/retry-tally",
response_model=InvitedUserResponse,
summary="Retry Invited User Tally",
)
async def retry_invited_user_tally_route(
invited_user_id: str,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info(
"Admin user %s retrying Tally seed for invited user %s",
admin_user_id,
invited_user_id,
)
invited_user = await retry_invited_user_tally(invited_user_id)
logger.info(
"Admin user %s retried Tally seed for invited user %s",
admin_user_id,
invited_user_id,
)
return InvitedUserResponse.from_record(invited_user)

View File

@@ -1,168 +0,0 @@
from datetime import datetime, timezone
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import prisma.enums
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from backend.data.invited_user import (
BulkInvitedUserRowResult,
BulkInvitedUsersResult,
InvitedUserRecord,
)
from .user_admin_routes import router as user_admin_router
app = fastapi.FastAPI()
app.include_router(user_admin_router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def _sample_invited_user() -> InvitedUserRecord:
now = datetime.now(timezone.utc)
return InvitedUserRecord(
id="invite-1",
email="invited@example.com",
status=prisma.enums.InvitedUserStatus.INVITED,
auth_user_id=None,
name="Invited User",
tally_understanding=None,
tally_status=prisma.enums.TallyComputationStatus.PENDING,
tally_computed_at=None,
tally_error=None,
created_at=now,
updated_at=now,
)
def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
return BulkInvitedUsersResult(
created_count=1,
skipped_count=1,
error_count=0,
results=[
BulkInvitedUserRowResult(
row_number=1,
email="invited@example.com",
name=None,
status="CREATED",
message="Invite created",
invited_user=_sample_invited_user(),
),
BulkInvitedUserRowResult(
row_number=2,
email="duplicate@example.com",
name=None,
status="SKIPPED",
message="An invited user with this email already exists",
invited_user=None,
),
],
)
def test_get_invited_users(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.list_invited_users",
AsyncMock(return_value=([_sample_invited_user()], 1)),
)
response = client.get("/admin/invited-users")
assert response.status_code == 200
data = response.json()
assert len(data["invited_users"]) == 1
assert data["invited_users"][0]["email"] == "invited@example.com"
assert data["invited_users"][0]["status"] == "INVITED"
assert data["pagination"]["total_items"] == 1
assert data["pagination"]["current_page"] == 1
assert data["pagination"]["page_size"] == 50
def test_create_invited_user(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.create_invited_user",
AsyncMock(return_value=_sample_invited_user()),
)
response = client.post(
"/admin/invited-users",
json={"email": "invited@example.com", "name": "Invited User"},
)
assert response.status_code == 200
data = response.json()
assert data["email"] == "invited@example.com"
assert data["name"] == "Invited User"
def test_bulk_create_invited_users(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.bulk_create_invited_users_from_file",
AsyncMock(return_value=_sample_bulk_invited_users_result()),
)
response = client.post(
"/admin/invited-users/bulk",
files={
"file": ("invites.txt", b"invited@example.com\nduplicate@example.com\n")
},
)
assert response.status_code == 200
data = response.json()
assert data["created_count"] == 1
assert data["skipped_count"] == 1
assert data["results"][0]["status"] == "CREATED"
assert data["results"][1]["status"] == "SKIPPED"
def test_revoke_invited_user(
mocker: pytest_mock.MockerFixture,
) -> None:
revoked = _sample_invited_user().model_copy(
update={"status": prisma.enums.InvitedUserStatus.REVOKED}
)
mocker.patch(
"backend.api.features.admin.user_admin_routes.revoke_invited_user",
AsyncMock(return_value=revoked),
)
response = client.post("/admin/invited-users/invite-1/revoke")
assert response.status_code == 200
assert response.json()["status"] == "REVOKED"
def test_retry_invited_user_tally(
mocker: pytest_mock.MockerFixture,
) -> None:
retried = _sample_invited_user().model_copy(
update={"tally_status": prisma.enums.TallyComputationStatus.RUNNING}
)
mocker.patch(
"backend.api.features.admin.user_admin_routes.retry_invited_user_tally",
AsyncMock(return_value=retried),
)
response = client.post("/admin/invited-users/invite-1/retry-tally")
assert response.status_code == 200
assert response.json()["tally_status"] == "RUNNING"

View File

@@ -27,12 +27,6 @@ from backend.copilot.model import (
get_user_sessions,
update_session_title,
)
from backend.copilot.rate_limit import (
CoPilotUsageStatus,
RateLimitExceeded,
check_rate_limit,
get_usage_status,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.e2b_sandbox import kill_sandbox
from backend.copilot.tools.models import (
@@ -59,8 +53,6 @@ from backend.copilot.tools.models import (
UnderstandingUpdatedResponse,
)
from backend.copilot.tracking import track_user_message
from backend.data.redis_client import get_redis_async
from backend.data.understanding import get_business_understanding
from backend.data.workspace import get_or_create_workspace
from backend.util.exceptions import NotFoundError
@@ -126,8 +118,6 @@ class SessionDetailResponse(BaseModel):
user_id: str | None
messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
class SessionSummaryResponse(BaseModel):
@@ -137,7 +127,6 @@ class SessionSummaryResponse(BaseModel):
created_at: str
updated_at: str
title: str | None = None
is_processing: bool
class ListSessionsResponse(BaseModel):
@@ -196,28 +185,6 @@ async def list_sessions(
"""
sessions, total_count = await get_user_sessions(user_id, limit, offset)
# Batch-check Redis for active stream status on each session
processing_set: set[str] = set()
if sessions:
try:
redis = await get_redis_async()
pipe = redis.pipeline(transaction=False)
for session in sessions:
pipe.hget(
f"{config.session_meta_prefix}{session.session_id}",
"status",
)
statuses = await pipe.execute()
processing_set = {
session.session_id
for session, st in zip(sessions, statuses)
if st == "running"
}
except Exception:
logger.warning(
"Failed to fetch processing status from Redis; " "defaulting to empty"
)
return ListSessionsResponse(
sessions=[
SessionSummaryResponse(
@@ -225,7 +192,6 @@ async def list_sessions(
created_at=session.started_at.isoformat(),
updated_at=session.updated_at.isoformat(),
title=session.title,
is_processing=session.session_id in processing_set,
)
for session in sessions
],
@@ -397,10 +363,6 @@ async def get_session(
last_message_id=last_message_id,
)
# Sum token usage from session
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
return SessionDetailResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
@@ -408,26 +370,6 @@ async def get_session(
user_id=session.user_id or None,
messages=messages,
active_stream=active_stream_info,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
)
@router.get("/usage")
async def get_copilot_usage(
user_id: Annotated[str | None, Depends(auth.get_user_id)],
) -> CoPilotUsageStatus:
"""Get CoPilot usage status for the authenticated user.
Returns current token usage vs limits for daily and weekly windows.
"""
if not user_id:
raise HTTPException(status_code=401, detail="Authentication required")
return await get_usage_status(
user_id=user_id,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
)
@@ -528,17 +470,6 @@ async def stream_chat_post(
},
)
# Pre-turn rate limit check (token-based)
if user_id and (config.daily_token_limit > 0 or config.weekly_token_limit > 0):
try:
await check_rate_limit(
user_id=user_id,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
)
except RateLimitExceeded as e:
raise HTTPException(status_code=429, detail=str(e)) from e
# Enrich message with file metadata if file_ids are provided.
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
@@ -897,36 +828,6 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Suggested Prompts ==========
class SuggestedPromptsResponse(BaseModel):
"""Response model for user-specific suggested prompts."""
prompts: list[str]
@router.get(
"/suggested-prompts",
dependencies=[Security(auth.requires_user)],
)
async def get_suggested_prompts(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> SuggestedPromptsResponse:
"""
Get LLM-generated suggested prompts for the authenticated user.
Returns personalized quick-action prompts based on the user's
business understanding. Returns an empty list if no custom prompts
are available.
"""
understanding = await get_business_understanding(user_id)
if understanding is None:
return SuggestedPromptsResponse(prompts=[])
return SuggestedPromptsResponse(prompts=understanding.suggested_prompts)
# ========== Configuration ==========

View File

@@ -1,7 +1,6 @@
"""Tests for chat API routes: session title update, file attachment validation, usage, and suggested prompts."""
"""Tests for chat API routes: session title update and file attachment validation."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
@@ -250,130 +249,3 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
call_kwargs = mock_prisma.find_many.call_args[1]
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
assert call_kwargs["where"]["isDeleted"] is False
# ─── Usage endpoint ───────────────────────────────────────────────────
def _mock_usage(
mocker: pytest_mock.MockerFixture,
*,
daily_used: int = 500,
weekly_used: int = 2000,
) -> AsyncMock:
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
resets_at = datetime.now(UTC) + timedelta(days=1)
status = CoPilotUsageStatus(
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
)
return mocker.patch(
"backend.api.features.chat.routes.get_usage_status",
new_callable=AsyncMock,
return_value=status,
)
def test_usage_returns_daily_and_weekly(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""GET /usage returns daily and weekly usage."""
mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000)
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
response = client.get("/usage")
assert response.status_code == 200
data = response.json()
assert data["daily"]["used"] == 500
assert data["weekly"]["used"] == 2000
mock_get.assert_called_once_with(
user_id=test_user_id,
daily_token_limit=10000,
weekly_token_limit=50000,
)
def test_usage_uses_config_limits(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
mock_get = _mock_usage(mocker)
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
response = client.get("/usage")
assert response.status_code == 200
mock_get.assert_called_once_with(
user_id=test_user_id,
daily_token_limit=99999,
weekly_token_limit=77777,
)
# ─── Suggested prompts endpoint ──────────────────────────────────────
def _mock_get_business_understanding(
mocker: pytest_mock.MockerFixture,
*,
return_value=None,
):
"""Mock get_business_understanding."""
return mocker.patch(
"backend.api.features.chat.routes.get_business_understanding",
new_callable=AsyncMock,
return_value=return_value,
)
def test_suggested_prompts_returns_prompts(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with understanding and prompts gets them back."""
mock_understanding = MagicMock()
mock_understanding.suggested_prompts = ["Do X", "Do Y", "Do Z"]
_mock_get_business_understanding(mocker, return_value=mock_understanding)
response = client.get("/suggested-prompts")
assert response.status_code == 200
assert response.json() == {"prompts": ["Do X", "Do Y", "Do Z"]}
def test_suggested_prompts_no_understanding(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with no understanding gets empty list."""
_mock_get_business_understanding(mocker, return_value=None)
response = client.get("/suggested-prompts")
assert response.status_code == 200
assert response.json() == {"prompts": []}
def test_suggested_prompts_empty_prompts(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with understanding but no prompts gets empty list."""
mock_understanding = MagicMock()
mock_understanding.suggested_prompts = []
_mock_get_business_understanding(mocker, return_value=mock_understanding)
response = client.get("/suggested-prompts")
assert response.status_code == 200
assert response.json() == {"prompts": []}

View File

@@ -165,6 +165,7 @@ class LibraryAgent(pydantic.BaseModel):
id: str
graph_id: str
graph_version: int
owner_user_id: str
image_url: str | None
@@ -205,9 +206,7 @@ class LibraryAgent(pydantic.BaseModel):
default_factory=list,
description="List of recent executions with status, score, and summary",
)
can_access_graph: bool = pydantic.Field(
description="Indicates whether the same user owns the corresponding graph"
)
can_access_graph: bool
is_latest_version: bool
is_favorite: bool
folder_id: str | None = None
@@ -325,6 +324,7 @@ class LibraryAgent(pydantic.BaseModel):
id=agent.id,
graph_id=agent.agentGraphId,
graph_version=agent.agentGraphVersion,
owner_user_id=agent.userId,
image_url=agent.imageUrl,
creator_name=creator_name,
creator_image_url=creator_image_url,

View File

@@ -42,6 +42,7 @@ async def test_get_library_agents_success(
id="test-agent-1",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 1",
description="Test Description 1",
image_url=None,
@@ -66,6 +67,7 @@ async def test_get_library_agents_success(
id="test-agent-2",
graph_id="test-agent-2",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 2",
description="Test Description 2",
image_url=None,
@@ -129,6 +131,7 @@ async def test_get_favorite_library_agents_success(
id="test-agent-1",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Favorite Agent 1",
description="Test Favorite Description 1",
image_url=None,
@@ -181,6 +184,7 @@ def test_add_agent_to_library_success(
id="test-library-agent-id",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 1",
description="Test Description 1",
image_url=None,

View File

@@ -55,7 +55,6 @@ from backend.data.credit import (
set_auto_top_up,
)
from backend.data.graph import GraphSettings
from backend.data.invited_user import get_or_activate_user
from backend.data.model import CredentialsMetaInput, UserOnboarding
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.data.onboarding import (
@@ -71,6 +70,7 @@ from backend.data.onboarding import (
update_user_onboarding,
)
from backend.data.user import (
get_or_create_user,
get_user_by_id,
get_user_notification_preference,
update_user_email,
@@ -136,10 +136,12 @@ _tally_background_tasks: set[asyncio.Task] = set()
dependencies=[Security(requires_user)],
)
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
user = await get_or_activate_user(user_data)
user = await get_or_create_user(user_data)
# Fire-and-forget: backfill Tally understanding when invite pre-seeding did
# not produce a stored result before first activation.
# Fire-and-forget: populate business understanding from Tally form.
# We use created_at proximity instead of an is_new flag because
# get_or_create_user is cached — a separate is_new return value would be
# unreliable on repeated calls within the cache TTL.
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
if age_seconds < 30:
try:
@@ -163,8 +165,7 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
dependencies=[Security(requires_user)],
)
async def update_user_email_route(
user_id: Annotated[str, Security(get_user_id)],
email: str = Body(...),
user_id: Annotated[str, Security(get_user_id)], email: str = Body(...)
) -> dict[str, str]:
await update_user_email(user_id, email)
@@ -178,16 +179,10 @@ async def update_user_email_route(
dependencies=[Security(requires_user)],
)
async def get_user_timezone_route(
user_id: Annotated[str, Security(get_user_id)],
user_data: dict = Security(get_jwt_payload),
) -> TimezoneResponse:
"""Get user timezone setting."""
try:
user = await get_user_by_id(user_id)
except ValueError:
raise HTTPException(
status_code=HTTP_404_NOT_FOUND,
detail="User not found. Please complete activation via /auth/user first.",
)
user = await get_or_create_user(user_data)
return TimezoneResponse(timezone=user.timezone)
@@ -198,8 +193,7 @@ async def get_user_timezone_route(
dependencies=[Security(requires_user)],
)
async def update_user_timezone_route(
user_id: Annotated[str, Security(get_user_id)],
request: UpdateTimezoneRequest,
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
) -> TimezoneResponse:
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
user = await update_user_timezone(user_id, str(request.timezone))

View File

@@ -51,7 +51,7 @@ def test_get_or_create_user_route(
}
mocker.patch(
"backend.api.features.v1.get_or_activate_user",
"backend.api.features.v1.get_or_create_user",
return_value=mock_user,
)

View File

@@ -94,8 +94,3 @@ class NotificationPayload(pydantic.BaseModel):
class OnboardingNotificationPayload(NotificationPayload):
step: OnboardingStep | None
class CopilotCompletionPayload(NotificationPayload):
session_id: str
status: Literal["completed", "failed"]

View File

@@ -19,7 +19,6 @@ from prisma.errors import PrismaError
import backend.api.features.admin.credit_admin_routes
import backend.api.features.admin.execution_analytics_routes
import backend.api.features.admin.store_admin_routes
import backend.api.features.admin.user_admin_routes
import backend.api.features.builder
import backend.api.features.builder.routes
import backend.api.features.chat.routes as chat_routes
@@ -312,11 +311,6 @@ app.include_router(
tags=["v2", "admin"],
prefix="/api/executions",
)
app.include_router(
backend.api.features.admin.user_admin_routes.router,
tags=["v2", "admin"],
prefix="/api/users",
)
app.include_router(
backend.api.features.executions.review.routes.router,
tags=["v2", "executions", "review"],

View File

@@ -96,7 +96,6 @@ class SendEmailBlock(Block):
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Email sent successfully")],
test_mock={"send_email": lambda *args, **kwargs: "Email sent successfully"},
is_sensitive_action=True,
)
@staticmethod

View File

@@ -1,3 +0,0 @@
def github_repo_path(repo_url: str) -> str:
"""Extract 'owner/repo' from a GitHub repository URL."""
return repo_url.replace("https://github.com/", "")

View File

@@ -1,374 +0,0 @@
import asyncio
from enum import StrEnum
from urllib.parse import quote
from typing_extensions import TypedDict
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import SchemaField
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GithubCredentials,
GithubCredentialsField,
GithubCredentialsInput,
)
from ._utils import github_repo_path
class GithubListCommitsBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Branch name to list commits from",
default="main",
)
per_page: int = SchemaField(
description="Number of commits to return (max 100)",
default=30,
ge=1,
le=100,
)
page: int = SchemaField(
description="Page number for pagination",
default=1,
ge=1,
)
class Output(BlockSchemaOutput):
class CommitItem(TypedDict):
sha: str
message: str
author: str
date: str
url: str
commit: CommitItem = SchemaField(
title="Commit", description="A commit with its details"
)
commits: list[CommitItem] = SchemaField(
description="List of commits with their details"
)
error: str = SchemaField(description="Error message if listing commits failed")
def __init__(self):
super().__init__(
id="8b13f579-d8b6-4dc2-a140-f770428805de",
description="This block lists commits on a branch in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListCommitsBlock.Input,
output_schema=GithubListCommitsBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "main",
"per_page": 30,
"page": 1,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"commits",
[
{
"sha": "abc123",
"message": "Initial commit",
"author": "octocat",
"date": "2024-01-01T00:00:00Z",
"url": "https://github.com/owner/repo/commit/abc123",
}
],
),
(
"commit",
{
"sha": "abc123",
"message": "Initial commit",
"author": "octocat",
"date": "2024-01-01T00:00:00Z",
"url": "https://github.com/owner/repo/commit/abc123",
},
),
],
test_mock={
"list_commits": lambda *args, **kwargs: [
{
"sha": "abc123",
"message": "Initial commit",
"author": "octocat",
"date": "2024-01-01T00:00:00Z",
"url": "https://github.com/owner/repo/commit/abc123",
}
]
},
)
@staticmethod
async def list_commits(
credentials: GithubCredentials,
repo_url: str,
branch: str,
per_page: int,
page: int,
) -> list[Output.CommitItem]:
api = get_api(credentials)
commits_url = repo_url + "/commits"
params = {"sha": branch, "per_page": str(per_page), "page": str(page)}
response = await api.get(commits_url, params=params)
data = response.json()
repo_path = github_repo_path(repo_url)
return [
GithubListCommitsBlock.Output.CommitItem(
sha=c["sha"],
message=c["commit"]["message"],
author=(c["commit"].get("author") or {}).get("name", "Unknown"),
date=(c["commit"].get("author") or {}).get("date", ""),
url=f"https://github.com/{repo_path}/commit/{c['sha']}",
)
for c in data
]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
commits = await self.list_commits(
credentials,
input_data.repo_url,
input_data.branch,
input_data.per_page,
input_data.page,
)
yield "commits", commits
for commit in commits:
yield "commit", commit
except Exception as e:
yield "error", str(e)
class FileOperation(StrEnum):
"""File operations for GithubMultiFileCommitBlock.
UPSERT creates or overwrites a file (the Git Trees API does not distinguish
between creation and update — the blob is placed at the given path regardless
of whether a file already exists there).
DELETE removes a file from the tree.
"""
UPSERT = "upsert"
DELETE = "delete"
class FileOperationInput(TypedDict):
path: str
content: str
operation: FileOperation
class GithubMultiFileCommitBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Branch to commit to",
placeholder="feature-branch",
)
commit_message: str = SchemaField(
description="Commit message",
placeholder="Add new feature",
)
files: list[FileOperationInput] = SchemaField(
description=(
"List of file operations. Each item has: "
"'path' (file path), 'content' (file content, ignored for delete), "
"'operation' (upsert/delete)"
),
)
class Output(BlockSchemaOutput):
sha: str = SchemaField(description="SHA of the new commit")
url: str = SchemaField(description="URL of the new commit")
error: str = SchemaField(description="Error message if the commit failed")
def __init__(self):
super().__init__(
id="389eee51-a95e-4230-9bed-92167a327802",
description=(
"This block creates a single commit with multiple file "
"upsert/delete operations using the Git Trees API."
),
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubMultiFileCommitBlock.Input,
output_schema=GithubMultiFileCommitBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "feature",
"commit_message": "Add files",
"files": [
{
"path": "src/new.py",
"content": "print('hello')",
"operation": "upsert",
},
{
"path": "src/old.py",
"content": "",
"operation": "delete",
},
],
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("sha", "newcommitsha"),
("url", "https://github.com/owner/repo/commit/newcommitsha"),
],
test_mock={
"multi_file_commit": lambda *args, **kwargs: (
"newcommitsha",
"https://github.com/owner/repo/commit/newcommitsha",
)
},
)
@staticmethod
async def multi_file_commit(
credentials: GithubCredentials,
repo_url: str,
branch: str,
commit_message: str,
files: list[FileOperationInput],
) -> tuple[str, str]:
api = get_api(credentials)
safe_branch = quote(branch, safe="")
# 1. Get the latest commit SHA for the branch
ref_url = repo_url + f"/git/refs/heads/{safe_branch}"
response = await api.get(ref_url)
ref_data = response.json()
latest_commit_sha = ref_data["object"]["sha"]
# 2. Get the tree SHA of the latest commit
commit_url = repo_url + f"/git/commits/{latest_commit_sha}"
response = await api.get(commit_url)
commit_data = response.json()
base_tree_sha = commit_data["tree"]["sha"]
# 3. Build tree entries for each file operation (blobs created concurrently)
async def _create_blob(content: str) -> str:
blob_url = repo_url + "/git/blobs"
blob_response = await api.post(
blob_url,
json={"content": content, "encoding": "utf-8"},
)
return blob_response.json()["sha"]
tree_entries: list[dict] = []
upsert_files = []
for file_op in files:
path = file_op["path"]
operation = FileOperation(file_op.get("operation", "upsert"))
if operation == FileOperation.DELETE:
tree_entries.append(
{
"path": path,
"mode": "100644",
"type": "blob",
"sha": None, # null SHA = delete
}
)
else:
upsert_files.append((path, file_op.get("content", "")))
# Create all blobs concurrently
if upsert_files:
blob_shas = await asyncio.gather(
*[_create_blob(content) for _, content in upsert_files]
)
for (path, _), blob_sha in zip(upsert_files, blob_shas):
tree_entries.append(
{
"path": path,
"mode": "100644",
"type": "blob",
"sha": blob_sha,
}
)
# 4. Create a new tree
tree_url = repo_url + "/git/trees"
tree_response = await api.post(
tree_url,
json={"base_tree": base_tree_sha, "tree": tree_entries},
)
new_tree_sha = tree_response.json()["sha"]
# 5. Create a new commit
new_commit_url = repo_url + "/git/commits"
commit_response = await api.post(
new_commit_url,
json={
"message": commit_message,
"tree": new_tree_sha,
"parents": [latest_commit_sha],
},
)
new_commit_sha = commit_response.json()["sha"]
# 6. Update the branch reference
try:
await api.patch(
ref_url,
json={"sha": new_commit_sha},
)
except Exception as e:
raise RuntimeError(
f"Commit {new_commit_sha} was created but failed to update "
f"ref heads/{branch}: {e}. "
f"You can recover by manually updating the branch to {new_commit_sha}."
) from e
repo_path = github_repo_path(repo_url)
commit_web_url = f"https://github.com/{repo_path}/commit/{new_commit_sha}"
return new_commit_sha, commit_web_url
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
sha, url = await self.multi_file_commit(
credentials,
input_data.repo_url,
input_data.branch,
input_data.commit_message,
input_data.files,
)
yield "sha", sha
yield "url", url
except Exception as e:
yield "error", str(e)

View File

@@ -1,5 +1,4 @@
import re
from typing import Literal
from typing_extensions import TypedDict
@@ -21,8 +20,6 @@ from ._auth import (
GithubCredentialsInput,
)
MergeMethod = Literal["merge", "squash", "rebase"]
class GithubListPullRequestsBlock(Block):
class Input(BlockSchemaInput):
@@ -561,109 +558,12 @@ class GithubListPRReviewersBlock(Block):
yield "reviewer", reviewer
class GithubMergePullRequestBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
pr_url: str = SchemaField(
description="URL of the GitHub pull request",
placeholder="https://github.com/owner/repo/pull/1",
)
merge_method: MergeMethod = SchemaField(
description="Merge method to use: merge, squash, or rebase",
default="merge",
)
commit_title: str = SchemaField(
description="Title for the merge commit (optional, used for merge and squash)",
default="",
)
commit_message: str = SchemaField(
description="Message for the merge commit (optional, used for merge and squash)",
default="",
)
class Output(BlockSchemaOutput):
sha: str = SchemaField(description="SHA of the merge commit")
merged: bool = SchemaField(description="Whether the PR was merged")
message: str = SchemaField(description="Merge status message")
error: str = SchemaField(description="Error message if the merge failed")
def __init__(self):
super().__init__(
id="77456c22-33d8-4fd4-9eef-50b46a35bb48",
description="This block merges a pull request using merge, squash, or rebase.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubMergePullRequestBlock.Input,
output_schema=GithubMergePullRequestBlock.Output,
test_input={
"pr_url": "https://github.com/owner/repo/pull/1",
"merge_method": "squash",
"commit_title": "",
"commit_message": "",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("sha", "abc123"),
("merged", True),
("message", "Pull Request successfully merged"),
],
test_mock={
"merge_pr": lambda *args, **kwargs: (
"abc123",
True,
"Pull Request successfully merged",
)
},
is_sensitive_action=True,
)
@staticmethod
async def merge_pr(
credentials: GithubCredentials,
pr_url: str,
merge_method: MergeMethod,
commit_title: str,
commit_message: str,
) -> tuple[str, bool, str]:
api = get_api(credentials)
merge_url = prepare_pr_api_url(pr_url=pr_url, path="merge")
data: dict[str, str] = {"merge_method": merge_method}
if commit_title:
data["commit_title"] = commit_title
if commit_message:
data["commit_message"] = commit_message
response = await api.put(merge_url, json=data)
result = response.json()
return result["sha"], result["merged"], result["message"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
sha, merged, message = await self.merge_pr(
credentials,
input_data.pr_url,
input_data.merge_method,
input_data.commit_title,
input_data.commit_message,
)
yield "sha", sha
yield "merged", merged
yield "message", message
except Exception as e:
yield "error", str(e)
def prepare_pr_api_url(pr_url: str, path: str) -> str:
# Pattern to capture the base repository URL and the pull request number
pattern = r"^(?:(https?)://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
pattern = r"^(?:https?://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
match = re.match(pattern, pr_url)
if not match:
return pr_url
scheme, base_url, pr_number = match.groups()
return f"{scheme or 'https'}://{base_url}/pulls/{pr_number}/{path}"
base_url, pr_number = match.groups()
return f"{base_url}/pulls/{pr_number}/{path}"

View File

@@ -1,3 +1,5 @@
import base64
from typing_extensions import TypedDict
from backend.blocks._base import (
@@ -17,7 +19,6 @@ from ._auth import (
GithubCredentialsField,
GithubCredentialsInput,
)
from ._utils import github_repo_path
class GithubListTagsBlock(Block):
@@ -88,7 +89,7 @@ class GithubListTagsBlock(Block):
tags_url = repo_url + "/tags"
response = await api.get(tags_url)
data = response.json()
repo_path = github_repo_path(repo_url)
repo_path = repo_url.replace("https://github.com/", "")
tags: list[GithubListTagsBlock.Output.TagItem] = [
{
"name": tag["name"],
@@ -114,6 +115,101 @@ class GithubListTagsBlock(Block):
yield "tag", tag
class GithubListBranchesBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
class Output(BlockSchemaOutput):
class BranchItem(TypedDict):
name: str
url: str
branch: BranchItem = SchemaField(
title="Branch",
description="Branches with their name and file tree browser URL",
)
branches: list[BranchItem] = SchemaField(
description="List of branches with their name and file tree browser URL"
)
def __init__(self):
super().__init__(
id="74243e49-2bec-4916-8bf4-db43d44aead5",
description="This block lists all branches for a specified GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListBranchesBlock.Input,
output_schema=GithubListBranchesBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"branches",
[
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
}
],
),
(
"branch",
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
},
),
],
test_mock={
"list_branches": lambda *args, **kwargs: [
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
}
]
},
)
@staticmethod
async def list_branches(
credentials: GithubCredentials, repo_url: str
) -> list[Output.BranchItem]:
api = get_api(credentials)
branches_url = repo_url + "/branches"
response = await api.get(branches_url)
data = response.json()
repo_path = repo_url.replace("https://github.com/", "")
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
{
"name": branch["name"],
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
}
for branch in data
]
return branches
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
branches = await self.list_branches(
credentials,
input_data.repo_url,
)
yield "branches", branches
for branch in branches:
yield "branch", branch
class GithubListDiscussionsBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
@@ -187,7 +283,7 @@ class GithubListDiscussionsBlock(Block):
) -> list[Output.DiscussionItem]:
api = get_api(credentials)
# GitHub GraphQL API endpoint is different; we'll use api.post with custom URL
repo_path = github_repo_path(repo_url)
repo_path = repo_url.replace("https://github.com/", "")
owner, repo = repo_path.split("/")
query = """
query($owner: String!, $repo: String!, $num: Int!) {
@@ -320,6 +416,564 @@ class GithubListReleasesBlock(Block):
yield "release", release
class GithubReadFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path to the file in the repository",
placeholder="path/to/file",
)
branch: str = SchemaField(
description="Branch to read from",
placeholder="branch_name",
default="master",
)
class Output(BlockSchemaOutput):
text_content: str = SchemaField(
description="Content of the file (decoded as UTF-8 text)"
)
raw_content: str = SchemaField(
description="Raw base64-encoded content of the file"
)
size: int = SchemaField(description="The size of the file (in bytes)")
def __init__(self):
super().__init__(
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
description="This block reads the content of a specified file from a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubReadFileBlock.Input,
output_schema=GithubReadFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "path/to/file",
"branch": "master",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("raw_content", "RmlsZSBjb250ZW50"),
("text_content", "File content"),
("size", 13),
],
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
)
@staticmethod
async def read_file(
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
) -> tuple[str, int]:
api = get_api(credentials)
content_url = repo_url + f"/contents/{file_path}?ref={branch}"
response = await api.get(content_url)
data = response.json()
if isinstance(data, list):
# Multiple entries of different types exist at this path
if not (file := next((f for f in data if f["type"] == "file"), None)):
raise TypeError("Not a file")
data = file
if data["type"] != "file":
raise TypeError("Not a file")
return data["content"], data["size"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
content, size = await self.read_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.branch,
)
yield "raw_content", content
yield "text_content", base64.b64decode(content).decode("utf-8")
yield "size", size
class GithubReadFolderBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
folder_path: str = SchemaField(
description="Path to the folder in the repository",
placeholder="path/to/folder",
)
branch: str = SchemaField(
description="Branch name to read from (defaults to master)",
placeholder="branch_name",
default="master",
)
class Output(BlockSchemaOutput):
class DirEntry(TypedDict):
name: str
path: str
class FileEntry(TypedDict):
name: str
path: str
size: int
file: FileEntry = SchemaField(description="Files in the folder")
dir: DirEntry = SchemaField(description="Directories in the folder")
error: str = SchemaField(
description="Error message if reading the folder failed"
)
def __init__(self):
super().__init__(
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
description="This block reads the content of a specified folder from a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubReadFolderBlock.Input,
output_schema=GithubReadFolderBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"folder_path": "path/to/folder",
"branch": "master",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"file",
{
"name": "file1.txt",
"path": "path/to/folder/file1.txt",
"size": 1337,
},
),
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
],
test_mock={
"read_folder": lambda *args, **kwargs: (
[
{
"name": "file1.txt",
"path": "path/to/folder/file1.txt",
"size": 1337,
}
],
[{"name": "dir2", "path": "path/to/folder/dir2"}],
)
},
)
@staticmethod
async def read_folder(
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{folder_path}?ref={branch}"
response = await api.get(contents_url)
data = response.json()
if not isinstance(data, list):
raise TypeError("Not a folder")
files: list[GithubReadFolderBlock.Output.FileEntry] = [
GithubReadFolderBlock.Output.FileEntry(
name=entry["name"],
path=entry["path"],
size=entry["size"],
)
for entry in data
if entry["type"] == "file"
]
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
GithubReadFolderBlock.Output.DirEntry(
name=entry["name"],
path=entry["path"],
)
for entry in data
if entry["type"] == "dir"
]
return files, dirs
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
files, dirs = await self.read_folder(
credentials,
input_data.repo_url,
input_data.folder_path.lstrip("/"),
input_data.branch,
)
for file in files:
yield "file", file
for dir in dirs:
yield "dir", dir
class GithubMakeBranchBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
new_branch: str = SchemaField(
description="Name of the new branch",
placeholder="new_branch_name",
)
source_branch: str = SchemaField(
description="Name of the source branch",
placeholder="source_branch_name",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the branch creation operation")
error: str = SchemaField(
description="Error message if the branch creation failed"
)
def __init__(self):
super().__init__(
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
description="This block creates a new branch from a specified source branch.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubMakeBranchBlock.Input,
output_schema=GithubMakeBranchBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"new_branch": "new_branch_name",
"source_branch": "source_branch_name",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Branch created successfully")],
test_mock={
"create_branch": lambda *args, **kwargs: "Branch created successfully"
},
)
@staticmethod
async def create_branch(
credentials: GithubCredentials,
repo_url: str,
new_branch: str,
source_branch: str,
) -> str:
api = get_api(credentials)
ref_url = repo_url + f"/git/refs/heads/{source_branch}"
response = await api.get(ref_url)
data = response.json()
sha = data["object"]["sha"]
# Create the new branch
new_ref_url = repo_url + "/git/refs"
data = {
"ref": f"refs/heads/{new_branch}",
"sha": sha,
}
response = await api.post(new_ref_url, json=data)
return "Branch created successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
status = await self.create_branch(
credentials,
input_data.repo_url,
input_data.new_branch,
input_data.source_branch,
)
yield "status", status
class GithubDeleteBranchBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Name of the branch to delete",
placeholder="branch_name",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the branch deletion operation")
error: str = SchemaField(
description="Error message if the branch deletion failed"
)
def __init__(self):
super().__init__(
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
description="This block deletes a specified branch.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubDeleteBranchBlock.Input,
output_schema=GithubDeleteBranchBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "branch_name",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Branch deleted successfully")],
test_mock={
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
},
)
@staticmethod
async def delete_branch(
credentials: GithubCredentials, repo_url: str, branch: str
) -> str:
api = get_api(credentials)
ref_url = repo_url + f"/git/refs/heads/{branch}"
await api.delete(ref_url)
return "Branch deleted successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
status = await self.delete_branch(
credentials,
input_data.repo_url,
input_data.branch,
)
yield "status", status
class GithubCreateFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path where the file should be created",
placeholder="path/to/file.txt",
)
content: str = SchemaField(
description="Content to write to the file",
placeholder="File content here",
)
branch: str = SchemaField(
description="Branch where the file should be created",
default="main",
)
commit_message: str = SchemaField(
description="Message for the commit",
default="Create new file",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the created file")
sha: str = SchemaField(description="SHA of the commit")
error: str = SchemaField(
description="Error message if the file creation failed"
)
def __init__(self):
super().__init__(
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
description="This block creates a new file in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubCreateFileBlock.Input,
output_schema=GithubCreateFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "test/file.txt",
"content": "Test content",
"branch": "main",
"commit_message": "Create test file",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
("sha", "abc123"),
],
test_mock={
"create_file": lambda *args, **kwargs: (
"https://github.com/owner/repo/blob/main/test/file.txt",
"abc123",
)
},
)
@staticmethod
async def create_file(
credentials: GithubCredentials,
repo_url: str,
file_path: str,
content: str,
branch: str,
commit_message: str,
) -> tuple[str, str]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{file_path}"
content_base64 = base64.b64encode(content.encode()).decode()
data = {
"message": commit_message,
"content": content_base64,
"branch": branch,
}
response = await api.put(contents_url, json=data)
data = response.json()
return data["content"]["html_url"], data["commit"]["sha"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, sha = await self.create_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.content,
input_data.branch,
input_data.commit_message,
)
yield "url", url
yield "sha", sha
except Exception as e:
yield "error", str(e)
class GithubUpdateFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path to the file to update",
placeholder="path/to/file.txt",
)
content: str = SchemaField(
description="New content for the file",
placeholder="Updated content here",
)
branch: str = SchemaField(
description="Branch containing the file",
default="main",
)
commit_message: str = SchemaField(
description="Message for the commit",
default="Update file",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the updated file")
sha: str = SchemaField(description="SHA of the commit")
def __init__(self):
super().__init__(
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
description="This block updates an existing file in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubUpdateFileBlock.Input,
output_schema=GithubUpdateFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "test/file.txt",
"content": "Updated content",
"branch": "main",
"commit_message": "Update test file",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
("sha", "def456"),
],
test_mock={
"update_file": lambda *args, **kwargs: (
"https://github.com/owner/repo/blob/main/test/file.txt",
"def456",
)
},
)
@staticmethod
async def update_file(
credentials: GithubCredentials,
repo_url: str,
file_path: str,
content: str,
branch: str,
commit_message: str,
) -> tuple[str, str]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{file_path}"
params = {"ref": branch}
response = await api.get(contents_url, params=params)
data = response.json()
# Convert new content to base64
content_base64 = base64.b64encode(content.encode()).decode()
data = {
"message": commit_message,
"content": content_base64,
"sha": data["sha"],
"branch": branch,
}
response = await api.put(contents_url, json=data)
data = response.json()
return data["content"]["html_url"], data["commit"]["sha"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, sha = await self.update_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.content,
input_data.branch,
input_data.commit_message,
)
yield "url", url
yield "sha", sha
except Exception as e:
yield "error", str(e)
class GithubCreateRepositoryBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
@@ -449,7 +1103,7 @@ class GithubListStargazersBlock(Block):
def __init__(self):
super().__init__(
id="e96d01ec-b55e-4a99-8ce8-c8776dce850b", # Generated unique UUID
id="a4b9c2d1-e5f6-4g7h-8i9j-0k1l2m3n4o5p", # Generated unique UUID
description="This block lists all users who have starred a specified GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListStargazersBlock.Input,
@@ -518,230 +1172,3 @@ class GithubListStargazersBlock(Block):
yield "stargazers", stargazers
for stargazer in stargazers:
yield "stargazer", stargazer
class GithubGetRepositoryInfoBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
class Output(BlockSchemaOutput):
name: str = SchemaField(description="Repository name")
full_name: str = SchemaField(description="Full repository name (owner/repo)")
description: str = SchemaField(description="Repository description")
default_branch: str = SchemaField(description="Default branch name (e.g. main)")
private: bool = SchemaField(description="Whether the repository is private")
html_url: str = SchemaField(description="Web URL of the repository")
clone_url: str = SchemaField(description="Git clone URL")
stars: int = SchemaField(description="Number of stars")
forks: int = SchemaField(description="Number of forks")
open_issues: int = SchemaField(description="Number of open issues")
error: str = SchemaField(
description="Error message if fetching repo info failed"
)
def __init__(self):
super().__init__(
id="59d4f241-968a-4040-95da-348ac5c5ce27",
description="This block retrieves metadata about a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubGetRepositoryInfoBlock.Input,
output_schema=GithubGetRepositoryInfoBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("name", "repo"),
("full_name", "owner/repo"),
("description", "A test repo"),
("default_branch", "main"),
("private", False),
("html_url", "https://github.com/owner/repo"),
("clone_url", "https://github.com/owner/repo.git"),
("stars", 42),
("forks", 5),
("open_issues", 3),
],
test_mock={
"get_repo_info": lambda *args, **kwargs: {
"name": "repo",
"full_name": "owner/repo",
"description": "A test repo",
"default_branch": "main",
"private": False,
"html_url": "https://github.com/owner/repo",
"clone_url": "https://github.com/owner/repo.git",
"stargazers_count": 42,
"forks_count": 5,
"open_issues_count": 3,
}
},
)
@staticmethod
async def get_repo_info(credentials: GithubCredentials, repo_url: str) -> dict:
api = get_api(credentials)
response = await api.get(repo_url)
return response.json()
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
data = await self.get_repo_info(credentials, input_data.repo_url)
yield "name", data["name"]
yield "full_name", data["full_name"]
yield "description", data.get("description", "") or ""
yield "default_branch", data["default_branch"]
yield "private", data["private"]
yield "html_url", data["html_url"]
yield "clone_url", data["clone_url"]
yield "stars", data["stargazers_count"]
yield "forks", data["forks_count"]
yield "open_issues", data["open_issues_count"]
except Exception as e:
yield "error", str(e)
class GithubForkRepositoryBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository to fork",
placeholder="https://github.com/owner/repo",
)
organization: str = SchemaField(
description="Organization to fork into (leave empty to fork to your account)",
default="",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the forked repository")
clone_url: str = SchemaField(description="Git clone URL of the fork")
full_name: str = SchemaField(description="Full name of the fork (owner/repo)")
error: str = SchemaField(description="Error message if the fork failed")
def __init__(self):
super().__init__(
id="a439f2f4-835f-4dae-ba7b-0205ffa70be6",
description="This block forks a GitHub repository to your account or an organization.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubForkRepositoryBlock.Input,
output_schema=GithubForkRepositoryBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"organization": "",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/myuser/repo"),
("clone_url", "https://github.com/myuser/repo.git"),
("full_name", "myuser/repo"),
],
test_mock={
"fork_repo": lambda *args, **kwargs: (
"https://github.com/myuser/repo",
"https://github.com/myuser/repo.git",
"myuser/repo",
)
},
)
@staticmethod
async def fork_repo(
credentials: GithubCredentials,
repo_url: str,
organization: str,
) -> tuple[str, str, str]:
api = get_api(credentials)
forks_url = repo_url + "/forks"
data: dict[str, str] = {}
if organization:
data["organization"] = organization
response = await api.post(forks_url, json=data)
result = response.json()
return result["html_url"], result["clone_url"], result["full_name"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, clone_url, full_name = await self.fork_repo(
credentials,
input_data.repo_url,
input_data.organization,
)
yield "url", url
yield "clone_url", clone_url
yield "full_name", full_name
except Exception as e:
yield "error", str(e)
class GithubStarRepositoryBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository to star",
placeholder="https://github.com/owner/repo",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the star operation")
error: str = SchemaField(description="Error message if starring failed")
def __init__(self):
super().__init__(
id="bd700764-53e3-44dd-a969-d1854088458f",
description="This block stars a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubStarRepositoryBlock.Input,
output_schema=GithubStarRepositoryBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Repository starred successfully")],
test_mock={
"star_repo": lambda *args, **kwargs: "Repository starred successfully"
},
)
@staticmethod
async def star_repo(credentials: GithubCredentials, repo_url: str) -> str:
api = get_api(credentials, convert_urls=False)
repo_path = github_repo_path(repo_url)
owner, repo = repo_path.split("/")
await api.put(
f"https://api.github.com/user/starred/{owner}/{repo}",
headers={"Content-Length": "0"},
)
return "Repository starred successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
status = await self.star_repo(credentials, input_data.repo_url)
yield "status", status
except Exception as e:
yield "error", str(e)

View File

@@ -1,452 +0,0 @@
from urllib.parse import quote
from typing_extensions import TypedDict
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import SchemaField
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GithubCredentials,
GithubCredentialsField,
GithubCredentialsInput,
)
from ._utils import github_repo_path
class GithubListBranchesBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
per_page: int = SchemaField(
description="Number of branches to return per page (max 100)",
default=30,
ge=1,
le=100,
)
page: int = SchemaField(
description="Page number for pagination",
default=1,
ge=1,
)
class Output(BlockSchemaOutput):
class BranchItem(TypedDict):
name: str
url: str
branch: BranchItem = SchemaField(
title="Branch",
description="Branches with their name and file tree browser URL",
)
branches: list[BranchItem] = SchemaField(
description="List of branches with their name and file tree browser URL"
)
error: str = SchemaField(description="Error message if listing branches failed")
def __init__(self):
super().__init__(
id="74243e49-2bec-4916-8bf4-db43d44aead5",
description="This block lists all branches for a specified GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListBranchesBlock.Input,
output_schema=GithubListBranchesBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"per_page": 30,
"page": 1,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"branches",
[
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
}
],
),
(
"branch",
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
},
),
],
test_mock={
"list_branches": lambda *args, **kwargs: [
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
}
]
},
)
@staticmethod
async def list_branches(
credentials: GithubCredentials, repo_url: str, per_page: int, page: int
) -> list[Output.BranchItem]:
api = get_api(credentials)
branches_url = repo_url + "/branches"
response = await api.get(
branches_url, params={"per_page": str(per_page), "page": str(page)}
)
data = response.json()
repo_path = github_repo_path(repo_url)
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
{
"name": branch["name"],
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
}
for branch in data
]
return branches
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
branches = await self.list_branches(
credentials,
input_data.repo_url,
input_data.per_page,
input_data.page,
)
yield "branches", branches
for branch in branches:
yield "branch", branch
except Exception as e:
yield "error", str(e)
class GithubMakeBranchBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
new_branch: str = SchemaField(
description="Name of the new branch",
placeholder="new_branch_name",
)
source_branch: str = SchemaField(
description="Name of the source branch",
placeholder="source_branch_name",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the branch creation operation")
error: str = SchemaField(
description="Error message if the branch creation failed"
)
def __init__(self):
super().__init__(
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
description="This block creates a new branch from a specified source branch.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubMakeBranchBlock.Input,
output_schema=GithubMakeBranchBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"new_branch": "new_branch_name",
"source_branch": "source_branch_name",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Branch created successfully")],
test_mock={
"create_branch": lambda *args, **kwargs: "Branch created successfully"
},
)
@staticmethod
async def create_branch(
credentials: GithubCredentials,
repo_url: str,
new_branch: str,
source_branch: str,
) -> str:
api = get_api(credentials)
ref_url = repo_url + f"/git/refs/heads/{quote(source_branch, safe='')}"
response = await api.get(ref_url)
data = response.json()
sha = data["object"]["sha"]
# Create the new branch
new_ref_url = repo_url + "/git/refs"
data = {
"ref": f"refs/heads/{new_branch}",
"sha": sha,
}
response = await api.post(new_ref_url, json=data)
return "Branch created successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
status = await self.create_branch(
credentials,
input_data.repo_url,
input_data.new_branch,
input_data.source_branch,
)
yield "status", status
except Exception as e:
yield "error", str(e)
class GithubDeleteBranchBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Name of the branch to delete",
placeholder="branch_name",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the branch deletion operation")
error: str = SchemaField(
description="Error message if the branch deletion failed"
)
def __init__(self):
super().__init__(
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
description="This block deletes a specified branch.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubDeleteBranchBlock.Input,
output_schema=GithubDeleteBranchBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "branch_name",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Branch deleted successfully")],
test_mock={
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
},
is_sensitive_action=True,
)
@staticmethod
async def delete_branch(
credentials: GithubCredentials, repo_url: str, branch: str
) -> str:
api = get_api(credentials)
ref_url = repo_url + f"/git/refs/heads/{quote(branch, safe='')}"
await api.delete(ref_url)
return "Branch deleted successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
status = await self.delete_branch(
credentials,
input_data.repo_url,
input_data.branch,
)
yield "status", status
except Exception as e:
yield "error", str(e)
class GithubCompareBranchesBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
base: str = SchemaField(
description="Base branch or commit SHA",
placeholder="main",
)
head: str = SchemaField(
description="Head branch or commit SHA to compare against base",
placeholder="feature-branch",
)
class Output(BlockSchemaOutput):
class FileChange(TypedDict):
filename: str
status: str
additions: int
deletions: int
patch: str
status: str = SchemaField(
description="Comparison status: ahead, behind, diverged, or identical"
)
ahead_by: int = SchemaField(
description="Number of commits head is ahead of base"
)
behind_by: int = SchemaField(
description="Number of commits head is behind base"
)
total_commits: int = SchemaField(
description="Total number of commits in the comparison"
)
diff: str = SchemaField(description="Unified diff of all file changes")
file: FileChange = SchemaField(
title="Changed File", description="A changed file with its diff"
)
files: list[FileChange] = SchemaField(
description="List of changed files with their diffs"
)
error: str = SchemaField(description="Error message if comparison failed")
def __init__(self):
super().__init__(
id="2e4faa8c-6086-4546-ba77-172d1d560186",
description="This block compares two branches or commits in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubCompareBranchesBlock.Input,
output_schema=GithubCompareBranchesBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"base": "main",
"head": "feature",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("status", "ahead"),
("ahead_by", 2),
("behind_by", 0),
("total_commits", 2),
("diff", "+++ b/file.py\n+new line"),
(
"files",
[
{
"filename": "file.py",
"status": "modified",
"additions": 1,
"deletions": 0,
"patch": "+new line",
}
],
),
(
"file",
{
"filename": "file.py",
"status": "modified",
"additions": 1,
"deletions": 0,
"patch": "+new line",
},
),
],
test_mock={
"compare_branches": lambda *args, **kwargs: {
"status": "ahead",
"ahead_by": 2,
"behind_by": 0,
"total_commits": 2,
"files": [
{
"filename": "file.py",
"status": "modified",
"additions": 1,
"deletions": 0,
"patch": "+new line",
}
],
}
},
)
@staticmethod
async def compare_branches(
credentials: GithubCredentials,
repo_url: str,
base: str,
head: str,
) -> dict:
api = get_api(credentials)
safe_base = quote(base, safe="")
safe_head = quote(head, safe="")
compare_url = repo_url + f"/compare/{safe_base}...{safe_head}"
response = await api.get(compare_url)
return response.json()
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
data = await self.compare_branches(
credentials,
input_data.repo_url,
input_data.base,
input_data.head,
)
yield "status", data["status"]
yield "ahead_by", data["ahead_by"]
yield "behind_by", data["behind_by"]
yield "total_commits", data["total_commits"]
files: list[GithubCompareBranchesBlock.Output.FileChange] = [
GithubCompareBranchesBlock.Output.FileChange(
filename=f["filename"],
status=f["status"],
additions=f["additions"],
deletions=f["deletions"],
patch=f.get("patch", ""),
)
for f in data.get("files", [])
]
# Build unified diff
diff_parts = []
for f in data.get("files", []):
patch = f.get("patch", "")
if patch:
diff_parts.append(f"+++ b/{f['filename']}\n{patch}")
yield "diff", "\n".join(diff_parts)
yield "files", files
for file in files:
yield "file", file
except Exception as e:
yield "error", str(e)

View File

@@ -1,720 +0,0 @@
import base64
from urllib.parse import quote
from typing_extensions import TypedDict
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import SchemaField
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GithubCredentials,
GithubCredentialsField,
GithubCredentialsInput,
)
class GithubReadFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path to the file in the repository",
placeholder="path/to/file",
)
branch: str = SchemaField(
description="Branch to read from",
placeholder="branch_name",
default="main",
)
class Output(BlockSchemaOutput):
text_content: str = SchemaField(
description="Content of the file (decoded as UTF-8 text)"
)
raw_content: str = SchemaField(
description="Raw base64-encoded content of the file"
)
size: int = SchemaField(description="The size of the file (in bytes)")
error: str = SchemaField(description="Error message if reading the file failed")
def __init__(self):
super().__init__(
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
description="This block reads the content of a specified file from a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubReadFileBlock.Input,
output_schema=GithubReadFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "path/to/file",
"branch": "main",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("raw_content", "RmlsZSBjb250ZW50"),
("text_content", "File content"),
("size", 13),
],
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
)
@staticmethod
async def read_file(
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
) -> tuple[str, int]:
api = get_api(credentials)
content_url = (
repo_url
+ f"/contents/{quote(file_path, safe='')}?ref={quote(branch, safe='')}"
)
response = await api.get(content_url)
data = response.json()
if isinstance(data, list):
# Multiple entries of different types exist at this path
if not (file := next((f for f in data if f["type"] == "file"), None)):
raise TypeError("Not a file")
data = file
if data["type"] != "file":
raise TypeError("Not a file")
return data["content"], data["size"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
content, size = await self.read_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.branch,
)
yield "raw_content", content
yield "text_content", base64.b64decode(content).decode("utf-8")
yield "size", size
except Exception as e:
yield "error", str(e)
class GithubReadFolderBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
folder_path: str = SchemaField(
description="Path to the folder in the repository",
placeholder="path/to/folder",
)
branch: str = SchemaField(
description="Branch name to read from (defaults to main)",
placeholder="branch_name",
default="main",
)
class Output(BlockSchemaOutput):
class DirEntry(TypedDict):
name: str
path: str
class FileEntry(TypedDict):
name: str
path: str
size: int
file: FileEntry = SchemaField(description="Files in the folder")
dir: DirEntry = SchemaField(description="Directories in the folder")
error: str = SchemaField(
description="Error message if reading the folder failed"
)
def __init__(self):
super().__init__(
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
description="This block reads the content of a specified folder from a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubReadFolderBlock.Input,
output_schema=GithubReadFolderBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"folder_path": "path/to/folder",
"branch": "main",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"file",
{
"name": "file1.txt",
"path": "path/to/folder/file1.txt",
"size": 1337,
},
),
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
],
test_mock={
"read_folder": lambda *args, **kwargs: (
[
{
"name": "file1.txt",
"path": "path/to/folder/file1.txt",
"size": 1337,
}
],
[{"name": "dir2", "path": "path/to/folder/dir2"}],
)
},
)
@staticmethod
async def read_folder(
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
api = get_api(credentials)
contents_url = (
repo_url
+ f"/contents/{quote(folder_path, safe='/')}?ref={quote(branch, safe='')}"
)
response = await api.get(contents_url)
data = response.json()
if not isinstance(data, list):
raise TypeError("Not a folder")
files: list[GithubReadFolderBlock.Output.FileEntry] = [
GithubReadFolderBlock.Output.FileEntry(
name=entry["name"],
path=entry["path"],
size=entry["size"],
)
for entry in data
if entry["type"] == "file"
]
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
GithubReadFolderBlock.Output.DirEntry(
name=entry["name"],
path=entry["path"],
)
for entry in data
if entry["type"] == "dir"
]
return files, dirs
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
files, dirs = await self.read_folder(
credentials,
input_data.repo_url,
input_data.folder_path.lstrip("/"),
input_data.branch,
)
for file in files:
yield "file", file
for dir in dirs:
yield "dir", dir
except Exception as e:
yield "error", str(e)
class GithubCreateFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path where the file should be created",
placeholder="path/to/file.txt",
)
content: str = SchemaField(
description="Content to write to the file",
placeholder="File content here",
)
branch: str = SchemaField(
description="Branch where the file should be created",
default="main",
)
commit_message: str = SchemaField(
description="Message for the commit",
default="Create new file",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the created file")
sha: str = SchemaField(description="SHA of the commit")
error: str = SchemaField(
description="Error message if the file creation failed"
)
def __init__(self):
super().__init__(
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
description="This block creates a new file in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubCreateFileBlock.Input,
output_schema=GithubCreateFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "test/file.txt",
"content": "Test content",
"branch": "main",
"commit_message": "Create test file",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
("sha", "abc123"),
],
test_mock={
"create_file": lambda *args, **kwargs: (
"https://github.com/owner/repo/blob/main/test/file.txt",
"abc123",
)
},
)
@staticmethod
async def create_file(
credentials: GithubCredentials,
repo_url: str,
file_path: str,
content: str,
branch: str,
commit_message: str,
) -> tuple[str, str]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
content_base64 = base64.b64encode(content.encode()).decode()
data = {
"message": commit_message,
"content": content_base64,
"branch": branch,
}
response = await api.put(contents_url, json=data)
data = response.json()
return data["content"]["html_url"], data["commit"]["sha"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, sha = await self.create_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.content,
input_data.branch,
input_data.commit_message,
)
yield "url", url
yield "sha", sha
except Exception as e:
yield "error", str(e)
class GithubUpdateFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path to the file to update",
placeholder="path/to/file.txt",
)
content: str = SchemaField(
description="New content for the file",
placeholder="Updated content here",
)
branch: str = SchemaField(
description="Branch containing the file",
default="main",
)
commit_message: str = SchemaField(
description="Message for the commit",
default="Update file",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the updated file")
sha: str = SchemaField(description="SHA of the commit")
def __init__(self):
super().__init__(
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
description="This block updates an existing file in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubUpdateFileBlock.Input,
output_schema=GithubUpdateFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "test/file.txt",
"content": "Updated content",
"branch": "main",
"commit_message": "Update test file",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
("sha", "def456"),
],
test_mock={
"update_file": lambda *args, **kwargs: (
"https://github.com/owner/repo/blob/main/test/file.txt",
"def456",
)
},
)
@staticmethod
async def update_file(
credentials: GithubCredentials,
repo_url: str,
file_path: str,
content: str,
branch: str,
commit_message: str,
) -> tuple[str, str]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
params = {"ref": branch}
response = await api.get(contents_url, params=params)
data = response.json()
# Convert new content to base64
content_base64 = base64.b64encode(content.encode()).decode()
data = {
"message": commit_message,
"content": content_base64,
"sha": data["sha"],
"branch": branch,
}
response = await api.put(contents_url, json=data)
data = response.json()
return data["content"]["html_url"], data["commit"]["sha"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, sha = await self.update_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.content,
input_data.branch,
input_data.commit_message,
)
yield "url", url
yield "sha", sha
except Exception as e:
yield "error", str(e)
class GithubSearchCodeBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
query: str = SchemaField(
description="Search query (GitHub code search syntax)",
placeholder="className language:python",
)
repo: str = SchemaField(
description="Restrict search to a repository (owner/repo format, optional)",
default="",
placeholder="owner/repo",
)
per_page: int = SchemaField(
description="Number of results to return (max 100)",
default=30,
ge=1,
le=100,
)
class Output(BlockSchemaOutput):
class SearchResult(TypedDict):
name: str
path: str
repository: str
url: str
score: float
result: SearchResult = SchemaField(
title="Result", description="A code search result"
)
results: list[SearchResult] = SchemaField(
description="List of code search results"
)
total_count: int = SchemaField(description="Total number of matching results")
error: str = SchemaField(description="Error message if search failed")
def __init__(self):
super().__init__(
id="47f94891-a2b1-4f1c-b5f2-573c043f721e",
description="This block searches for code in GitHub repositories.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubSearchCodeBlock.Input,
output_schema=GithubSearchCodeBlock.Output,
test_input={
"query": "addClass",
"repo": "owner/repo",
"per_page": 30,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("total_count", 1),
(
"results",
[
{
"name": "file.py",
"path": "src/file.py",
"repository": "owner/repo",
"url": "https://github.com/owner/repo/blob/main/src/file.py",
"score": 1.0,
}
],
),
(
"result",
{
"name": "file.py",
"path": "src/file.py",
"repository": "owner/repo",
"url": "https://github.com/owner/repo/blob/main/src/file.py",
"score": 1.0,
},
),
],
test_mock={
"search_code": lambda *args, **kwargs: (
1,
[
{
"name": "file.py",
"path": "src/file.py",
"repository": "owner/repo",
"url": "https://github.com/owner/repo/blob/main/src/file.py",
"score": 1.0,
}
],
)
},
)
@staticmethod
async def search_code(
credentials: GithubCredentials,
query: str,
repo: str,
per_page: int,
) -> tuple[int, list[Output.SearchResult]]:
api = get_api(credentials, convert_urls=False)
full_query = f"{query} repo:{repo}" if repo else query
params = {"q": full_query, "per_page": str(per_page)}
response = await api.get("https://api.github.com/search/code", params=params)
data = response.json()
results: list[GithubSearchCodeBlock.Output.SearchResult] = [
GithubSearchCodeBlock.Output.SearchResult(
name=item["name"],
path=item["path"],
repository=item["repository"]["full_name"],
url=item["html_url"],
score=item["score"],
)
for item in data["items"]
]
return data["total_count"], results
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
total_count, results = await self.search_code(
credentials,
input_data.query,
input_data.repo,
input_data.per_page,
)
yield "total_count", total_count
yield "results", results
for result in results:
yield "result", result
except Exception as e:
yield "error", str(e)
class GithubGetRepositoryTreeBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Branch name to get the tree from",
default="main",
)
recursive: bool = SchemaField(
description="Whether to recursively list the entire tree",
default=True,
)
class Output(BlockSchemaOutput):
class TreeEntry(TypedDict):
path: str
type: str
size: int
sha: str
entry: TreeEntry = SchemaField(
title="Tree Entry", description="A file or directory in the tree"
)
entries: list[TreeEntry] = SchemaField(
description="List of all files and directories in the tree"
)
truncated: bool = SchemaField(
description="Whether the tree was truncated due to size"
)
error: str = SchemaField(description="Error message if getting tree failed")
def __init__(self):
super().__init__(
id="89c5c0ec-172e-4001-a32c-bdfe4d0c9e81",
description="This block lists the entire file tree of a GitHub repository recursively.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubGetRepositoryTreeBlock.Input,
output_schema=GithubGetRepositoryTreeBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "main",
"recursive": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("truncated", False),
(
"entries",
[
{
"path": "src/main.py",
"type": "blob",
"size": 1234,
"sha": "abc123",
}
],
),
(
"entry",
{
"path": "src/main.py",
"type": "blob",
"size": 1234,
"sha": "abc123",
},
),
],
test_mock={
"get_tree": lambda *args, **kwargs: (
False,
[
{
"path": "src/main.py",
"type": "blob",
"size": 1234,
"sha": "abc123",
}
],
)
},
)
@staticmethod
async def get_tree(
credentials: GithubCredentials,
repo_url: str,
branch: str,
recursive: bool,
) -> tuple[bool, list[Output.TreeEntry]]:
api = get_api(credentials)
tree_url = repo_url + f"/git/trees/{quote(branch, safe='')}"
params = {"recursive": "1"} if recursive else {}
response = await api.get(tree_url, params=params)
data = response.json()
entries: list[GithubGetRepositoryTreeBlock.Output.TreeEntry] = [
GithubGetRepositoryTreeBlock.Output.TreeEntry(
path=item["path"],
type=item["type"],
size=item.get("size", 0),
sha=item["sha"],
)
for item in data["tree"]
]
return data.get("truncated", False), entries
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
truncated, entries = await self.get_tree(
credentials,
input_data.repo_url,
input_data.branch,
input_data.recursive,
)
yield "truncated", truncated
yield "entries", entries
for entry in entries:
yield "entry", entry
except Exception as e:
yield "error", str(e)

View File

@@ -1,120 +0,0 @@
import inspect
import pytest
from backend.blocks.github._auth import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
from backend.blocks.github.commits import FileOperation, GithubMultiFileCommitBlock
from backend.blocks.github.pull_requests import (
GithubMergePullRequestBlock,
prepare_pr_api_url,
)
from backend.util.exceptions import BlockExecutionError
# ── prepare_pr_api_url tests ──
class TestPreparePrApiUrl:
def test_https_scheme_preserved(self):
result = prepare_pr_api_url("https://github.com/owner/repo/pull/42", "merge")
assert result == "https://github.com/owner/repo/pulls/42/merge"
def test_http_scheme_preserved(self):
result = prepare_pr_api_url("http://github.com/owner/repo/pull/1", "files")
assert result == "http://github.com/owner/repo/pulls/1/files"
def test_no_scheme_defaults_to_https(self):
result = prepare_pr_api_url("github.com/owner/repo/pull/5", "merge")
assert result == "https://github.com/owner/repo/pulls/5/merge"
def test_reviewers_path(self):
result = prepare_pr_api_url(
"https://github.com/owner/repo/pull/99", "requested_reviewers"
)
assert result == "https://github.com/owner/repo/pulls/99/requested_reviewers"
def test_invalid_url_returned_as_is(self):
url = "https://example.com/not-a-pr"
assert prepare_pr_api_url(url, "merge") == url
def test_empty_string(self):
assert prepare_pr_api_url("", "merge") == ""
# ── Error-path block tests ──
# When a block's run() yields ("error", msg), _execute() converts it to a
# BlockExecutionError. We call block.execute() directly (not execute_block_test,
# which returns early on empty test_output).
def _mock_block(block, mocks: dict):
"""Apply mocks to a block's static methods, wrapping sync mocks as async."""
for name, mock_fn in mocks.items():
original = getattr(block, name)
if inspect.iscoroutinefunction(original):
async def async_mock(*args, _fn=mock_fn, **kwargs):
return _fn(*args, **kwargs)
setattr(block, name, async_mock)
else:
setattr(block, name, mock_fn)
def _raise(exc: Exception):
"""Helper that returns a callable which raises the given exception."""
def _raiser(*args, **kwargs):
raise exc
return _raiser
@pytest.mark.asyncio
async def test_merge_pr_error_path():
block = GithubMergePullRequestBlock()
_mock_block(block, {"merge_pr": _raise(RuntimeError("PR not mergeable"))})
input_data = {
"pr_url": "https://github.com/owner/repo/pull/1",
"merge_method": "squash",
"commit_title": "",
"commit_message": "",
"credentials": TEST_CREDENTIALS_INPUT,
}
with pytest.raises(BlockExecutionError, match="PR not mergeable"):
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_multi_file_commit_error_path():
block = GithubMultiFileCommitBlock()
_mock_block(block, {"multi_file_commit": _raise(RuntimeError("ref update failed"))})
input_data = {
"repo_url": "https://github.com/owner/repo",
"branch": "feature",
"commit_message": "test",
"files": [{"path": "a.py", "content": "x", "operation": "upsert"}],
"credentials": TEST_CREDENTIALS_INPUT,
}
with pytest.raises(BlockExecutionError, match="ref update failed"):
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
pass
# ── FileOperation enum tests ──
class TestFileOperation:
def test_upsert_value(self):
assert FileOperation.UPSERT == "upsert"
def test_delete_value(self):
assert FileOperation.DELETE == "delete"
def test_invalid_value_raises(self):
with pytest.raises(ValueError):
FileOperation("create")
def test_invalid_value_raises_typo(self):
with pytest.raises(ValueError):
FileOperation("upser")

View File

@@ -241,8 +241,8 @@ class GmailBase(Block, ABC):
h.ignore_links = False
h.ignore_images = True
return h.handle(html_content)
except Exception:
# Keep extraction resilient if html2text is unavailable or fails.
except ImportError:
# Fallback: return raw HTML if html2text is not available
return html_content
# Handle content stored as attachment

View File

@@ -140,31 +140,19 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
# OpenRouter models
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
GEMINI_2_5_PRO_PREVIEW = "google/gemini-2.5-pro-preview-03-25"
GEMINI_2_5_PRO = "google/gemini-2.5-pro"
GEMINI_3_1_PRO_PREVIEW = "google/gemini-3.1-pro-preview"
GEMINI_3_FLASH_PREVIEW = "google/gemini-3-flash-preview"
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
GEMINI_3_1_FLASH_LITE_PREVIEW = "google/gemini-3.1-flash-lite-preview"
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
MISTRAL_NEMO = "mistralai/mistral-nemo"
MISTRAL_LARGE_3 = "mistralai/mistral-large-2512"
MISTRAL_MEDIUM_3_1 = "mistralai/mistral-medium-3.1"
MISTRAL_SMALL_3_2 = "mistralai/mistral-small-3.2-24b-instruct"
CODESTRAL = "mistralai/codestral-2508"
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
COHERE_COMMAND_A_03_2025 = "cohere/command-a-03-2025"
COHERE_COMMAND_A_TRANSLATE_08_2025 = "cohere/command-a-translate-08-2025"
COHERE_COMMAND_A_REASONING_08_2025 = "cohere/command-a-reasoning-08-2025"
COHERE_COMMAND_A_VISION_07_2025 = "cohere/command-a-vision-07-2025"
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
PERPLEXITY_SONAR = "perplexity/sonar"
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
PERPLEXITY_SONAR_REASONING_PRO = "perplexity/sonar-reasoning-pro"
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
@@ -172,11 +160,9 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
MICROSOFT_PHI_4 = "microsoft/phi-4"
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
GROK_3 = "x-ai/grok-3"
GROK_4 = "x-ai/grok-4"
GROK_4_FAST = "x-ai/grok-4-fast"
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
@@ -354,41 +340,17 @@ MODEL_METADATA = {
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
),
# https://openrouter.ai/models
LlmModel.GEMINI_2_5_PRO_PREVIEW: ModelMetadata(
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
"open_router",
1048576,
65536,
1050000,
8192,
"Gemini 2.5 Pro Preview 03.25",
"OpenRouter",
"Google",
2,
),
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
"open_router",
1048576,
65536,
"Gemini 2.5 Pro",
"OpenRouter",
"Google",
2,
),
LlmModel.GEMINI_3_1_PRO_PREVIEW: ModelMetadata(
"open_router",
1048576,
65536,
"Gemini 3.1 Pro Preview",
"OpenRouter",
"Google",
2,
),
LlmModel.GEMINI_3_FLASH_PREVIEW: ModelMetadata(
"open_router",
1048576,
65536,
"Gemini 3 Flash Preview",
"OpenRouter",
"Google",
1,
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
),
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
@@ -396,15 +358,6 @@ MODEL_METADATA = {
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
),
LlmModel.GEMINI_3_1_FLASH_LITE_PREVIEW: ModelMetadata(
"open_router",
1048576,
65536,
"Gemini 3.1 Flash Lite Preview",
"OpenRouter",
"Google",
1,
),
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
"open_router",
1048576,
@@ -426,78 +379,12 @@ MODEL_METADATA = {
LlmModel.MISTRAL_NEMO: ModelMetadata(
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
),
LlmModel.MISTRAL_LARGE_3: ModelMetadata(
"open_router",
262144,
None,
"Mistral Large 3 2512",
"OpenRouter",
"Mistral AI",
2,
),
LlmModel.MISTRAL_MEDIUM_3_1: ModelMetadata(
"open_router",
131072,
None,
"Mistral Medium 3.1",
"OpenRouter",
"Mistral AI",
2,
),
LlmModel.MISTRAL_SMALL_3_2: ModelMetadata(
"open_router",
131072,
131072,
"Mistral Small 3.2 24B",
"OpenRouter",
"Mistral AI",
1,
),
LlmModel.CODESTRAL: ModelMetadata(
"open_router",
256000,
None,
"Codestral 2508",
"OpenRouter",
"Mistral AI",
1,
),
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
),
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
),
LlmModel.COHERE_COMMAND_A_03_2025: ModelMetadata(
"open_router", 256000, 8192, "Command A 03.2025", "OpenRouter", "Cohere", 2
),
LlmModel.COHERE_COMMAND_A_TRANSLATE_08_2025: ModelMetadata(
"open_router",
128000,
8192,
"Command A Translate 08.2025",
"OpenRouter",
"Cohere",
2,
),
LlmModel.COHERE_COMMAND_A_REASONING_08_2025: ModelMetadata(
"open_router",
256000,
32768,
"Command A Reasoning 08.2025",
"OpenRouter",
"Cohere",
3,
),
LlmModel.COHERE_COMMAND_A_VISION_07_2025: ModelMetadata(
"open_router",
128000,
8192,
"Command A Vision 07.2025",
"OpenRouter",
"Cohere",
2,
),
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
),
@@ -510,15 +397,6 @@ MODEL_METADATA = {
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
),
LlmModel.PERPLEXITY_SONAR_REASONING_PRO: ModelMetadata(
"open_router",
128000,
8000,
"Sonar Reasoning Pro",
"OpenRouter",
"Perplexity",
2,
),
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
"open_router",
128000,
@@ -564,9 +442,6 @@ MODEL_METADATA = {
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
),
LlmModel.MICROSOFT_PHI_4: ModelMetadata(
"open_router", 16384, 16384, "Phi-4", "OpenRouter", "Microsoft", 1
),
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
),
@@ -576,15 +451,6 @@ MODEL_METADATA = {
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
),
LlmModel.GROK_3: ModelMetadata(
"open_router",
131072,
131072,
"Grok 3",
"OpenRouter",
"xAI",
2,
),
LlmModel.GROK_4: ModelMetadata(
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
),

View File

@@ -4,7 +4,7 @@ from enum import Enum
from typing import Any, Literal
import openai
from pydantic import SecretStr, field_validator
from pydantic import SecretStr
from backend.blocks._base import (
Block,
@@ -13,7 +13,6 @@ from backend.blocks._base import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.block import BlockInput
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -36,20 +35,6 @@ class PerplexityModel(str, Enum):
SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
def _sanitize_perplexity_model(value: Any) -> PerplexityModel:
"""Return a valid PerplexityModel, falling back to SONAR for invalid values."""
if isinstance(value, PerplexityModel):
return value
try:
return PerplexityModel(value)
except ValueError:
logger.warning(
f"Invalid PerplexityModel '{value}', "
f"falling back to {PerplexityModel.SONAR.value}"
)
return PerplexityModel.SONAR
PerplexityCredentials = CredentialsMetaInput[
Literal[ProviderName.OPEN_ROUTER], Literal["api_key"]
]
@@ -88,25 +73,6 @@ class PerplexityBlock(Block):
advanced=False,
)
credentials: PerplexityCredentials = PerplexityCredentialsField()
@field_validator("model", mode="before")
@classmethod
def fallback_invalid_model(cls, v: Any) -> PerplexityModel:
"""Fall back to SONAR if the model value is not a valid
PerplexityModel (e.g. an OpenAI model ID set by the agent
generator)."""
return _sanitize_perplexity_model(v)
@classmethod
def validate_data(cls, data: BlockInput) -> str | None:
"""Sanitize the model field before JSON schema validation so that
invalid values are replaced with the default instead of raising a
BlockInputError."""
model_value = data.get("model")
if model_value is not None:
data["model"] = _sanitize_perplexity_model(model_value).value
return super().validate_data(data)
system_prompt: str = SchemaField(
title="System Prompt",
default="",

View File

@@ -2232,7 +2232,6 @@ class DeleteRedditPostBlock(Block):
("post_id", "abc123"),
],
test_mock={"delete_post": lambda creds, post_id: True},
is_sensitive_action=True,
)
@staticmethod
@@ -2291,7 +2290,6 @@ class DeleteRedditCommentBlock(Block):
("comment_id", "xyz789"),
],
test_mock={"delete_comment": lambda creds, comment_id: True},
is_sensitive_action=True,
)
@staticmethod

View File

@@ -72,7 +72,6 @@ class Slant3DCreateOrderBlock(Slant3DBlockBase):
"_make_request": lambda *args, **kwargs: {"orderId": "314144241"},
"_convert_to_color": lambda *args, **kwargs: "black",
},
is_sensitive_action=True,
)
async def run(

View File

@@ -1,81 +0,0 @@
"""Unit tests for PerplexityBlock model fallback behavior."""
import pytest
from backend.blocks.perplexity import (
TEST_CREDENTIALS_INPUT,
PerplexityBlock,
PerplexityModel,
)
def _make_input(**overrides) -> dict:
defaults = {
"prompt": "test query",
"credentials": TEST_CREDENTIALS_INPUT,
}
defaults.update(overrides)
return defaults
class TestPerplexityModelFallback:
"""Tests for fallback_invalid_model field_validator."""
def test_invalid_model_falls_back_to_sonar(self):
inp = PerplexityBlock.Input(**_make_input(model="gpt-5.2-2025-12-11"))
assert inp.model == PerplexityModel.SONAR
def test_another_invalid_model_falls_back_to_sonar(self):
inp = PerplexityBlock.Input(**_make_input(model="gpt-4o"))
assert inp.model == PerplexityModel.SONAR
def test_valid_model_string_is_kept(self):
inp = PerplexityBlock.Input(**_make_input(model="perplexity/sonar-pro"))
assert inp.model == PerplexityModel.SONAR_PRO
def test_valid_enum_value_is_kept(self):
inp = PerplexityBlock.Input(
**_make_input(model=PerplexityModel.SONAR_DEEP_RESEARCH)
)
assert inp.model == PerplexityModel.SONAR_DEEP_RESEARCH
def test_default_model_when_omitted(self):
inp = PerplexityBlock.Input(**_make_input())
assert inp.model == PerplexityModel.SONAR
@pytest.mark.parametrize(
"model_value",
[
"perplexity/sonar",
"perplexity/sonar-pro",
"perplexity/sonar-deep-research",
],
)
def test_all_valid_models_accepted(self, model_value: str):
inp = PerplexityBlock.Input(**_make_input(model=model_value))
assert inp.model.value == model_value
class TestPerplexityValidateData:
"""Tests for validate_data which runs during block execution (before
Pydantic instantiation). Invalid models must be sanitized here so
JSON schema validation does not reject them."""
def test_invalid_model_sanitized_before_schema_validation(self):
data = _make_input(model="gpt-5.2-2025-12-11")
error = PerplexityBlock.Input.validate_data(data)
assert error is None
assert data["model"] == PerplexityModel.SONAR.value
def test_valid_model_unchanged_by_validate_data(self):
data = _make_input(model="perplexity/sonar-pro")
error = PerplexityBlock.Input.validate_data(data)
assert error is None
assert data["model"] == "perplexity/sonar-pro"
def test_missing_model_uses_default(self):
data = _make_input() # no model key
error = PerplexityBlock.Input.validate_data(data)
assert error is None
inp = PerplexityBlock.Input(**data)
assert inp.model == PerplexityModel.SONAR

View File

@@ -18,13 +18,11 @@ from langfuse import propagate_attributes
from backend.copilot.model import (
ChatMessage,
ChatSession,
Usage,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.rate_limit import record_token_usage
from backend.copilot.response_model import (
StreamBaseResponse,
StreamError,
@@ -38,7 +36,6 @@ from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
from backend.copilot.service import (
_build_system_prompt,
@@ -49,11 +46,7 @@ from backend.copilot.service import (
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
from backend.util.prompt import (
compress_context,
estimate_token_count,
estimate_token_count_str,
)
from backend.util.prompt import compress_context
logger = logging.getLogger(__name__)
@@ -228,9 +221,6 @@ async def stream_chat_completion_baseline(
text_block_id = str(uuid.uuid4())
text_started = False
step_open = False
# Token usage accumulators — populated from streaming chunks
turn_prompt_tokens = 0
turn_completion_tokens = 0
try:
for _round in range(_MAX_TOOL_ROUNDS):
# Open a new step for each LLM round
@@ -242,7 +232,6 @@ async def stream_chat_completion_baseline(
model=config.model,
messages=openai_messages,
stream=True,
stream_options={"include_usage": True},
)
if tools:
create_kwargs["tools"] = tools
@@ -253,18 +242,7 @@ async def stream_chat_completion_baseline(
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
# Capture token usage from the streaming chunk.
# OpenRouter normalises all providers into OpenAI format
# where prompt_tokens already includes cached tokens
# (unlike Anthropic's native API). Use += to sum all
# tool-call rounds since each API call is independent.
if chunk.usage:
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
turn_completion_tokens += chunk.usage.completion_tokens or 0
if not chunk.choices:
continue
delta = chunk.choices[0].delta
delta = chunk.choices[0].delta if chunk.choices else None
if not delta:
continue
@@ -433,53 +411,6 @@ async def stream_chat_completion_baseline(
except Exception:
logger.warning("[Baseline] Langfuse trace context teardown failed")
# Fallback: estimate tokens via tiktoken when the provider does
# not honour stream_options={"include_usage": True}.
# Count the full message list (system + history + turn) since
# each API call sends the complete context window.
if turn_prompt_tokens == 0 and turn_completion_tokens == 0:
turn_prompt_tokens = max(
estimate_token_count(openai_messages, model=config.model), 0
)
turn_completion_tokens = max(
estimate_token_count_str(assistant_text, model=config.model), 0
)
logger.info(
"[Baseline] No streaming usage reported; estimated tokens: "
"prompt=%d, completion=%d",
turn_prompt_tokens,
turn_completion_tokens,
)
# Emit token usage and update session for persistence
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
total_tokens = turn_prompt_tokens + turn_completion_tokens
session.usage.append(
Usage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=total_tokens,
)
)
logger.info(
"[Baseline] Turn usage: prompt=%d, completion=%d, total=%d",
turn_prompt_tokens,
turn_completion_tokens,
total_tokens,
)
# Record for rate limiting counters
if user_id:
try:
await record_token_usage(
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
)
except Exception as usage_err:
logger.warning(
"[Baseline] Failed to record token usage: %s", usage_err
)
# Persist assistant response
if assistant_text:
session.messages.append(
@@ -490,16 +421,4 @@ async def stream_chat_completion_baseline(
except Exception as persist_err:
logger.error("[Baseline] Failed to persist session: %s", persist_err)
# Yield usage and finish AFTER try/finally (not inside finally).
# PEP 525 prohibits yielding from finally in async generators during
# aclose() — doing so raises RuntimeError on client disconnect.
# On GeneratorExit the client is already gone, so unreachable yields
# are harmless; on normal completion they reach the SSE stream.
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
yield StreamUsage(
promptTokens=turn_prompt_tokens,
completionTokens=turn_completion_tokens,
totalTokens=turn_prompt_tokens + turn_completion_tokens,
)
yield StreamFinish()

View File

@@ -70,20 +70,6 @@ class ChatConfig(BaseSettings):
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
)
# Rate limiting — token-based limits per day and per week.
# Each CoPilot turn consumes ~10-15K tokens (system prompt + tool schemas + response),
# so 2.5M daily allows ~170-250 turns/day which is reasonable for normal use.
# TODO: These are global deploy-time constants. For per-user or per-plan limits,
# move to the database (e.g. UserPlan table) and look up in get_usage_status.
daily_token_limit: int = Field(
default=2_500_000,
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
)
weekly_token_limit: int = Field(
default=12_500_000,
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
)
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,
@@ -129,7 +115,7 @@ class ChatConfig(BaseSettings):
description="E2B sandbox template to use for copilot sessions.",
)
e2b_sandbox_timeout: int = Field(
default=300, # 5 min safety net explicit per-turn pause is the primary mechanism
default=10800, # 3 hours — wall-clock timeout, not idle; explicit pause is primary
description="E2B sandbox running-time timeout (seconds). "
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
"mechanism; this is the safety net.",

View File

@@ -73,9 +73,6 @@ class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
# Cache breakdown (Anthropic-specific; zero for non-Anthropic models)
cache_read_tokens: int = 0
cache_creation_tokens: int = 0
class ChatSessionInfo(BaseModel):

View File

@@ -52,11 +52,6 @@ Examples:
You can embed a reference inside any string argument, or use it as the entire
value. Multiple references in one argument are all expanded.
**Type coercion**: The platform automatically coerces expanded string values
to match the block's expected input types. For example, if a block expects
`list[list[str]]` and you pass a string containing a JSON array (e.g. from
an @@agptfile: expansion), the string will be parsed into the correct type.
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.

View File

@@ -1,253 +0,0 @@
"""CoPilot rate limiting based on token usage.
Uses Redis fixed-window counters to track per-user token consumption
with configurable daily and weekly limits. Daily windows reset at
midnight UTC; weekly windows reset at ISO week boundary (Monday 00:00
UTC). Fails open when Redis is unavailable to avoid blocking users.
"""
import asyncio
import logging
from datetime import UTC, datetime, timedelta
from pydantic import BaseModel, Field
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
# Redis key prefixes
_PREFIX = "copilot:usage"
class UsageWindow(BaseModel):
"""Usage within a single time window."""
used: int
limit: int = Field(
description="Maximum tokens allowed in this window. 0 means unlimited."
)
resets_at: datetime
class CoPilotUsageStatus(BaseModel):
"""Current usage status for a user across all windows."""
daily: UsageWindow
weekly: UsageWindow
class RateLimitExceeded(Exception):
"""Raised when a user exceeds their CoPilot usage limit."""
def __init__(self, window: str, resets_at: datetime):
self.window = window
self.resets_at = resets_at
delta = resets_at - datetime.now(UTC)
total_secs = delta.total_seconds()
if total_secs <= 0:
time_str = "now"
else:
hours = int(total_secs // 3600)
minutes = int((total_secs % 3600) // 60)
time_str = f"{hours}h {minutes}m" if hours > 0 else f"{minutes}m"
super().__init__(
f"You've reached your {window} usage limit. Resets in {time_str}."
)
def _daily_key(user_id: str, now: datetime | None = None) -> str:
if now is None:
now = datetime.now(UTC)
return f"{_PREFIX}:daily:{user_id}:{now.strftime('%Y-%m-%d')}"
def _weekly_key(user_id: str, now: datetime | None = None) -> str:
if now is None:
now = datetime.now(UTC)
year, week, _ = now.isocalendar()
return f"{_PREFIX}:weekly:{user_id}:{year}-W{week:02d}"
def _daily_reset_time(now: datetime | None = None) -> datetime:
"""Calculate when the current daily window resets (next midnight UTC)."""
if now is None:
now = datetime.now(UTC)
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
def _weekly_reset_time(now: datetime | None = None) -> datetime:
"""Calculate when the current weekly window resets (next Monday 00:00 UTC).
On Monday itself, ``(7 - weekday) % 7`` is 0; the ``or 7`` fallback
pushes to *next* Monday so the current week's window stays open.
"""
if now is None:
now = datetime.now(UTC)
days_until_monday = (7 - now.weekday()) % 7 or 7
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(
days=days_until_monday
)
async def _fetch_counters(user_id: str, now: datetime) -> tuple[int, int]:
"""Fetch daily and weekly token counters from Redis.
Returns (daily_used, weekly_used). Returns (0, 0) if Redis is unavailable.
"""
redis = await get_redis_async()
daily_raw, weekly_raw = await asyncio.gather(
redis.get(_daily_key(user_id, now=now)),
redis.get(_weekly_key(user_id, now=now)),
)
return int(daily_raw or 0), int(weekly_raw or 0)
async def get_usage_status(
user_id: str,
daily_token_limit: int,
weekly_token_limit: int,
) -> CoPilotUsageStatus:
"""Get current usage status for a user.
Args:
user_id: The user's ID.
daily_token_limit: Max tokens per day (0 = unlimited).
weekly_token_limit: Max tokens per week (0 = unlimited).
Returns:
CoPilotUsageStatus with current usage and limits.
"""
now = datetime.now(UTC)
try:
daily_used, weekly_used = await _fetch_counters(user_id, now)
except Exception:
logger.warning(
"Redis unavailable for usage status, returning zeros", exc_info=True
)
daily_used, weekly_used = 0, 0
return CoPilotUsageStatus(
daily=UsageWindow(
used=daily_used,
limit=daily_token_limit,
resets_at=_daily_reset_time(now=now),
),
weekly=UsageWindow(
used=weekly_used,
limit=weekly_token_limit,
resets_at=_weekly_reset_time(now=now),
),
)
async def check_rate_limit(
user_id: str,
daily_token_limit: int,
weekly_token_limit: int,
) -> None:
"""Check if user is within rate limits. Raises RateLimitExceeded if not.
This is a pre-turn soft check. The authoritative usage counter is updated
by ``record_token_usage()`` after the turn completes. Under concurrency,
two parallel turns may both pass this check against the same snapshot.
This is acceptable because token-based limits are approximate by nature
(the exact token count is unknown until after generation).
Fails open: if Redis is unavailable, allows the request.
"""
now = datetime.now(UTC)
try:
daily_used, weekly_used = await _fetch_counters(user_id, now)
except Exception:
logger.warning(
"Redis unavailable for rate limit check, allowing request", exc_info=True
)
return
if daily_token_limit > 0 and daily_used >= daily_token_limit:
raise RateLimitExceeded("daily", _daily_reset_time(now=now))
if weekly_token_limit > 0 and weekly_used >= weekly_token_limit:
raise RateLimitExceeded("weekly", _weekly_reset_time(now=now))
async def record_token_usage(
user_id: str,
prompt_tokens: int,
completion_tokens: int,
*,
cache_read_tokens: int = 0,
cache_creation_tokens: int = 0,
) -> None:
"""Record token usage for a user across all windows.
Uses cost-weighted counting so cached tokens don't unfairly penalise
multi-turn conversations. Anthropic's pricing:
- uncached input: 100%
- cache creation: 25%
- cache read: 10%
- output: 100%
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
from the API response). Cache counts are passed separately.
Args:
user_id: The user's ID.
prompt_tokens: Uncached input tokens.
completion_tokens: Output tokens.
cache_read_tokens: Tokens served from prompt cache (10% cost).
cache_creation_tokens: Tokens written to prompt cache (25% cost).
"""
weighted_input = (
prompt_tokens
+ round(cache_creation_tokens * 0.25)
+ round(cache_read_tokens * 0.1)
)
total = weighted_input + completion_tokens
if total <= 0:
return
raw_total = (
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
)
logger.info(
"Recording token usage for %s: raw=%d, weighted=%d "
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
user_id[:8],
raw_total,
total,
prompt_tokens,
cache_read_tokens,
cache_creation_tokens,
completion_tokens,
)
now = datetime.now(UTC)
try:
redis = await get_redis_async()
pipe = redis.pipeline(transaction=False)
# Daily counter (expires at next midnight UTC)
d_key = _daily_key(user_id, now=now)
pipe.incrby(d_key, total)
seconds_until_daily_reset = int(
(_daily_reset_time(now=now) - now).total_seconds()
)
pipe.expire(d_key, max(seconds_until_daily_reset, 1))
# Weekly counter (expires end of week)
w_key = _weekly_key(user_id, now=now)
pipe.incrby(w_key, total)
seconds_until_weekly_reset = int(
(_weekly_reset_time(now=now) - now).total_seconds()
)
pipe.expire(w_key, max(seconds_until_weekly_reset, 1))
await pipe.execute()
except Exception:
logger.warning(
"Redis unavailable for recording token usage (tokens=%d)",
total,
exc_info=True,
)

View File

@@ -1,334 +0,0 @@
"""Unit tests for CoPilot rate limiting."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from redis.exceptions import RedisError
from .rate_limit import (
CoPilotUsageStatus,
RateLimitExceeded,
check_rate_limit,
get_usage_status,
record_token_usage,
)
_USER = "test-user-rl"
# ---------------------------------------------------------------------------
# RateLimitExceeded
# ---------------------------------------------------------------------------
class TestRateLimitExceeded:
def test_message_contains_window_name(self):
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1))
assert "daily" in str(exc)
def test_message_contains_reset_time(self):
exc = RateLimitExceeded(
"weekly", datetime.now(UTC) + timedelta(hours=2, minutes=30)
)
msg = str(exc)
# Allow for slight timing drift (29m or 30m)
assert "2h " in msg
assert "Resets in" in msg
def test_message_minutes_only_when_under_one_hour(self):
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(minutes=15))
msg = str(exc)
assert "Resets in" in msg
# Should not have "0h"
assert "0h" not in msg
def test_message_says_now_when_resets_at_is_in_the_past(self):
"""Negative delta (clock skew / stale TTL) should say 'now', not '-1h -30m'."""
exc = RateLimitExceeded("daily", datetime.now(UTC) - timedelta(minutes=5))
assert "Resets in now" in str(exc)
# ---------------------------------------------------------------------------
# get_usage_status
# ---------------------------------------------------------------------------
class TestGetUsageStatus:
@pytest.mark.asyncio
async def test_returns_redis_values(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["500", "2000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert isinstance(status, CoPilotUsageStatus)
assert status.daily.used == 500
assert status.daily.limit == 10000
assert status.weekly.used == 2000
assert status.weekly.limit == 50000
@pytest.mark.asyncio
async def test_returns_zeros_when_redis_unavailable(self):
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 0
assert status.weekly.used == 0
@pytest.mark.asyncio
async def test_partial_none_daily_counter(self):
"""Daily counter is None (new day), weekly has usage."""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=[None, "3000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 0
assert status.weekly.used == 3000
@pytest.mark.asyncio
async def test_partial_none_weekly_counter(self):
"""Weekly counter is None (start of week), daily has usage."""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["500", None])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 500
assert status.weekly.used == 0
@pytest.mark.asyncio
async def test_resets_at_daily_is_next_midnight_utc(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["0", "0"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
now = datetime.now(UTC)
# Daily reset should be within 24h
assert status.daily.resets_at > now
assert status.daily.resets_at <= now + timedelta(hours=24, seconds=5)
# ---------------------------------------------------------------------------
# check_rate_limit
# ---------------------------------------------------------------------------
class TestCheckRateLimit:
@pytest.mark.asyncio
async def test_allows_when_under_limit(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["100", "200"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
@pytest.mark.asyncio
async def test_raises_when_daily_limit_exceeded(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["10000", "200"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
with pytest.raises(RateLimitExceeded) as exc_info:
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert exc_info.value.window == "daily"
@pytest.mark.asyncio
async def test_raises_when_weekly_limit_exceeded(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["100", "50000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
with pytest.raises(RateLimitExceeded) as exc_info:
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert exc_info.value.window == "weekly"
@pytest.mark.asyncio
async def test_allows_when_redis_unavailable(self):
"""Fail-open: allow requests when Redis is down."""
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
# Should not raise
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
@pytest.mark.asyncio
async def test_skips_check_when_limit_is_zero(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["999999", "999999"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise — limits of 0 mean unlimited
await check_rate_limit(_USER, daily_token_limit=0, weekly_token_limit=0)
# ---------------------------------------------------------------------------
# record_token_usage
# ---------------------------------------------------------------------------
class TestRecordTokenUsage:
@staticmethod
def _make_pipeline_mock() -> MagicMock:
"""Create a pipeline mock with sync methods and async execute."""
pipe = MagicMock()
pipe.execute = AsyncMock(return_value=[])
return pipe
@pytest.mark.asyncio
async def test_increments_redis_counters(self):
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
# Should call incrby twice (daily + weekly) with total=150
incrby_calls = mock_pipe.incrby.call_args_list
assert len(incrby_calls) == 2
assert incrby_calls[0].args[1] == 150 # daily
assert incrby_calls[1].args[1] == 150 # weekly
@pytest.mark.asyncio
async def test_skips_when_zero_tokens(self):
mock_redis = AsyncMock()
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=0, completion_tokens=0)
# Should not call pipeline at all
mock_redis.pipeline.assert_not_called()
@pytest.mark.asyncio
async def test_sets_expire_on_both_keys(self):
"""Pipeline should call expire for both daily and weekly keys."""
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
expire_calls = mock_pipe.expire.call_args_list
assert len(expire_calls) == 2
# Daily key TTL should be positive (seconds until next midnight)
daily_ttl = expire_calls[0].args[1]
assert daily_ttl >= 1
# Weekly key TTL should be positive (seconds until next Monday)
weekly_ttl = expire_calls[1].args[1]
assert weekly_ttl >= 1
@pytest.mark.asyncio
async def test_handles_redis_failure_gracefully(self):
"""Should not raise when Redis is unavailable."""
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
# Should not raise
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
@pytest.mark.asyncio
async def test_cost_weighted_counting(self):
"""Cached tokens should be weighted: cache_read=10%, cache_create=25%."""
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(
_USER,
prompt_tokens=100, # uncached → 100
completion_tokens=50, # output → 50
cache_read_tokens=10000, # 10% → 1000
cache_creation_tokens=400, # 25% → 100
)
# Expected weighted total: 100 + 1000 + 100 + 50 = 1250
incrby_calls = mock_pipe.incrby.call_args_list
assert len(incrby_calls) == 2
assert incrby_calls[0].args[1] == 1250 # daily
assert incrby_calls[1].args[1] == 1250 # weekly
@pytest.mark.asyncio
async def test_handles_redis_error_during_pipeline_execute(self):
"""Should not raise when pipeline.execute() fails with RedisError."""
mock_pipe = self._make_pipeline_mock()
mock_pipe.execute = AsyncMock(side_effect=RedisError("Pipeline failed"))
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise — fail-open
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)

View File

@@ -186,29 +186,12 @@ class StreamToolOutputAvailable(StreamBaseResponse):
class StreamUsage(StreamBaseResponse):
"""Token usage statistics.
Emitted as an SSE comment so the Vercel AI SDK parser ignores it
(it uses z.strictObject() and rejects unknown event types).
Usage data is recorded server-side (session DB + Redis counters).
"""
"""Token usage statistics."""
type: ResponseType = ResponseType.USAGE
promptTokens: int = Field(..., description="Number of uncached prompt tokens")
promptTokens: int = Field(..., description="Number of prompt tokens")
completionTokens: int = Field(..., description="Number of completion tokens")
totalTokens: int = Field(
..., description="Total number of tokens (raw, not weighted)"
)
cacheReadTokens: int = Field(
default=0, description="Prompt tokens served from cache (10% cost)"
)
cacheCreationTokens: int = Field(
default=0, description="Prompt tokens written to cache (25% cost)"
)
def to_sse(self) -> str:
"""Emit as SSE comment so the AI SDK parser ignores it."""
return f": usage {self.model_dump_json(exclude_none=True)}\n\n"
totalTokens: int = Field(..., description="Total number of tokens")
class StreamError(StreamBaseResponse):

View File

@@ -198,7 +198,6 @@ class CompactionTracker:
def reset_for_query(self) -> None:
"""Reset per-query state before a new SDK query."""
self._compact_start.clear()
self._done = False
self._start_emitted = False
self._tool_call_id = ""

View File

@@ -1,546 +0,0 @@
"""End-to-end compaction flow test.
Simulates the full service.py compaction lifecycle using real-format
JSONL session files — no SDK subprocess needed. Exercises:
1. TranscriptBuilder loads a "downloaded" transcript
2. User query appended, assistant response streamed
3. PreCompact hook fires → CompactionTracker.on_compact()
4. Next message → emit_start_if_ready() yields spinner events
5. Message after that → emit_end_if_ready() returns end events
6. _read_compacted_entries() reads the CLI session file
7. TranscriptBuilder.replace_entries() syncs state
8. More messages appended post-compaction
9. to_jsonl() exports full state for upload
10. Fresh builder loads the export — roundtrip verified
"""
import asyncio
from pathlib import Path
from backend.copilot.model import ChatSession
from backend.copilot.response_model import (
StreamFinishStep,
StreamStartStep,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from backend.copilot.sdk.compaction import CompactionTracker
from backend.copilot.sdk.transcript import strip_progress_entries
from backend.copilot.sdk.transcript_builder import TranscriptBuilder
from backend.util import json
def _make_jsonl(*entries: dict) -> str:
return "\n".join(json.dumps(e) for e in entries) + "\n"
def _run(coro):
"""Run an async coroutine synchronously."""
return asyncio.run(coro)
def _read_compacted_entries(path: str) -> tuple[list[dict], str] | None:
"""Test-only: read compacted entries from a session JSONL file.
Returns (parsed_dicts, jsonl_string) from the first ``isCompactSummary``
entry onward, or ``None`` if no summary is found.
"""
content = Path(path).read_text()
lines = content.strip().split("\n")
compact_idx: int | None = None
parsed: list[dict] = []
raw_lines: list[str] = []
for line in lines:
if not line.strip():
continue
entry = json.loads(line, fallback=None)
if not isinstance(entry, dict):
continue
parsed.append(entry)
raw_lines.append(line.strip())
if compact_idx is None and entry.get("isCompactSummary"):
compact_idx = len(parsed) - 1
if compact_idx is None:
return None
return parsed[compact_idx:], "\n".join(raw_lines[compact_idx:]) + "\n"
# ---------------------------------------------------------------------------
# Fixtures: realistic CLI session file content
# ---------------------------------------------------------------------------
# Pre-compaction conversation
USER_1 = {
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "What files are in this project?"},
}
ASST_1_THINKING = {
"type": "assistant",
"uuid": "a1-think",
"parentUuid": "u1",
"message": {
"role": "assistant",
"id": "msg_sdk_aaa",
"type": "message",
"content": [{"type": "thinking", "thinking": "Let me look at the files..."}],
"stop_reason": None,
"stop_sequence": None,
},
}
ASST_1_TOOL = {
"type": "assistant",
"uuid": "a1-tool",
"parentUuid": "u1",
"message": {
"role": "assistant",
"id": "msg_sdk_aaa",
"type": "message",
"content": [
{
"type": "tool_use",
"id": "tu1",
"name": "Bash",
"input": {"command": "ls"},
}
],
"stop_reason": "tool_use",
"stop_sequence": None,
},
}
TOOL_RESULT_1 = {
"type": "user",
"uuid": "tr1",
"parentUuid": "a1-tool",
"message": {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "tu1",
"content": "file1.py\nfile2.py",
}
],
},
}
ASST_1_TEXT = {
"type": "assistant",
"uuid": "a1-text",
"parentUuid": "tr1",
"message": {
"role": "assistant",
"id": "msg_sdk_bbb",
"type": "message",
"content": [{"type": "text", "text": "I found file1.py and file2.py."}],
"stop_reason": "end_turn",
"stop_sequence": None,
},
}
# Progress entries (should be stripped during upload)
PROGRESS_1 = {
"type": "progress",
"uuid": "prog1",
"parentUuid": "a1-tool",
"data": {"type": "bash_progress", "stdout": "running ls..."},
}
# Second user message
USER_2 = {
"type": "user",
"uuid": "u2",
"parentUuid": "a1-text",
"message": {"role": "user", "content": "Show me file1.py"},
}
ASST_2 = {
"type": "assistant",
"uuid": "a2",
"parentUuid": "u2",
"message": {
"role": "assistant",
"id": "msg_sdk_ccc",
"type": "message",
"content": [{"type": "text", "text": "Here is file1.py content..."}],
"stop_reason": "end_turn",
"stop_sequence": None,
},
}
# --- Compaction summary (written by CLI after context compaction) ---
COMPACT_SUMMARY = {
"type": "summary",
"uuid": "cs1",
"isCompactSummary": True,
"message": {
"role": "user",
"content": (
"Summary: User asked about project files. Found file1.py and file2.py. "
"User then asked to see file1.py."
),
},
}
# Post-compaction assistant response
POST_COMPACT_ASST = {
"type": "assistant",
"uuid": "a3",
"parentUuid": "cs1",
"message": {
"role": "assistant",
"id": "msg_sdk_ddd",
"type": "message",
"content": [{"type": "text", "text": "Here is the content of file1.py..."}],
"stop_reason": "end_turn",
"stop_sequence": None,
},
}
# Post-compaction user follow-up
USER_3 = {
"type": "user",
"uuid": "u3",
"parentUuid": "a3",
"message": {"role": "user", "content": "Now show file2.py"},
}
ASST_3 = {
"type": "assistant",
"uuid": "a4",
"parentUuid": "u3",
"message": {
"role": "assistant",
"id": "msg_sdk_eee",
"type": "message",
"content": [{"type": "text", "text": "Here is file2.py..."}],
"stop_reason": "end_turn",
"stop_sequence": None,
},
}
# ---------------------------------------------------------------------------
# E2E test
# ---------------------------------------------------------------------------
class TestCompactionE2E:
def _write_session_file(self, session_dir, entries):
"""Write a CLI session JSONL file."""
path = session_dir / "session.jsonl"
path.write_text(_make_jsonl(*entries))
return path
def test_full_compaction_lifecycle(self, tmp_path):
"""Simulate the complete service.py compaction flow.
Timeline:
1. Previous turn uploaded transcript with [USER_1, ASST_1, USER_2, ASST_2]
2. Current turn: download → load_previous
3. User sends "Now show file2.py" → append_user
4. SDK starts streaming response
5. Mid-stream: PreCompact hook fires (context too large)
6. CLI writes compaction summary to session file
7. Next SDK message → emit_start (spinner)
8. Following message → emit_end (end events)
9. _read_compacted_entries reads the session file
10. replace_entries syncs TranscriptBuilder
11. More assistant messages appended
12. Export → upload → next turn downloads it
"""
session_dir = tmp_path / "session"
session_dir.mkdir(parents=True)
# --- Step 1-2: Load "downloaded" transcript from previous turn ---
previous_transcript = _make_jsonl(
USER_1,
ASST_1_THINKING,
ASST_1_TOOL,
TOOL_RESULT_1,
ASST_1_TEXT,
USER_2,
ASST_2,
)
builder = TranscriptBuilder()
builder.load_previous(previous_transcript)
assert builder.entry_count == 7
# --- Step 3: User sends new query ---
builder.append_user("Now show file2.py")
assert builder.entry_count == 8
# --- Step 4: SDK starts streaming ---
builder.append_assistant(
[{"type": "thinking", "thinking": "Let me read file2.py..."}],
model="claude-sonnet-4-20250514",
)
assert builder.entry_count == 9
# --- Step 5-6: PreCompact fires, CLI writes session file ---
session_file = self._write_session_file(
session_dir,
[
USER_1,
ASST_1_THINKING,
ASST_1_TOOL,
PROGRESS_1,
TOOL_RESULT_1,
ASST_1_TEXT,
USER_2,
ASST_2,
COMPACT_SUMMARY,
POST_COMPACT_ASST,
USER_3,
ASST_3,
],
)
# --- Step 7: CompactionTracker receives PreCompact hook ---
tracker = CompactionTracker()
session = ChatSession.new(user_id="test-user")
# on_compact is a property returning Event.set callable
tracker.on_compact()
# --- Step 8: Next SDK message arrives → emit_start ---
start_events = tracker.emit_start_if_ready()
assert len(start_events) == 3
assert isinstance(start_events[0], StreamStartStep)
assert isinstance(start_events[1], StreamToolInputStart)
assert isinstance(start_events[2], StreamToolInputAvailable)
# Verify tool_call_id is set
tool_call_id = start_events[1].toolCallId
assert tool_call_id.startswith("compaction-")
# --- Step 9: Following message → emit_end ---
end_events = _run(tracker.emit_end_if_ready(session))
assert len(end_events) == 2
assert isinstance(end_events[0], StreamToolOutputAvailable)
assert isinstance(end_events[1], StreamFinishStep)
# Verify same tool_call_id
assert end_events[0].toolCallId == tool_call_id
# Session should have compaction messages persisted
assert len(session.messages) == 2
assert session.messages[0].role == "assistant"
assert session.messages[1].role == "tool"
# --- Step 10: _read_compacted_entries + replace_entries ---
result = _read_compacted_entries(str(session_file))
assert result is not None
compacted_dicts, compacted_jsonl = result
# Should have: COMPACT_SUMMARY + POST_COMPACT_ASST + USER_3 + ASST_3
assert len(compacted_dicts) == 4
assert compacted_dicts[0]["uuid"] == "cs1"
assert compacted_dicts[0]["isCompactSummary"] is True
# Replace builder state with compacted JSONL
old_count = builder.entry_count
builder.replace_entries(compacted_jsonl)
assert builder.entry_count == 4 # Only compacted entries
assert builder.entry_count < old_count # Compaction reduced entries
# --- Step 11: More assistant messages after compaction ---
builder.append_assistant(
[{"type": "text", "text": "Here is file2.py:\n\ndef hello():\n pass"}],
model="claude-sonnet-4-20250514",
stop_reason="end_turn",
)
assert builder.entry_count == 5
# --- Step 12: Export for upload ---
output = builder.to_jsonl()
assert output # Not empty
output_entries = [json.loads(line) for line in output.strip().split("\n")]
assert len(output_entries) == 5
# Verify structure:
# [COMPACT_SUMMARY, POST_COMPACT_ASST, USER_3, ASST_3, new_assistant]
assert output_entries[0]["type"] == "summary"
assert output_entries[0].get("isCompactSummary") is True
assert output_entries[0]["uuid"] == "cs1"
assert output_entries[1]["uuid"] == "a3"
assert output_entries[2]["uuid"] == "u3"
assert output_entries[3]["uuid"] == "a4"
assert output_entries[4]["type"] == "assistant"
# Verify parent chain is intact
assert output_entries[1]["parentUuid"] == "cs1" # a3 → cs1
assert output_entries[2]["parentUuid"] == "a3" # u3 → a3
assert output_entries[3]["parentUuid"] == "u3" # a4 → u3
assert output_entries[4]["parentUuid"] == "a4" # new → a4
# --- Step 13: Roundtrip — next turn loads this export ---
builder2 = TranscriptBuilder()
builder2.load_previous(output)
assert builder2.entry_count == 5
# isCompactSummary survives roundtrip
output2 = builder2.to_jsonl()
first_entry = json.loads(output2.strip().split("\n")[0])
assert first_entry.get("isCompactSummary") is True
# Can append more messages
builder2.append_user("What about file3.py?")
assert builder2.entry_count == 6
final_output = builder2.to_jsonl()
last_entry = json.loads(final_output.strip().split("\n")[-1])
assert last_entry["type"] == "user"
# Parented to the last entry from previous turn
assert last_entry["parentUuid"] == output_entries[-1]["uuid"]
def test_double_compaction_within_session(self, tmp_path):
"""Two compactions in the same session (across reset_for_query)."""
session_dir = tmp_path / "session"
session_dir.mkdir(parents=True)
tracker = CompactionTracker()
session = ChatSession.new(user_id="test")
builder = TranscriptBuilder()
# --- First query with compaction ---
builder.append_user("first question")
builder.append_assistant([{"type": "text", "text": "first answer"}])
# Write session file for first compaction
first_summary = {
"type": "summary",
"uuid": "cs-first",
"isCompactSummary": True,
"message": {"role": "user", "content": "First compaction summary"},
}
first_post = {
"type": "assistant",
"uuid": "a-first",
"parentUuid": "cs-first",
"message": {"role": "assistant", "content": "first post-compact"},
}
file1 = session_dir / "session1.jsonl"
file1.write_text(_make_jsonl(first_summary, first_post))
tracker.on_compact()
tracker.emit_start_if_ready()
end_events1 = _run(tracker.emit_end_if_ready(session))
assert len(end_events1) == 2 # output + finish
result1_entries = _read_compacted_entries(str(file1))
assert result1_entries is not None
_, compacted1_jsonl = result1_entries
builder.replace_entries(compacted1_jsonl)
assert builder.entry_count == 2
# --- Reset for second query ---
tracker.reset_for_query()
# --- Second query with compaction ---
builder.append_user("second question")
builder.append_assistant([{"type": "text", "text": "second answer"}])
second_summary = {
"type": "summary",
"uuid": "cs-second",
"isCompactSummary": True,
"message": {"role": "user", "content": "Second compaction summary"},
}
second_post = {
"type": "assistant",
"uuid": "a-second",
"parentUuid": "cs-second",
"message": {"role": "assistant", "content": "second post-compact"},
}
file2 = session_dir / "session2.jsonl"
file2.write_text(_make_jsonl(second_summary, second_post))
tracker.on_compact()
tracker.emit_start_if_ready()
end_events2 = _run(tracker.emit_end_if_ready(session))
assert len(end_events2) == 2 # output + finish
result2_entries = _read_compacted_entries(str(file2))
assert result2_entries is not None
_, compacted2_jsonl = result2_entries
builder.replace_entries(compacted2_jsonl)
assert builder.entry_count == 2 # Only second compaction entries
# Export and verify
output = builder.to_jsonl()
entries = [json.loads(line) for line in output.strip().split("\n")]
assert entries[0]["uuid"] == "cs-second"
assert entries[0].get("isCompactSummary") is True
def test_strip_progress_then_load_then_compact_roundtrip(self, tmp_path):
"""Full pipeline: strip → load → compact → replace → export → reload.
This tests the exact sequence that happens across two turns:
Turn 1: SDK produces transcript with progress entries
Upload: strip_progress_entries removes progress, upload to cloud
Turn 2: Download → load_previous → compaction fires → replace → export
Turn 3: Download the Turn 2 export → load_previous (roundtrip)
"""
session_dir = tmp_path / "session"
session_dir.mkdir(parents=True)
# --- Turn 1: SDK produces raw transcript ---
raw_content = _make_jsonl(
USER_1,
ASST_1_THINKING,
ASST_1_TOOL,
PROGRESS_1,
TOOL_RESULT_1,
ASST_1_TEXT,
USER_2,
ASST_2,
)
# Strip progress for upload
stripped = strip_progress_entries(raw_content)
stripped_entries = [
json.loads(line) for line in stripped.strip().split("\n") if line.strip()
]
# Progress should be gone
assert not any(e.get("type") == "progress" for e in stripped_entries)
assert len(stripped_entries) == 7 # 8 - 1 progress
# --- Turn 2: Download stripped, load, compaction happens ---
builder = TranscriptBuilder()
builder.load_previous(stripped)
assert builder.entry_count == 7
builder.append_user("Now show file2.py")
builder.append_assistant(
[{"type": "text", "text": "Reading file2.py..."}],
model="claude-sonnet-4-20250514",
)
# CLI writes session file with compaction
session_file = self._write_session_file(
session_dir,
[
USER_1,
ASST_1_TOOL,
TOOL_RESULT_1,
ASST_1_TEXT,
USER_2,
ASST_2,
COMPACT_SUMMARY,
POST_COMPACT_ASST,
],
)
result = _read_compacted_entries(str(session_file))
assert result is not None
_, compacted_jsonl = result
builder.replace_entries(compacted_jsonl)
# Append post-compaction message
builder.append_user("Thanks!")
output = builder.to_jsonl()
# --- Turn 3: Fresh load of Turn 2 export ---
builder3 = TranscriptBuilder()
builder3.load_previous(output)
# Should have: compact_summary + post_compact_asst + "Thanks!"
assert builder3.entry_count == 3
# Compact summary survived the full pipeline
first = json.loads(builder3.to_jsonl().strip().split("\n")[0])
assert first.get("isCompactSummary") is True
assert first["type"] == "summary"

View File

@@ -26,17 +26,3 @@ For other services, search the MCP registry at https://registry.modelcontextprot
If the server requires credentials, a `SetupRequirementsResponse` is returned with an OAuth
login prompt. Once the user completes the flow and confirms, retry the same call immediately.
### Communication style
Avoid technical jargon like "MCP server", "OAuth", or "credentials" when talking to the user.
Use plain, friendly language instead:
| Instead of… | Say… |
|---|---|
| "Let me connect to Sentry's MCP server and discover what tools are available." | "I can connect to Sentry and help identify important issues." |
| "Let me connect to Sentry's MCP server now." | "Next, I'll connect to Sentry." |
| "The MCP server at mcp.sentry.dev requires authentication. Please connect your credentials to continue." | "To continue, sign in to Sentry and approve access." |
| "Sentry's MCP server needs OAuth authentication. You should see a prompt to connect your Sentry account…" | "You should see a prompt to sign in to Sentry. Once connected, I can help surface critical issues right away." |
Use **"connect to [Service]"** or **"sign in to [Service]"** — never "MCP server", "OAuth", or "credentials".

View File

@@ -221,12 +221,12 @@ class SDKResponseAdapter:
responses.append(StreamFinish())
else:
logger.warning(
"Unexpected ResultMessage subtype: %s", sdk_message.subtype
f"Unexpected ResultMessage subtype: {sdk_message.subtype}"
)
responses.append(StreamFinish())
else:
logger.debug("Unhandled SDK message type: %s", type(sdk_message).__name__)
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
return responses

View File

@@ -52,7 +52,7 @@ def _validate_workspace_path(
if is_allowed_local_path(path, sdk_cwd):
return {}
logger.warning("Blocked %s outside workspace: %s", tool_name, path)
logger.warning(f"Blocked {tool_name} outside workspace: {path}")
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
return _deny(
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
@@ -71,7 +71,7 @@ def _validate_tool_access(
"""
# Block forbidden tools
if tool_name in BLOCKED_TOOLS:
logger.warning("Blocked tool access attempt: %s", tool_name)
logger.warning(f"Blocked tool access attempt: {tool_name}")
return _deny(
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
"This is enforced by the platform and cannot be bypassed. "
@@ -111,9 +111,7 @@ def _validate_user_isolation(
# the tool itself via _validate_ephemeral_path.
path = tool_input.get("path", "") or tool_input.get("file_path", "")
if path and ".." in path:
logger.warning(
"Blocked path traversal attempt: %s by user %s", path, user_id
)
logger.warning(f"Blocked path traversal attempt: {path} by user {user_id}")
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
@@ -171,7 +169,7 @@ def create_security_hooks(
# Block background task execution first — denied calls
# should not consume a subtask slot.
if tool_input.get("run_in_background"):
logger.info("[SDK] Blocked background Task, user=%s", user_id)
logger.info(f"[SDK] Blocked background Task, user={user_id}")
return cast(
SyncHookJSONOutput,
_deny(
@@ -213,7 +211,7 @@ def create_security_hooks(
if tool_name == "Task" and tool_use_id is not None:
task_tool_use_ids.add(tool_use_id)
logger.debug("[SDK] Tool start: %s, user=%s", tool_name, user_id)
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
return cast(SyncHookJSONOutput, {})
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:

View File

@@ -40,13 +40,11 @@ from ..constants import COPILOT_ERROR_PREFIX, COPILOT_SYSTEM_PREFIX
from ..model import (
ChatMessage,
ChatSession,
Usage,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from ..prompting import get_sdk_supplement
from ..rate_limit import record_token_usage
from ..response_model import (
StreamBaseResponse,
StreamError,
@@ -56,7 +54,6 @@ from ..response_model import (
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
StreamUsage,
)
from ..service import (
_build_system_prompt,
@@ -78,12 +75,8 @@ from .tool_adapter import (
wait_for_stash,
)
from .transcript import (
COMPACT_THRESHOLD_BYTES,
TranscriptDownload,
cleanup_cli_project_dir,
compact_transcript,
download_transcript,
read_cli_session_file,
upload_transcript,
validate_transcript,
write_transcript_to_tempfile,
@@ -301,7 +294,7 @@ def _cleanup_sdk_tool_results(cwd: str) -> None:
"""
normalized = os.path.normpath(cwd)
if not normalized.startswith(_SDK_CWD_PREFIX):
logger.warning("[SDK] Rejecting cleanup for path outside workspace: %s", cwd)
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
return
# Clean the CLI's project directory (transcripts + tool-results).
@@ -395,7 +388,7 @@ async def _compress_messages(
client=client,
)
except Exception as e:
logger.warning("[SDK] Context compression with LLM failed: %s", e)
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
# Fall back to truncation-only (no LLM summarization)
result = await compress_context(
messages=messages_dict,
@@ -631,56 +624,6 @@ async def _prepare_file_attachments(
return PreparedAttachments(hint=hint, image_blocks=image_blocks)
async def _maybe_compact_and_upload(
dl: TranscriptDownload,
user_id: str,
session_id: str,
log_prefix: str = "[Transcript]",
) -> str:
"""Compact an oversized transcript and upload the compacted version.
Returns the (possibly compacted) transcript content, or an empty string
if compaction was needed but failed.
"""
content = dl.content
if len(content) <= COMPACT_THRESHOLD_BYTES:
return content
logger.warning(
"%s Transcript oversized (%dB > %dB), compacting",
log_prefix,
len(content),
COMPACT_THRESHOLD_BYTES,
)
compacted = await compact_transcript(content, log_prefix=log_prefix)
if not compacted:
logger.warning(
"%s Compaction failed, skipping resume for this turn", log_prefix
)
return ""
# Keep the original message_count: it reflects the number of
# session.messages covered by this transcript, which the gap-fill
# logic uses as a slice index. Counting JSONL lines would give a
# smaller number (compacted messages != session message count) and
# cause already-covered messages to be re-injected.
try:
await upload_transcript(
user_id=user_id,
session_id=session_id,
content=compacted,
message_count=dl.message_count,
log_prefix=log_prefix,
)
except Exception:
logger.warning(
"%s Failed to upload compacted transcript",
log_prefix,
exc_info=True,
)
return compacted
async def stream_chat_completion_sdk(
session_id: str,
message: str | None = None,
@@ -792,14 +735,6 @@ async def stream_chat_completion_sdk(
_otel_ctx: Any = None
# Make sure there is no more code between the lock acquisition and try-block.
# Token usage accumulators — populated from ResultMessage at end of turn
turn_prompt_tokens = 0 # uncached input tokens only
turn_completion_tokens = 0
turn_cache_read_tokens = 0
turn_cache_creation_tokens = 0
total_tokens = 0 # computed once before StreamUsage, reused in finally
turn_cost_usd: float | None = None
try:
# Build system prompt (reuses non-SDK path with Langfuse support).
# Pre-compute the cwd here so the exact working directory path can be
@@ -892,33 +827,20 @@ async def stream_chat_completion_sdk(
is_valid,
)
if is_valid:
transcript_content = await _maybe_compact_and_upload(
dl,
user_id=user_id or "",
session_id=session_id,
log_prefix=log_prefix,
)
# Load previous context into builder (empty string is a no-op)
if transcript_content:
transcript_builder.load_previous(
transcript_content, log_prefix=log_prefix
)
resume_file = (
write_transcript_to_tempfile(
transcript_content, session_id, sdk_cwd
)
if transcript_content
else None
# Load previous FULL context into builder
transcript_builder.load_previous(dl.content, log_prefix=log_prefix)
resume_file = write_transcript_to_tempfile(
dl.content, session_id, sdk_cwd
)
if resume_file:
use_resume = True
transcript_msg_count = dl.message_count
logger.debug(
f"{log_prefix} Using --resume ({len(transcript_content)}B, "
f"{log_prefix} Using --resume ({len(dl.content)}B, "
f"msg_count={transcript_msg_count})"
)
else:
logger.warning("%s Transcript downloaded but invalid", log_prefix)
logger.warning(f"{log_prefix} Transcript downloaded but invalid")
elif config.claude_agent_use_resume and user_id and len(session.messages) > 1:
logger.warning(
f"{log_prefix} No transcript available "
@@ -1188,7 +1110,7 @@ async def stream_chat_completion_sdk(
- len(adapter.resolved_tool_calls),
)
# Log ResultMessage details and capture token usage
# Log ResultMessage details for debugging
if isinstance(sdk_msg, ResultMessage):
logger.info(
"%s Received: ResultMessage %s "
@@ -1207,46 +1129,9 @@ async def stream_chat_completion_sdk(
sdk_msg.result or "(no error message provided)",
)
# Capture token usage from ResultMessage.
# Anthropic reports cached tokens separately:
# input_tokens = uncached only
# cache_read_input_tokens = served from cache
# cache_creation_input_tokens = written to cache
if sdk_msg.usage:
turn_prompt_tokens += sdk_msg.usage.get("input_tokens", 0)
turn_cache_read_tokens += sdk_msg.usage.get(
"cache_read_input_tokens", 0
)
turn_cache_creation_tokens += sdk_msg.usage.get(
"cache_creation_input_tokens", 0
)
turn_completion_tokens += sdk_msg.usage.get(
"output_tokens", 0
)
logger.info(
"%s Token usage: uncached=%d, cache_read=%d, cache_create=%d, output=%d",
log_prefix,
turn_prompt_tokens,
turn_cache_read_tokens,
turn_cache_creation_tokens,
turn_completion_tokens,
)
if sdk_msg.total_cost_usd is not None:
turn_cost_usd = sdk_msg.total_cost_usd
# Emit compaction end if SDK finished compacting.
# When compaction ends, sync TranscriptBuilder with
# the CLI's compacted session file so the uploaded
# transcript reflects compaction.
compaction_events = await compaction.emit_end_if_ready(session)
for ev in compaction_events:
# Emit compaction end if SDK finished compacting
for ev in await compaction.emit_end_if_ready(session):
yield ev
if compaction_events and sdk_cwd:
cli_content = await read_cli_session_file(sdk_cwd)
if cli_content:
transcript_builder.replace_entries(
cli_content, log_prefix=log_prefix
)
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
@@ -1440,27 +1325,6 @@ async def stream_chat_completion_sdk(
) and not has_appended_assistant:
session.messages.append(assistant_response)
# Emit token usage to the client (must be in try to reach SSE stream).
# Session persistence of usage is in finally to stay consistent with
# rate-limit recording even if an exception interrupts between here
# and the finally block.
# Compute total_tokens once; reused in the finally block for
# session persistence and rate-limit recording.
total_tokens = (
turn_prompt_tokens
+ turn_cache_read_tokens
+ turn_cache_creation_tokens
+ turn_completion_tokens
)
if total_tokens > 0:
yield StreamUsage(
promptTokens=turn_prompt_tokens,
completionTokens=turn_completion_tokens,
totalTokens=total_tokens,
cacheReadTokens=turn_cache_read_tokens,
cacheCreationTokens=turn_cache_creation_tokens,
)
# Transcript upload is handled exclusively in the finally block
# to avoid double-uploads (the success path used to upload the
# old resume file, then the finally block overwrote it with the
@@ -1525,48 +1389,6 @@ async def stream_chat_completion_sdk(
except Exception:
logger.warning("OTEL context teardown failed", exc_info=True)
# --- Persist token usage to session + rate-limit counters ---
# Both must live in finally so they stay consistent even when an
# exception interrupts the try block after StreamUsage was yielded.
# total_tokens is computed once before StreamUsage yield above.
if total_tokens > 0:
if session is not None:
session.usage.append(
Usage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=total_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
)
)
logger.info(
"%s Turn usage: uncached=%d, cache_read=%d, cache_create=%d, "
"output=%d, total=%d, cost_usd=%s",
log_prefix,
turn_prompt_tokens,
turn_cache_read_tokens,
turn_cache_creation_tokens,
turn_completion_tokens,
total_tokens,
turn_cost_usd,
)
if user_id and total_tokens > 0:
try:
await record_token_usage(
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
)
except Exception as usage_err:
logger.warning(
"%s Failed to record token usage: %s",
log_prefix,
usage_err,
)
# --- Persist session messages ---
# This MUST run in finally to persist messages even when the generator
# is stopped early (e.g., user clicks stop, processor breaks stream loop).
@@ -1662,6 +1484,6 @@ async def _update_title_async(
)
if title and user_id:
await update_session_title(session_id, user_id, title, only_if_empty=True)
logger.debug("[SDK] Generated title for %s: %s", session_id, title)
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
except Exception as e:
logger.warning("[SDK] Failed to update session title: %s", e)
logger.warning(f"[SDK] Failed to update session title: {e}")

View File

@@ -234,9 +234,7 @@ def create_tool_handler(base_tool: BaseTool):
try:
return await _execute_tool_sync(base_tool, user_id, session, args)
except Exception as e:
logger.error(
"Error executing tool %s: %s", base_tool.name, e, exc_info=True
)
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
return tool_handler

View File

@@ -13,17 +13,10 @@ filesystem for self-hosted) — no DB column needed.
import logging
import os
import re
import shutil
import time
from dataclasses import dataclass
from pathlib import Path
from uuid import uuid4
import openai
from backend.copilot.config import ChatConfig
from backend.util import json
from backend.util.prompt import CompressResult, compress_context
logger = logging.getLogger(__name__)
@@ -41,11 +34,6 @@ STRIPPABLE_TYPES = frozenset(
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
)
# JSONL protocol values used in transcript serialization.
STOP_REASON_END_TURN = "end_turn"
COMPACT_MSG_ID_PREFIX = "msg_compact_"
ENTRY_TYPE_MESSAGE = "message"
@dataclass
class TranscriptDownload:
@@ -94,11 +82,7 @@ def strip_progress_entries(content: str) -> str:
parent = entry.get("parentUuid", "")
if uid:
uuid_to_parent[uid] = parent
if (
entry.get("type", "") in STRIPPABLE_TYPES
and uid
and not entry.get("isCompactSummary")
):
if entry.get("type", "") in STRIPPABLE_TYPES and uid:
stripped_uuids.add(uid)
# Second pass: keep non-stripped entries, reparenting where needed.
@@ -109,9 +93,7 @@ def strip_progress_entries(content: str) -> str:
continue
parent = entry.get("parentUuid", "")
original_parent = parent
seen_parents: set[str] = set()
while parent in stripped_uuids and parent not in seen_parents:
seen_parents.add(parent)
while parent in stripped_uuids:
parent = uuid_to_parent.get(parent, "")
if parent != original_parent:
entry["parentUuid"] = parent
@@ -124,9 +106,7 @@ def strip_progress_entries(content: str) -> str:
if not isinstance(entry, dict):
result_lines.append(line)
continue
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
"isCompactSummary"
):
if entry.get("type", "") in STRIPPABLE_TYPES:
continue
uid = entry.get("uuid", "")
if uid in reparented:
@@ -157,78 +137,32 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
def _cli_project_dir(sdk_cwd: str) -> str | None:
"""Return the CLI's project directory for a given working directory.
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
"""Remove the CLI's project directory for a specific working directory.
Returns ``None`` if the path would escape the projects base.
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
safe to remove entirely after the transcript has been uploaded.
"""
import shutil
# Encode cwd the same way CLI does (replaces non-alphanumeric with -)
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
if not project_dir.startswith(projects_base + os.sep):
logger.warning("[Transcript] Project dir escaped base: %s", project_dir)
return None
return project_dir
async def read_cli_session_file(sdk_cwd: str) -> str | None:
"""Read the CLI's own session file, which reflects any mid-stream compaction.
After the CLI compacts context, its session file contains the compacted
conversation. Reading this file lets ``TranscriptBuilder`` replace its
uncompacted entries with the CLI's compacted version.
"""
import aiofiles
project_dir = _cli_project_dir(sdk_cwd)
if not project_dir or not os.path.isdir(project_dir):
return None
jsonl_files = list(Path(project_dir).glob("*.jsonl"))
if not jsonl_files:
logger.debug("[Transcript] No CLI session file in %s", project_dir)
return None
# Pick the most recently modified file (there should only be one per turn).
# Guard against races where a file is deleted between glob and stat.
candidates: list[tuple[float, Path]] = []
for p in jsonl_files:
try:
candidates.append((p.stat().st_mtime, p))
except OSError:
continue
if not candidates:
logger.debug("[Transcript] No readable CLI session file in %s", project_dir)
return None
# Resolve + prefix check to prevent symlink escapes.
session_file = max(candidates, key=lambda item: item[0])[1]
real_path = str(session_file.resolve())
if not real_path.startswith(project_dir + os.sep):
logger.warning("[Transcript] Session file escaped project dir: %s", real_path)
return None
try:
async with aiofiles.open(real_path) as f:
content = await f.read()
logger.info(
"[Transcript] Read CLI session file: %s (%d bytes)",
real_path,
len(content),
logger.warning(
f"[Transcript] Cleanup path escaped projects base: {project_dir}"
)
return content
except OSError as e:
logger.warning("[Transcript] Failed to read CLI session file: %s", e)
return None
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
"""Remove the CLI's project directory for a specific working directory."""
project_dir = _cli_project_dir(sdk_cwd)
if not project_dir:
return
if os.path.isdir(project_dir):
shutil.rmtree(project_dir, ignore_errors=True)
logger.debug("[Transcript] Cleaned up CLI project dir: %s", project_dir)
logger.debug(f"[Transcript] Cleaned up CLI project dir: {project_dir}")
else:
logger.debug("[Transcript] Project dir not found: %s", project_dir)
logger.debug(f"[Transcript] Project dir not found: {project_dir}")
def write_transcript_to_tempfile(
@@ -246,7 +180,7 @@ def write_transcript_to_tempfile(
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
real_cwd = os.path.realpath(cwd)
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
logger.warning("[Transcript] cwd outside sandbox: %s", cwd)
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
return None
try:
@@ -256,17 +190,17 @@ def write_transcript_to_tempfile(
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
)
if not jsonl_path.startswith(real_cwd):
logger.warning("[Transcript] Path escaped cwd: %s", jsonl_path)
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
return None
with open(jsonl_path, "w") as f:
f.write(transcript_content)
logger.info("[Transcript] Wrote resume file: %s", jsonl_path)
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
return jsonl_path
except OSError as e:
logger.warning("[Transcript] Failed to write resume file: %s", e)
logger.warning(f"[Transcript] Failed to write resume file: {e}")
return None
@@ -410,14 +344,11 @@ async def upload_transcript(
content=json.dumps(meta).encode("utf-8"),
)
except Exception as e:
logger.warning("%s Failed to write metadata: %s", log_prefix, e)
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
logger.info(
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
log_prefix,
len(encoded),
len(content),
message_count,
f"{log_prefix} Uploaded {len(encoded)}B "
f"(stripped from {len(content)}B, msg_count={message_count})"
)
@@ -440,10 +371,10 @@ async def download_transcript(
data = await storage.retrieve(path)
content = data.decode("utf-8")
except FileNotFoundError:
logger.debug("%s No transcript in storage", log_prefix)
logger.debug(f"{log_prefix} No transcript in storage")
return None
except Exception as e:
logger.warning("%s Failed to download transcript: %s", log_prefix, e)
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
return None
# Try to load metadata (best-effort — old transcripts won't have it)
@@ -463,14 +394,10 @@ async def download_transcript(
meta = json.loads(meta_data.decode("utf-8"), fallback={})
message_count = meta.get("message_count", 0)
uploaded_at = meta.get("uploaded_at", 0.0)
except FileNotFoundError:
except (FileNotFoundError, Exception):
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
except Exception as e:
logger.debug("%s Failed to load transcript metadata: %s", log_prefix, e)
logger.info(
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
)
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
return TranscriptDownload(
content=content,
message_count=message_count,
@@ -478,171 +405,15 @@ async def download_transcript(
)
# ---------------------------------------------------------------------------
# Transcript compaction
# ---------------------------------------------------------------------------
async def delete_transcript(user_id: str, session_id: str) -> None:
"""Delete transcript from bucket storage (e.g. after resume failure)."""
from backend.util.workspace_storage import get_workspace_storage
# Transcripts above this byte threshold are compacted at download time.
COMPACT_THRESHOLD_BYTES = 400_000
storage = await get_workspace_storage()
path = _build_storage_path(user_id, session_id, storage)
def _flatten_assistant_content(blocks: list) -> str:
"""Flatten assistant content blocks into a single plain-text string."""
parts: list[str] = []
for block in blocks:
if isinstance(block, dict):
if block.get("type") == "text":
parts.append(block.get("text", ""))
elif block.get("type") == "tool_use":
parts.append(f"[tool_use: {block.get('name', '?')}]")
elif isinstance(block, str):
parts.append(block)
return "\n".join(parts) if parts else ""
def _flatten_tool_result_content(blocks: list) -> str:
"""Flatten tool_result and other content blocks into plain text.
Handles nested tool_result structures, text blocks, and raw strings.
Uses ``json.dumps`` as fallback for dict blocks without a ``text`` key
or where ``text`` is ``None``.
"""
str_parts: list[str] = []
for block in blocks:
if isinstance(block, dict) and block.get("type") == "tool_result":
inner = block.get("content", "")
if isinstance(inner, list):
for sub in inner:
if isinstance(sub, dict):
text = sub.get("text")
str_parts.append(
str(text) if text is not None else json.dumps(sub)
)
else:
str_parts.append(str(sub))
else:
str_parts.append(str(inner))
elif isinstance(block, dict) and block.get("type") == "text":
str_parts.append(str(block.get("text", "")))
elif isinstance(block, str):
str_parts.append(block)
return "\n".join(str_parts) if str_parts else ""
def _transcript_to_messages(content: str) -> list[dict]:
"""Convert JSONL transcript entries to message dicts for compress_context."""
messages: list[dict] = []
for line in content.strip().split("\n"):
if not line.strip():
continue
entry = json.loads(line, fallback=None)
if not isinstance(entry, dict):
continue
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
"isCompactSummary"
):
continue
msg = entry.get("message", {})
role = msg.get("role", "")
if not role:
continue
msg_dict: dict = {"role": role}
raw_content = msg.get("content")
if role == "assistant" and isinstance(raw_content, list):
msg_dict["content"] = _flatten_assistant_content(raw_content)
elif isinstance(raw_content, list):
msg_dict["content"] = _flatten_tool_result_content(raw_content)
else:
msg_dict["content"] = raw_content or ""
messages.append(msg_dict)
return messages
def _messages_to_transcript(messages: list[dict]) -> str:
"""Convert compressed message dicts back to JSONL transcript format."""
lines: list[str] = []
last_uuid: str | None = None
for msg in messages:
role = msg.get("role", "user")
entry_type = "assistant" if role == "assistant" else "user"
uid = str(uuid4())
content = msg.get("content", "")
if role == "assistant":
message: dict = {
"role": "assistant",
"model": "",
"id": f"{COMPACT_MSG_ID_PREFIX}{uuid4().hex[:24]}",
"type": ENTRY_TYPE_MESSAGE,
"content": [{"type": "text", "text": content}] if content else [],
"stop_reason": STOP_REASON_END_TURN,
"stop_sequence": None,
}
else:
message = {"role": role, "content": content}
entry = {
"type": entry_type,
"uuid": uid,
"parentUuid": last_uuid,
"message": message,
}
lines.append(json.dumps(entry, separators=(",", ":")))
last_uuid = uid
return "\n".join(lines) + "\n" if lines else ""
async def _run_compression(
messages: list[dict],
model: str,
cfg: ChatConfig,
log_prefix: str,
) -> CompressResult:
"""Run LLM-based compression with truncation fallback."""
try:
async with openai.AsyncOpenAI(
api_key=cfg.api_key, base_url=cfg.base_url, timeout=30.0
) as client:
return await compress_context(messages=messages, model=model, client=client)
await storage.delete(path)
logger.info(f"[Transcript] Deleted transcript for session {session_id}")
except Exception as e:
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
return await compress_context(messages=messages, model=model, client=None)
async def compact_transcript(
content: str,
log_prefix: str = "[Transcript]",
) -> str | None:
"""Compact an oversized JSONL transcript using LLM summarization.
Converts transcript entries to plain messages, runs ``compress_context``
(the same compressor used for pre-query history), and rebuilds JSONL.
Returns the compacted JSONL string, or ``None`` on failure.
"""
cfg = ChatConfig()
messages = _transcript_to_messages(content)
if len(messages) < 2:
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
return None
try:
result = await _run_compression(messages, cfg.model, cfg, log_prefix)
if not result.was_compacted:
logger.info("%s Transcript already within token budget", log_prefix)
return content
logger.info(
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
log_prefix,
result.original_token_count,
result.token_count,
result.messages_summarized,
result.messages_dropped,
)
compacted = _messages_to_transcript(result.messages)
if not validate_transcript(compacted):
logger.warning("%s Compacted transcript failed validation", log_prefix)
return None
return compacted
except Exception as e:
logger.error(
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
)
return None
logger.warning(f"[Transcript] Failed to delete transcript: {e}")

View File

@@ -31,7 +31,6 @@ class TranscriptEntry(BaseModel):
uuid: str
parentUuid: str | None
message: dict[str, Any]
isCompactSummary: bool | None = None
class TranscriptBuilder:
@@ -79,12 +78,10 @@ class TranscriptBuilder:
)
continue
# Skip STRIPPABLE_TYPES unless the entry is a compaction summary.
# Compaction summaries may have type "summary" but must be preserved
# so --resume can reconstruct the compacted conversation.
# Load all non-strippable entries (user/assistant/system/etc.)
# Skip only STRIPPABLE_TYPES to match strip_progress_entries() behavior
entry_type = data.get("type", "")
is_compact = data.get("isCompactSummary", False)
if entry_type in STRIPPABLE_TYPES and not is_compact:
if entry_type in STRIPPABLE_TYPES:
continue
entry = TranscriptEntry(
@@ -92,7 +89,6 @@ class TranscriptBuilder:
uuid=data.get("uuid") or str(uuid4()),
parentUuid=data.get("parentUuid"),
message=data.get("message", {}),
isCompactSummary=True if is_compact else None,
)
self._entries.append(entry)
self._last_uuid = entry.uuid
@@ -181,33 +177,6 @@ class TranscriptBuilder:
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
return "\n".join(lines) + "\n"
def replace_entries(self, content: str, log_prefix: str = "[Transcript]") -> None:
"""Replace all entries with compacted JSONL content.
Called after the CLI performs mid-stream compaction so the builder's
state reflects the compacted conversation instead of the full
pre-compaction history.
"""
prev_count = len(self._entries)
temp = TranscriptBuilder()
try:
temp.load_previous(content, log_prefix=log_prefix)
except Exception:
logger.exception(
"%s Failed to parse compacted transcript; keeping %d existing entries",
log_prefix,
prev_count,
)
return
self._entries = temp._entries
self._last_uuid = temp._last_uuid
logger.info(
"%s Replaced %d entries with %d compacted entries",
log_prefix,
prev_count,
len(self._entries),
)
@property
def entry_count(self) -> int:
"""Total number of entries in the complete context."""

View File

@@ -2,25 +2,14 @@
import os
import pytest
from backend.util import json
from .transcript import (
COMPACT_MSG_ID_PREFIX,
STRIPPABLE_TYPES,
_cli_project_dir,
_flatten_assistant_content,
_flatten_tool_result_content,
_messages_to_transcript,
_transcript_to_messages,
compact_transcript,
read_cli_session_file,
strip_progress_entries,
validate_transcript,
write_transcript_to_tempfile,
)
from .transcript_builder import TranscriptBuilder
def _make_jsonl(*entries: dict) -> str:
@@ -46,14 +35,6 @@ PROGRESS_ENTRY = {
"data": {"type": "bash_progress", "stdout": "running..."},
}
COMPACT_SUMMARY = {
"type": "summary",
"uuid": "cs1",
"parentUuid": None,
"isCompactSummary": True,
"message": {"role": "user", "content": "Summary of previous conversation..."},
}
VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG)
@@ -256,121 +237,6 @@ class TestStripProgressEntries:
# Should return just a newline (empty content stripped)
assert result.strip() == ""
# --- _cli_project_dir ---
class TestCliProjectDir:
def test_returns_path_for_valid_cwd(self, tmp_path, monkeypatch):
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
projects = tmp_path / "projects"
projects.mkdir()
result = _cli_project_dir("/tmp/copilot-abc")
assert result is not None
assert "projects" in result
def test_returns_none_for_path_traversal(self, tmp_path, monkeypatch):
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
projects = tmp_path / "projects"
projects.mkdir()
# A cwd that encodes to something with .. shouldn't escape
result = _cli_project_dir("/tmp/copilot-test")
# Should return a valid path (no traversal possible with alphanum encoding)
assert result is None or result.startswith(str(projects))
# --- read_cli_session_file ---
class TestReadCliSessionFile:
@pytest.mark.asyncio
async def test_reads_session_file(self, tmp_path, monkeypatch):
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
# Create the CLI project directory structure
cwd = "/tmp/copilot-testread"
import re
encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
project_dir = tmp_path / "projects" / encoded
project_dir.mkdir(parents=True)
# Write a session file
session_file = project_dir / "test-session.jsonl"
session_file.write_text(json.dumps(ASST_MSG) + "\n")
result = await read_cli_session_file(cwd)
assert result is not None
assert "assistant" in result
@pytest.mark.asyncio
async def test_returns_none_when_no_files(self, tmp_path, monkeypatch):
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
cwd = "/tmp/copilot-nofiles"
import re
encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
project_dir = tmp_path / "projects" / encoded
project_dir.mkdir(parents=True)
# No jsonl files
result = await read_cli_session_file(cwd)
assert result is None
@pytest.mark.asyncio
async def test_returns_none_when_dir_missing(self, tmp_path, monkeypatch):
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(tmp_path))
(tmp_path / "projects").mkdir()
result = await read_cli_session_file("/tmp/copilot-nonexistent")
assert result is None
# --- _transcript_to_messages / _messages_to_transcript ---
class TestTranscriptMessageConversion:
def test_roundtrip_preserves_roles(self):
transcript = _make_jsonl(USER_MSG, ASST_MSG)
messages = _transcript_to_messages(transcript)
assert len(messages) == 2
assert messages[0]["role"] == "user"
assert messages[1]["role"] == "assistant"
def test_messages_to_transcript_produces_valid_jsonl(self):
messages = [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "hello"},
]
result = _messages_to_transcript(messages)
assert validate_transcript(result) is True
def test_strips_strippable_types(self):
transcript = _make_jsonl(
{"type": "progress", "uuid": "p1", "message": {"role": "user"}},
USER_MSG,
ASST_MSG,
)
messages = _transcript_to_messages(transcript)
assert len(messages) == 2 # progress entry skipped
def test_flattens_assistant_content_blocks(self):
asst_with_blocks = {
"type": "assistant",
"uuid": "a1",
"message": {
"role": "assistant",
"content": [
{"type": "text", "text": "hello"},
{"type": "tool_use", "name": "bash"},
],
},
}
messages = _transcript_to_messages(_make_jsonl(asst_with_blocks))
assert len(messages) == 1
assert "hello" in messages[0]["content"]
assert "[tool_use: bash]" in messages[0]["content"]
def test_empty_messages_returns_empty(self):
result = _messages_to_transcript([])
assert result == ""
def test_no_strippable_entries(self):
"""When there's nothing to strip, output matches input structure."""
content = _make_jsonl(USER_MSG, ASST_MSG)
@@ -416,654 +282,3 @@ class TestTranscriptMessageConversion:
lines = result.strip().split("\n")
asst_entry = json.loads(lines[-1])
assert asst_entry["parentUuid"] == "u1" # reparented
# --- TranscriptBuilder ---
class TestTranscriptBuilderReplaceEntries:
"""Tests for TranscriptBuilder.replace_entries — the compaction sync path."""
def test_replace_entries_with_valid_content(self):
builder = TranscriptBuilder()
builder.append_user("hello")
builder.append_assistant([{"type": "text", "text": "world"}])
assert builder.entry_count == 2
# Replace with compacted content (one user + one assistant)
compacted = _make_jsonl(USER_MSG, ASST_MSG)
builder.replace_entries(compacted)
assert builder.entry_count == 2
def test_replace_entries_keeps_old_on_corrupt_content(self):
builder = TranscriptBuilder()
builder.append_user("hello")
assert builder.entry_count == 1
# Corrupt content that fails to parse
builder.replace_entries("not valid json at all\n")
# Should still have old entries (load_previous skips invalid lines,
# but if ALL lines are invalid, temp builder is empty → exception path)
assert builder.entry_count >= 0 # doesn't crash
def test_replace_entries_with_empty_content(self):
builder = TranscriptBuilder()
builder.append_user("hello")
assert builder.entry_count == 1
builder.replace_entries("")
# Empty content → load_previous returns early → temp is empty
# replace_entries swaps to empty (0 entries)
assert builder.entry_count == 0
def test_replace_entries_filters_strippable_types(self):
"""Strippable types (progress, file-history-snapshot) are filtered out."""
builder = TranscriptBuilder()
builder.append_user("hello")
content = _make_jsonl(
{"type": "progress", "uuid": "p1", "message": {}},
USER_MSG,
ASST_MSG,
)
builder.replace_entries(content)
# Only user + assistant should remain (progress filtered)
assert builder.entry_count == 2
def test_replace_entries_preserves_uuids(self):
builder = TranscriptBuilder()
content = _make_jsonl(USER_MSG, ASST_MSG)
builder.replace_entries(content)
jsonl = builder.to_jsonl()
lines = jsonl.strip().split("\n")
first = json.loads(lines[0])
assert first["uuid"] == "u1"
class TestTranscriptBuilderBasic:
def test_append_user_and_assistant(self):
builder = TranscriptBuilder()
builder.append_user("hi")
builder.append_assistant([{"type": "text", "text": "hello"}])
assert builder.entry_count == 2
assert not builder.is_empty
def test_to_jsonl_empty(self):
builder = TranscriptBuilder()
assert builder.to_jsonl() == ""
assert builder.is_empty
def test_load_previous_and_append(self):
builder = TranscriptBuilder()
content = _make_jsonl(USER_MSG, ASST_MSG)
builder.load_previous(content)
assert builder.entry_count == 2
builder.append_user("new message")
assert builder.entry_count == 3
def test_consecutive_assistant_entries_share_message_id(self):
builder = TranscriptBuilder()
builder.append_user("hi")
builder.append_assistant([{"type": "text", "text": "part1"}])
builder.append_assistant([{"type": "text", "text": "part2"}])
jsonl = builder.to_jsonl()
lines = jsonl.strip().split("\n")
asst1 = json.loads(lines[1])
asst2 = json.loads(lines[2])
assert asst1["message"]["id"] == asst2["message"]["id"]
def test_non_consecutive_assistant_entries_get_new_id(self):
builder = TranscriptBuilder()
builder.append_user("hi")
builder.append_assistant([{"type": "text", "text": "response1"}])
builder.append_user("followup")
builder.append_assistant([{"type": "text", "text": "response2"}])
jsonl = builder.to_jsonl()
lines = jsonl.strip().split("\n")
asst1 = json.loads(lines[1])
asst2 = json.loads(lines[3])
assert asst1["message"]["id"] != asst2["message"]["id"]
class TestCompactSummaryRoundtrip:
"""Verify isCompactSummary survives export→reload roundtrip."""
def test_load_previous_preserves_compact_summary(self):
"""Compaction summary with type 'summary' should not be stripped."""
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
builder = TranscriptBuilder()
builder.load_previous(content)
# summary type is in STRIPPABLE_TYPES, but isCompactSummary keeps it
assert builder.entry_count == 3
def test_export_reload_preserves_compact_summary(self):
"""Critical: isCompactSummary must survive to_jsonl → load_previous."""
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
builder1 = TranscriptBuilder()
builder1.load_previous(content)
assert builder1.entry_count == 3
exported = builder1.to_jsonl()
# Verify isCompactSummary is in the exported JSONL
first_line = json.loads(exported.strip().split("\n")[0])
assert first_line.get("isCompactSummary") is True
# Reload and verify it's still preserved
builder2 = TranscriptBuilder()
builder2.load_previous(exported)
assert builder2.entry_count == 3
def test_strip_progress_preserves_compact_summary(self):
"""strip_progress_entries should keep isCompactSummary entries."""
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
stripped = strip_progress_entries(content)
entries = [json.loads(line) for line in stripped.strip().split("\n")]
types = [e.get("type") for e in entries]
assert "summary" in types # Not stripped despite being in STRIPPABLE_TYPES
compact = [e for e in entries if e.get("isCompactSummary")]
assert len(compact) == 1
def test_regular_summary_still_stripped(self):
"""Non-compact summaries should still be stripped."""
regular_summary = {
"type": "summary",
"uuid": "rs1",
"summary": "Session summary",
}
content = _make_jsonl(regular_summary, USER_MSG, ASST_MSG)
stripped = strip_progress_entries(content)
entries = [json.loads(line) for line in stripped.strip().split("\n")]
types = [e.get("type") for e in entries]
assert "summary" not in types
def test_replace_entries_preserves_compact_summary(self):
"""replace_entries should preserve isCompactSummary entries."""
builder = TranscriptBuilder()
builder.append_user("old")
content = _make_jsonl(COMPACT_SUMMARY, USER_MSG, ASST_MSG)
builder.replace_entries(content)
assert builder.entry_count == 3
# Verify by re-exporting
exported = builder.to_jsonl()
first = json.loads(exported.strip().split("\n")[0])
assert first.get("isCompactSummary") is True
# --- _flatten_assistant_content ---
class TestFlattenAssistantContent:
def test_text_blocks(self):
blocks = [
{"type": "text", "text": "Hello"},
{"type": "text", "text": "World"},
]
assert _flatten_assistant_content(blocks) == "Hello\nWorld"
def test_tool_use_blocks(self):
blocks = [{"type": "tool_use", "name": "read_file", "id": "t1", "input": {}}]
assert _flatten_assistant_content(blocks) == "[tool_use: read_file]"
def test_mixed_blocks(self):
blocks = [
{"type": "text", "text": "Let me read that."},
{"type": "tool_use", "name": "read", "id": "t1", "input": {}},
]
result = _flatten_assistant_content(blocks)
assert "Let me read that." in result
assert "[tool_use: read]" in result
def test_string_blocks(self):
"""Plain strings in the list should be included."""
assert _flatten_assistant_content(["hello", "world"]) == "hello\nworld"
def test_empty_list(self):
assert _flatten_assistant_content([]) == ""
def test_tool_use_missing_name(self):
blocks = [{"type": "tool_use", "id": "t1", "input": {}}]
assert _flatten_assistant_content(blocks) == "[tool_use: ?]"
# --- _flatten_tool_result_content ---
class TestFlattenToolResultContent:
def test_tool_result_with_text(self):
blocks = [
{
"type": "tool_result",
"tool_use_id": "t1",
"content": [{"type": "text", "text": "file contents here"}],
}
]
assert _flatten_tool_result_content(blocks) == "file contents here"
def test_tool_result_with_string_content(self):
blocks = [
{"type": "tool_result", "tool_use_id": "t1", "content": "simple result"}
]
assert _flatten_tool_result_content(blocks) == "simple result"
def test_tool_result_with_nested_list(self):
blocks = [
{
"type": "tool_result",
"tool_use_id": "t1",
"content": [
{"type": "text", "text": "line 1"},
{"type": "text", "text": "line 2"},
],
}
]
assert _flatten_tool_result_content(blocks) == "line 1\nline 2"
def test_text_blocks(self):
blocks = [{"type": "text", "text": "some text"}]
assert _flatten_tool_result_content(blocks) == "some text"
def test_string_items(self):
assert _flatten_tool_result_content(["raw string"]) == "raw string"
def test_empty_list(self):
assert _flatten_tool_result_content([]) == ""
def test_tool_result_none_text_uses_json(self):
"""Dicts without text key fall back to json.dumps."""
blocks = [
{
"type": "tool_result",
"tool_use_id": "t1",
"content": [{"type": "image", "source": "data:..."}],
}
]
result = _flatten_tool_result_content(blocks)
assert "image" in result # json.dumps fallback includes the key
# --- _transcript_to_messages ---
class TestTranscriptToMessages:
def test_basic_conversation(self):
content = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hello"},
},
{
"type": "assistant",
"uuid": "a1",
"parentUuid": "u1",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "hi there"}],
},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 2
assert msgs[0] == {"role": "user", "content": "hello"}
assert msgs[1] == {"role": "assistant", "content": "hi there"}
def test_strips_progress_entries(self):
content = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
},
{
"type": "progress",
"uuid": "p1",
"message": {"role": "user", "content": "..."},
},
{
"type": "assistant",
"uuid": "a1",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "ok"}],
},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 2
assert msgs[0]["role"] == "user"
assert msgs[1]["role"] == "assistant"
def test_preserves_compact_summaries(self):
content = _make_jsonl(
{
"type": "summary",
"uuid": "cs1",
"isCompactSummary": True,
"message": {"role": "user", "content": "Summary of previous..."},
},
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 2
assert msgs[0]["content"] == "Summary of previous..."
def test_strips_regular_summary(self):
content = _make_jsonl(
{
"type": "summary",
"uuid": "s1",
"message": {"role": "user", "content": "Session summary"},
},
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 1
assert msgs[0]["content"] == "hi"
def test_skips_entries_without_role(self):
content = _make_jsonl(
{"type": "user", "uuid": "u1", "message": {}},
{
"type": "user",
"uuid": "u2",
"message": {"role": "user", "content": "hi"},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 1
def test_tool_result_content(self):
content = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "t1",
"content": "file contents",
}
],
},
},
)
msgs = _transcript_to_messages(content)
assert len(msgs) == 1
assert "file contents" in msgs[0]["content"]
def test_empty_content(self):
assert _transcript_to_messages("") == []
assert _transcript_to_messages(" \n ") == []
def test_invalid_json_lines_skipped(self):
content = '{"type":"user","uuid":"u1","message":{"role":"user","content":"hi"}}\nnot json\n'
msgs = _transcript_to_messages(content)
assert len(msgs) == 1
# --- _messages_to_transcript ---
class TestMessagesToTranscript:
def test_basic_roundtrip_structure(self):
messages = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi there"},
]
result = _messages_to_transcript(messages)
assert result.endswith("\n")
lines = [json.loads(line) for line in result.strip().split("\n")]
assert len(lines) == 2
# User entry
assert lines[0]["type"] == "user"
assert lines[0]["message"]["role"] == "user"
assert lines[0]["message"]["content"] == "hello"
assert lines[0]["parentUuid"] is None
# Assistant entry
assert lines[1]["type"] == "assistant"
assert lines[1]["message"]["role"] == "assistant"
assert lines[1]["message"]["content"] == [{"type": "text", "text": "hi there"}]
assert lines[1]["message"]["id"].startswith(COMPACT_MSG_ID_PREFIX)
assert lines[1]["parentUuid"] == lines[0]["uuid"]
def test_parent_uuid_chain(self):
messages = [
{"role": "user", "content": "q1"},
{"role": "assistant", "content": "a1"},
{"role": "user", "content": "q2"},
]
result = _messages_to_transcript(messages)
lines = [json.loads(line) for line in result.strip().split("\n")]
assert lines[0]["parentUuid"] is None
assert lines[1]["parentUuid"] == lines[0]["uuid"]
assert lines[2]["parentUuid"] == lines[1]["uuid"]
def test_empty_messages(self):
assert _messages_to_transcript([]) == ""
def test_assistant_empty_content(self):
messages = [{"role": "assistant", "content": ""}]
result = _messages_to_transcript(messages)
entry = json.loads(result.strip())
assert entry["message"]["content"] == []
def test_output_is_valid_transcript(self):
messages = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "world"},
]
result = _messages_to_transcript(messages)
assert validate_transcript(result)
# --- _transcript_to_messages + _messages_to_transcript roundtrip ---
class TestTranscriptCompactionRoundtrip:
def test_content_preserved_through_roundtrip(self):
"""Messages→transcript→messages preserves content."""
original = [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
{"role": "user", "content": "Thanks"},
]
transcript = _messages_to_transcript(original)
recovered = _transcript_to_messages(transcript)
assert len(recovered) == len(original)
for orig, rec in zip(original, recovered):
assert orig["role"] == rec["role"]
assert orig["content"] == rec["content"]
def test_full_transcript_to_messages_and_back(self):
"""Real-ish JSONL → messages → transcript → messages roundtrip."""
source = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "explain python"},
},
{
"type": "assistant",
"uuid": "a1",
"parentUuid": "u1",
"message": {
"role": "assistant",
"content": [
{"type": "text", "text": "Python is a programming language."}
],
},
},
{
"type": "user",
"uuid": "u2",
"parentUuid": "a1",
"message": {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "t1",
"content": "output of ls",
}
],
},
},
)
msgs1 = _transcript_to_messages(source)
assert len(msgs1) == 3
rebuilt = _messages_to_transcript(msgs1)
msgs2 = _transcript_to_messages(rebuilt)
assert len(msgs2) == len(msgs1)
for m1, m2 in zip(msgs1, msgs2):
assert m1["role"] == m2["role"]
# Content may differ in format (list vs string) but text is preserved
assert m1["content"] in m2["content"] or m2["content"] in m1["content"]
# --- compact_transcript ---
class TestCompactTranscript:
@pytest.mark.asyncio
async def test_too_few_messages_returns_none(self):
"""Transcripts with < 2 messages can't be compacted."""
single = _make_jsonl(
{"type": "user", "uuid": "u1", "message": {"role": "user", "content": "hi"}}
)
result = await compact_transcript(single)
assert result is None
@pytest.mark.asyncio
async def test_empty_transcript_returns_none(self):
result = await compact_transcript("")
assert result is None
@pytest.mark.asyncio
async def test_compaction_produces_valid_transcript(self, monkeypatch):
"""When compress_context compacts, result should be valid JSONL."""
from unittest.mock import AsyncMock
from backend.util.prompt import CompressResult
mock_result = CompressResult(
messages=[
{"role": "user", "content": "Summary of conversation"},
{"role": "assistant", "content": "Acknowledged"},
],
token_count=50,
was_compacted=True,
original_token_count=5000,
messages_summarized=10,
messages_dropped=5,
)
monkeypatch.setattr(
"backend.copilot.sdk.transcript._run_compression",
AsyncMock(return_value=mock_result),
)
source = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "msg1"},
},
{
"type": "assistant",
"uuid": "a1",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "reply1"}],
},
},
{
"type": "user",
"uuid": "u2",
"message": {"role": "user", "content": "msg2"},
},
)
result = await compact_transcript(source)
assert result is not None
assert validate_transcript(result)
# Verify compacted content
msgs = _transcript_to_messages(result)
assert len(msgs) == 2
assert msgs[0]["content"] == "Summary of conversation"
@pytest.mark.asyncio
async def test_no_compaction_needed_returns_original(self, monkeypatch):
"""When compress_context says no compaction needed, return original."""
from unittest.mock import AsyncMock
from backend.util.prompt import CompressResult
mock_result = CompressResult(
messages=[], token_count=100, was_compacted=False, original_token_count=100
)
monkeypatch.setattr(
"backend.copilot.sdk.transcript._run_compression",
AsyncMock(return_value=mock_result),
)
source = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
},
{
"type": "assistant",
"uuid": "a1",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "hello"}],
},
},
)
result = await compact_transcript(source)
assert result == source # Unchanged
@pytest.mark.asyncio
async def test_compression_failure_returns_none(self, monkeypatch):
"""When _run_compression raises, compact_transcript returns None."""
from unittest.mock import AsyncMock
monkeypatch.setattr(
"backend.copilot.sdk.transcript._run_compression",
AsyncMock(side_effect=RuntimeError("LLM unavailable")),
)
source = _make_jsonl(
{
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
},
{
"type": "assistant",
"uuid": "a1",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "hello"}],
},
},
)
result = await compact_transcript(source)
assert result is None

View File

@@ -23,11 +23,6 @@ from typing import Any, Literal
import orjson
from backend.api.model import CopilotCompletionPayload
from backend.data.notification_bus import (
AsyncRedisNotificationEventBus,
NotificationEvent,
)
from backend.data.redis_client import get_redis_async
from .config import ChatConfig
@@ -43,7 +38,6 @@ from .response_model import (
logger = logging.getLogger(__name__)
config = ChatConfig()
_notification_bus = AsyncRedisNotificationEventBus()
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
_local_sessions: dict[str, asyncio.Task] = {}
@@ -751,29 +745,6 @@ async def mark_session_completed(
# Clean up local session reference if exists
_local_sessions.pop(session_id, None)
# Publish copilot completion notification via WebSocket
if meta:
parsed = _parse_session_meta(meta, session_id)
if parsed.user_id:
try:
await _notification_bus.publish(
NotificationEvent(
user_id=parsed.user_id,
payload=CopilotCompletionPayload(
type="copilot_completion",
event="session_completed",
session_id=session_id,
status=status,
),
)
)
except Exception as e:
logger.warning(
f"Failed to publish copilot completion notification "
f"for session {session_id}: {e}"
)
return True

View File

@@ -829,12 +829,8 @@ class AgentFixer:
For nodes whose block has category "AI", this function ensures that the
input_default has a "model" parameter set to one of the allowed models.
If missing or set to an unsupported value, it is replaced with the
appropriate default.
Blocks that define their own ``enum`` constraint on the ``model`` field
in their inputSchema (e.g. PerplexityBlock) are validated against that
enum instead of the generic allowed set.
If missing or set to an unsupported value, it is replaced with
default_model.
Args:
agent: The agent dictionary to fix
@@ -844,7 +840,7 @@ class AgentFixer:
Returns:
The fixed agent dictionary
"""
generic_allowed_models = {"gpt-4o", "claude-opus-4-6"}
allowed_models = {"gpt-4o", "claude-opus-4-6"}
# Create a mapping of block_id to block for quick lookup
block_map = {block.get("id"): block for block in blocks}
@@ -872,36 +868,20 @@ class AgentFixer:
input_default = node.get("input_default", {})
current_model = input_default.get("model")
# Determine allowed models and default from the block's schema.
# Blocks with a block-specific enum on the model field (e.g.
# PerplexityBlock) use their own enum values; others use the
# generic set.
model_schema = (
block.get("inputSchema", {}).get("properties", {}).get("model", {})
)
block_model_enum = model_schema.get("enum")
if block_model_enum:
allowed_models = set(block_model_enum)
fallback_model = model_schema.get("default", block_model_enum[0])
else:
allowed_models = generic_allowed_models
fallback_model = default_model
if current_model not in allowed_models:
block_name = block.get("name", "Unknown AI Block")
if current_model is None:
self.add_fix_log(
f"Added model parameter '{fallback_model}' to AI "
f"Added model parameter '{default_model}' to AI "
f"block node {node_id} ({block_name})"
)
else:
self.add_fix_log(
f"Replaced unsupported model '{current_model}' "
f"with '{fallback_model}' on AI block node "
f"with '{default_model}' on AI block node "
f"{node_id} ({block_name})"
)
input_default["model"] = fallback_model
input_default["model"] = default_model
node["input_default"] = input_default
fixed_count += 1

View File

@@ -475,111 +475,6 @@ class TestFixAiModelParameter:
assert result["nodes"][0]["input_default"]["model"] == "claude-opus-4-6"
def test_block_specific_enum_uses_block_default(self):
"""Blocks with their own model enum (e.g. PerplexityBlock) should use
the block's allowed models and default, not the generic ones."""
fixer = AgentFixer()
block_id = generate_uuid()
node = _make_node(
node_id="n1",
block_id=block_id,
input_default={"model": "gpt-5.2-2025-12-11"},
)
agent = _make_agent(nodes=[node])
blocks = [
{
"id": block_id,
"name": "PerplexityBlock",
"categories": [{"category": "AI"}],
"inputSchema": {
"properties": {
"model": {
"type": "string",
"enum": [
"perplexity/sonar",
"perplexity/sonar-pro",
"perplexity/sonar-deep-research",
],
"default": "perplexity/sonar",
}
},
},
}
]
result = fixer.fix_ai_model_parameter(agent, blocks)
assert result["nodes"][0]["input_default"]["model"] == "perplexity/sonar"
def test_block_specific_enum_valid_model_unchanged(self):
"""A valid block-specific model should not be replaced."""
fixer = AgentFixer()
block_id = generate_uuid()
node = _make_node(
node_id="n1",
block_id=block_id,
input_default={"model": "perplexity/sonar-pro"},
)
agent = _make_agent(nodes=[node])
blocks = [
{
"id": block_id,
"name": "PerplexityBlock",
"categories": [{"category": "AI"}],
"inputSchema": {
"properties": {
"model": {
"type": "string",
"enum": [
"perplexity/sonar",
"perplexity/sonar-pro",
"perplexity/sonar-deep-research",
],
"default": "perplexity/sonar",
}
},
},
}
]
result = fixer.fix_ai_model_parameter(agent, blocks)
assert result["nodes"][0]["input_default"]["model"] == "perplexity/sonar-pro"
def test_block_specific_enum_missing_model_gets_block_default(self):
"""Missing model on a block with enum should use the block's default."""
fixer = AgentFixer()
block_id = generate_uuid()
node = _make_node(node_id="n1", block_id=block_id, input_default={})
agent = _make_agent(nodes=[node])
blocks = [
{
"id": block_id,
"name": "PerplexityBlock",
"categories": [{"category": "AI"}],
"inputSchema": {
"properties": {
"model": {
"type": "string",
"enum": [
"perplexity/sonar",
"perplexity/sonar-pro",
"perplexity/sonar-deep-research",
],
"default": "perplexity/sonar",
}
},
},
}
]
result = fixer.fix_ai_model_parameter(agent, blocks)
assert result["nodes"][0]["input_default"]["model"] == "perplexity/sonar"
class TestFixAgentExecutorBlocks:
"""Tests for fix_agent_executor_blocks."""

View File

@@ -21,11 +21,9 @@ Lifecycle
Cost control
------------
Sandboxes are created with a configurable ``on_timeout`` lifecycle action
(default: ``"pause"``) and ``auto_resume`` (default: ``True``). The explicit
per-turn ``pause_sandbox()`` call is the primary mechanism; the lifecycle
timeout is a safety net (default: 5 min). ``auto_resume`` ensures that paused
sandboxes wake transparently on SDK activity, making the aggressive safety-net
timeout safe. Paused sandboxes are free.
(default: ``"pause"``). The explicit per-turn ``pause_sandbox()`` call is the
primary mechanism; the lifecycle setting is a safety net. Paused sandboxes are
free.
The sandbox_id is stored in Redis. The same key doubles as a creation lock:
a ``"creating"`` sentinel value is written with a short TTL while a new sandbox
@@ -42,7 +40,6 @@ import logging
from typing import Any, Awaitable, Callable, Literal
from e2b import AsyncSandbox
from e2b.sandbox.sandbox_api import SandboxLifecycle
from backend.data.redis_client import get_redis_async
@@ -119,10 +116,9 @@ async def get_or_create_sandbox(
removes the need for a separate lock key.
*timeout* controls how long the e2b sandbox may run continuously before
the ``on_timeout`` lifecycle rule fires (default: 5 min).
the ``on_timeout`` lifecycle rule fires (default: 3 h).
*on_timeout* controls what happens on timeout: ``"pause"`` (default, free)
or ``"kill"``. When ``"pause"``, ``auto_resume`` is enabled so paused
sandboxes wake transparently on SDK activity.
or ``"kill"``.
"""
redis = await get_redis_async()
key = _sandbox_key(session_id)
@@ -160,15 +156,11 @@ async def get_or_create_sandbox(
# We hold the slot — create the sandbox.
try:
lifecycle = SandboxLifecycle(
on_timeout=on_timeout,
auto_resume=on_timeout == "pause",
)
sandbox = await AsyncSandbox.create(
template=template,
api_key=api_key,
timeout=timeout,
lifecycle=lifecycle,
lifecycle={"on_timeout": on_timeout},
)
try:
await _set_stored_sandbox_id(session_id, sandbox.sandbox_id)

View File

@@ -157,17 +157,14 @@ class TestGetOrCreateSandbox:
assert result is new_sb
mock_cls.create.assert_awaited_once()
# Verify lifecycle: pause + auto_resume enabled
# Verify lifecycle param is set
_, kwargs = mock_cls.create.call_args
assert kwargs.get("lifecycle") == {
"on_timeout": "pause",
"auto_resume": True,
}
assert kwargs.get("lifecycle") == {"on_timeout": "pause"}
# sandbox_id should be saved to Redis
redis.set.assert_awaited()
def test_create_with_on_timeout_kill(self):
"""on_timeout='kill' disables auto_resume automatically."""
"""on_timeout='kill' is passed through to AsyncSandbox.create."""
new_sb = _mock_sandbox("sb-new")
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
with (
@@ -182,10 +179,7 @@ class TestGetOrCreateSandbox:
)
_, kwargs = mock_cls.create.call_args
assert kwargs.get("lifecycle") == {
"on_timeout": "kill",
"auto_resume": False,
}
assert kwargs.get("lifecycle") == {"on_timeout": "kill"}
def test_create_failure_releases_slot(self):
"""If sandbox creation fails, the Redis creation slot is deleted."""

View File

@@ -8,16 +8,11 @@ from pydantic_core import PydanticUndefined
from backend.blocks._base import AnyBlockSchema
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
from backend.data import db
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.db_accessors import workspace_db
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.executor.utils import block_usage_cost
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.clients import get_database_manager_async_client
from backend.util.exceptions import BlockError, InsufficientBalanceError
from backend.util.type import coerce_inputs_to_schema
from backend.util.exceptions import BlockError
from .models import BlockOutputResponse, ErrorResponse, ToolResponseBase
from .utils import match_credentials_to_requirements
@@ -25,26 +20,6 @@ from .utils import match_credentials_to_requirements
logger = logging.getLogger(__name__)
async def _get_credits(user_id: str) -> int:
"""Get user credits using the adapter pattern (RPC when Prisma unavailable)."""
if not db.is_connected():
return await get_database_manager_async_client().get_credits(user_id)
credit_model = await get_user_credit_model(user_id)
return await credit_model.get_credits(user_id)
async def _spend_credits(
user_id: str, cost: int, metadata: UsageTransactionMetadata
) -> int:
"""Spend user credits using the adapter pattern (RPC when Prisma unavailable)."""
if not db.is_connected():
return await get_database_manager_async_client().spend_credits(
user_id, cost, metadata
)
credit_model = await get_user_credit_model(user_id)
return await credit_model.spend_credits(user_id, cost, metadata)
def get_inputs_from_schema(
input_schema: dict[str, Any],
exclude_fields: set[str] | None = None,
@@ -136,23 +111,6 @@ async def execute_block(
session_id=session_id,
)
# Coerce non-matching data types to the expected input schema.
coerce_inputs_to_schema(input_data, block.input_schema)
# Pre-execution credit check
cost, cost_filter = block_usage_cost(block, input_data)
has_cost = cost > 0
if has_cost:
balance = await _get_credits(user_id)
if balance < cost:
return ErrorResponse(
message=(
f"Insufficient credits to run '{block.name}'. "
"Please top up your credits to continue."
),
session_id=session_id,
)
# Execute the block and collect outputs
outputs: dict[str, list[Any]] = defaultdict(list)
async for output_name, output_data in block.execute(
@@ -161,37 +119,6 @@ async def execute_block(
):
outputs[output_name].append(output_data)
# Charge credits for block execution
if has_cost:
try:
await _spend_credits(
user_id=user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=synthetic_graph_id,
graph_id=synthetic_graph_id,
node_id=synthetic_node_id,
node_exec_id=node_exec_id,
block_id=block_id,
block=block.name,
input=cost_filter,
reason="copilot_block_execution",
),
)
except InsufficientBalanceError:
logger.warning(
"Post-exec credit charge failed for block %s (cost=%d)",
block.name,
cost,
)
return ErrorResponse(
message=(
f"Insufficient credits to complete '{block.name}'. "
"Please top up your credits to continue."
),
session_id=session_id,
)
return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully",
block_id=block_id,
@@ -202,16 +129,16 @@ async def execute_block(
)
except BlockError as e:
logger.warning("Block execution failed: %s", e)
logger.warning(f"Block execution failed: {e}")
return ErrorResponse(
message=f"Block execution failed: {e}",
error=str(e),
session_id=session_id,
)
except Exception as e:
logger.error("Unexpected error executing block: %s", e, exc_info=True)
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
return ErrorResponse(
message="An unexpected error occurred while executing the block",
message=f"Failed to execute block: {str(e)}",
error=str(e),
session_id=session_id,
)

View File

@@ -1,506 +0,0 @@
"""Tests for execute_block — credit charging and type coercion."""
from collections.abc import AsyncIterator
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks._base import BlockType
from backend.copilot.tools.helpers import execute_block
from backend.copilot.tools.models import BlockOutputResponse, ErrorResponse
_USER = "test-user-helpers"
_SESSION = "test-session-helpers"
def _make_block(block_id: str = "block-1", name: str = "TestBlock"):
"""Create a minimal mock block for execute_block()."""
mock = MagicMock()
mock.id = block_id
mock.name = name
mock.block_type = BlockType.STANDARD
mock.input_schema = MagicMock()
mock.input_schema.get_credentials_fields_info.return_value = {}
async def _execute(
input_data: dict, **kwargs: Any
) -> AsyncIterator[tuple[str, Any]]:
yield "result", "ok"
mock.execute = _execute
return mock
def _patch_workspace():
"""Patch workspace_db to return a mock workspace."""
mock_workspace = MagicMock()
mock_workspace.id = "ws-1"
mock_ws_db = MagicMock()
mock_ws_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
return patch("backend.copilot.tools.helpers.workspace_db", return_value=mock_ws_db)
# ---------------------------------------------------------------------------
# Credit charging tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestExecuteBlockCreditCharging:
async def test_charges_credits_when_cost_is_positive(self):
"""Block with cost > 0 should call spend_credits after execution."""
block = _make_block()
mock_spend = AsyncMock()
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {"key": "val"}),
),
patch(
"backend.copilot.tools.helpers._get_credits",
new_callable=AsyncMock,
return_value=100,
),
patch(
"backend.copilot.tools.helpers._spend_credits",
new_callable=AsyncMock,
side_effect=mock_spend,
),
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={"text": "hello"},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, BlockOutputResponse)
assert result.success is True
mock_spend.assert_awaited_once()
call_kwargs = mock_spend.call_args.kwargs
assert call_kwargs["cost"] == 10
assert call_kwargs["metadata"].reason == "copilot_block_execution"
async def test_returns_error_when_insufficient_credits_before_exec(self):
"""Pre-execution check should return ErrorResponse when balance < cost."""
block = _make_block()
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {}),
),
patch(
"backend.copilot.tools.helpers._get_credits",
new_callable=AsyncMock,
return_value=5, # balance < cost (10)
),
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, ErrorResponse)
assert "Insufficient credits" in result.message
async def test_no_charge_when_cost_is_zero(self):
"""Block with cost 0 should not call spend_credits."""
block = _make_block()
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(0, {}),
),
patch(
"backend.copilot.tools.helpers._get_credits",
) as mock_get_credits,
patch(
"backend.copilot.tools.helpers._spend_credits",
) as mock_spend_credits,
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, BlockOutputResponse)
assert result.success is True
# Credit functions should not be called at all for zero-cost blocks
mock_get_credits.assert_not_awaited()
mock_spend_credits.assert_not_awaited()
async def test_returns_error_on_post_exec_insufficient_balance(self):
"""If charging fails after execution, return ErrorResponse."""
from backend.util.exceptions import InsufficientBalanceError
block = _make_block()
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {}),
),
patch(
"backend.copilot.tools.helpers._get_credits",
new_callable=AsyncMock,
return_value=15, # passes pre-check
),
patch(
"backend.copilot.tools.helpers._spend_credits",
new_callable=AsyncMock,
side_effect=InsufficientBalanceError(
"Low balance", _USER, 5, 10
), # fails during actual charge (race with concurrent spend)
),
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, ErrorResponse)
assert "Insufficient credits" in result.message
# ---------------------------------------------------------------------------
# Type coercion tests
# ---------------------------------------------------------------------------
def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
"""Create a mock input_schema with model_fields matching the given annotations."""
schema = MagicMock()
model_fields = {}
for name, ann in annotations.items():
field = MagicMock()
field.annotation = ann
model_fields[name] = field
schema.model_fields = model_fields
return schema
def _make_coerce_block(
block_id: str,
name: str,
annotations: dict[str, Any],
outputs: dict[str, list[Any]] | None = None,
) -> MagicMock:
"""Create a mock block with typed annotations and a simple execute method."""
block = MagicMock()
block.id = block_id
block.name = name
block.input_schema = _make_block_schema(annotations)
captured_inputs: dict[str, Any] = {}
async def mock_execute(input_data: dict, **_kwargs: Any):
captured_inputs.update(input_data)
for output_name, values in (outputs or {"result": ["ok"]}).items():
for v in values:
yield output_name, v
block.execute = mock_execute
block._captured_inputs = captured_inputs
return block
_TEST_SESSION_ID = "test-session-coerce"
_TEST_USER_ID = "test-user-coerce"
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_nested_list():
"""JSON string → list[list[str]] (Google Sheets CSV import case)."""
block = _make_coerce_block(
"sheets-write",
"Google Sheets Write",
{"values": list[list[str]], "spreadsheet_id": str},
)
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="ws-1")
)
with patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
):
response = await execute_block(
block=block,
block_id="sheets-write",
input_data={
"values": '[["Name","Score"],["Alice","90"],["Bob","85"]]',
"spreadsheet_id": "abc123",
},
user_id=_TEST_USER_ID,
session_id=_TEST_SESSION_ID,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(response, BlockOutputResponse)
assert response.success is True
assert block._captured_inputs["values"] == [
["Name", "Score"],
["Alice", "90"],
["Bob", "85"],
]
assert isinstance(block._captured_inputs["values"], list)
assert isinstance(block._captured_inputs["values"][0], list)
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_list():
"""JSON string → list[str]."""
block = _make_coerce_block(
"list-block",
"List Block",
{"items": list[str]},
)
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="ws-1")
)
with patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
):
response = await execute_block(
block=block,
block_id="list-block",
input_data={"items": '["a","b","c"]'},
user_id=_TEST_USER_ID,
session_id=_TEST_SESSION_ID,
node_exec_id="exec-2",
matched_credentials={},
)
assert isinstance(response, BlockOutputResponse)
assert block._captured_inputs["items"] == ["a", "b", "c"]
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_dict():
"""JSON string → dict[str, str]."""
block = _make_coerce_block(
"dict-block",
"Dict Block",
{"config": dict[str, str]},
)
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="ws-1")
)
with patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
):
response = await execute_block(
block=block,
block_id="dict-block",
input_data={"config": '{"key": "value", "foo": "bar"}'},
user_id=_TEST_USER_ID,
session_id=_TEST_SESSION_ID,
node_exec_id="exec-3",
matched_credentials={},
)
assert isinstance(response, BlockOutputResponse)
assert block._captured_inputs["config"] == {"key": "value", "foo": "bar"}
@pytest.mark.asyncio(loop_scope="session")
async def test_no_coercion_when_type_matches():
"""Already-correct types pass through without coercion."""
block = _make_coerce_block(
"pass-through",
"Pass Through",
{"values": list[list[str]], "name": str},
)
original_values = [["a", "b"], ["c", "d"]]
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="ws-1")
)
with patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
):
response = await execute_block(
block=block,
block_id="pass-through",
input_data={"values": original_values, "name": "test"},
user_id=_TEST_USER_ID,
session_id=_TEST_SESSION_ID,
node_exec_id="exec-4",
matched_credentials={},
)
assert isinstance(response, BlockOutputResponse)
assert block._captured_inputs["values"] == original_values
assert block._captured_inputs["name"] == "test"
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_string_to_int():
"""String number → int."""
block = _make_coerce_block(
"int-block",
"Int Block",
{"count": int},
)
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="ws-1")
)
with patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
):
response = await execute_block(
block=block,
block_id="int-block",
input_data={"count": "42"},
user_id=_TEST_USER_ID,
session_id=_TEST_SESSION_ID,
node_exec_id="exec-5",
matched_credentials={},
)
assert isinstance(response, BlockOutputResponse)
assert block._captured_inputs["count"] == 42
assert isinstance(block._captured_inputs["count"], int)
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_skips_none_values():
"""None values are not coerced (they may be optional fields)."""
block = _make_coerce_block(
"optional-block",
"Optional Block",
{"data": list[str], "label": str},
)
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="ws-1")
)
with patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
):
response = await execute_block(
block=block,
block_id="optional-block",
input_data={"label": "test"},
user_id=_TEST_USER_ID,
session_id=_TEST_SESSION_ID,
node_exec_id="exec-6",
matched_credentials={},
)
assert isinstance(response, BlockOutputResponse)
assert "data" not in block._captured_inputs
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_union_type_preserves_valid_member():
"""Union-typed fields should not be coerced when the value matches a member."""
block = _make_coerce_block(
"union-block",
"Union Block",
{"content": str | list[str]},
)
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="ws-1")
)
with patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
):
response = await execute_block(
block=block,
block_id="union-block",
input_data={"content": ["a", "b"]},
user_id=_TEST_USER_ID,
session_id=_TEST_SESSION_ID,
node_exec_id="exec-7",
matched_credentials={},
)
assert isinstance(response, BlockOutputResponse)
assert block._captured_inputs["content"] == ["a", "b"]
assert isinstance(block._captured_inputs["content"], list)
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_inner_elements_of_generic():
"""Inner elements of generic containers are recursively coerced."""
block = _make_coerce_block(
"inner-coerce",
"Inner Coerce",
{"values": list[str]},
)
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="ws-1")
)
with patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
):
response = await execute_block(
block=block,
block_id="inner-coerce",
input_data={"values": [1, 2, 3]},
user_id=_TEST_USER_ID,
session_id=_TEST_SESSION_ID,
node_exec_id="exec-8",
matched_credentials={},
)
assert isinstance(response, BlockOutputResponse)
assert block._captured_inputs["values"] == ["1", "2", "3"]
assert all(isinstance(v, str) for v in block._captured_inputs["values"])

View File

@@ -34,11 +34,6 @@ logger = logging.getLogger(__name__)
_AUTH_STATUS_CODES = {401, 403}
def _service_name(host: str) -> str:
"""Strip the 'mcp.' prefix from an MCP hostname: 'mcp.sentry.dev''sentry.dev'"""
return host[4:] if host.startswith("mcp.") else host
class RunMCPToolTool(BaseTool):
"""
Tool for discovering and executing tools on any MCP server.
@@ -308,8 +303,8 @@ class RunMCPToolTool(BaseTool):
)
return ErrorResponse(
message=(
f"Unable to connect to {_service_name(server_host(server_url))} "
" no credentials configured."
f"The MCP server at {server_host(server_url)} requires authentication, "
"but no credential configuration was found."
),
session_id=session_id,
)
@@ -317,13 +312,15 @@ class RunMCPToolTool(BaseTool):
missing_creds_list = list(missing_creds_dict.values())
host = server_host(server_url)
service = _service_name(host)
return SetupRequirementsResponse(
message=(f"To continue, sign in to {service} and approve access."),
message=(
f"The MCP server at {host} requires authentication. "
"Please connect your credentials to continue."
),
session_id=session_id,
setup_info=SetupInfo(
agent_id=server_url,
agent_name=service,
agent_name=f"MCP: {host}",
user_readiness=UserReadiness(
has_all_credentials=False,
missing_credentials=missing_creds_dict,

View File

@@ -756,4 +756,4 @@ async def test_build_setup_requirements_returns_setup_response():
)
assert isinstance(result, SetupRequirementsResponse)
assert result.setup_info.agent_id == _SERVER_URL
assert "sign in" in result.message.lower()
assert "authentication" in result.message.lower()

View File

@@ -100,31 +100,19 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.OLLAMA_DOLPHIN: 1,
LlmModel.OPENAI_GPT_OSS_120B: 1,
LlmModel.OPENAI_GPT_OSS_20B: 1,
LlmModel.GEMINI_2_5_PRO_PREVIEW: 4,
LlmModel.GEMINI_2_5_PRO: 4,
LlmModel.GEMINI_3_1_PRO_PREVIEW: 5,
LlmModel.GEMINI_3_FLASH_PREVIEW: 2,
LlmModel.GEMINI_3_PRO_PREVIEW: 5,
LlmModel.GEMINI_2_5_FLASH: 1,
LlmModel.GEMINI_2_0_FLASH: 1,
LlmModel.GEMINI_3_1_FLASH_LITE_PREVIEW: 1,
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: 1,
LlmModel.GEMINI_2_0_FLASH_LITE: 1,
LlmModel.MISTRAL_NEMO: 1,
LlmModel.MISTRAL_LARGE_3: 2,
LlmModel.MISTRAL_MEDIUM_3_1: 2,
LlmModel.MISTRAL_SMALL_3_2: 1,
LlmModel.CODESTRAL: 1,
LlmModel.COHERE_COMMAND_R_08_2024: 1,
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: 3,
LlmModel.COHERE_COMMAND_A_03_2025: 3,
LlmModel.COHERE_COMMAND_A_TRANSLATE_08_2025: 3,
LlmModel.COHERE_COMMAND_A_REASONING_08_2025: 6,
LlmModel.COHERE_COMMAND_A_VISION_07_2025: 3,
LlmModel.DEEPSEEK_CHAT: 2,
LlmModel.DEEPSEEK_R1_0528: 1,
LlmModel.PERPLEXITY_SONAR: 1,
LlmModel.PERPLEXITY_SONAR_PRO: 5,
LlmModel.PERPLEXITY_SONAR_REASONING_PRO: 5,
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: 10,
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: 1,
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: 1,
@@ -132,7 +120,6 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.AMAZON_NOVA_MICRO_V1: 1,
LlmModel.AMAZON_NOVA_PRO_V1: 1,
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
LlmModel.MICROSOFT_PHI_4: 1,
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
LlmModel.META_LLAMA_4_SCOUT: 1,
LlmModel.META_LLAMA_4_MAVERICK: 1,
@@ -140,7 +127,6 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.LLAMA_API_LLAMA4_MAVERICK: 1,
LlmModel.LLAMA_API_LLAMA3_3_8B: 1,
LlmModel.LLAMA_API_LLAMA3_3_70B: 1,
LlmModel.GROK_3: 3,
LlmModel.GROK_4: 9,
LlmModel.GROK_4_FAST: 1,
LlmModel.GROK_4_1_FAST: 1,

View File

@@ -512,10 +512,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
list_workspace_files = d.list_workspace_files
soft_delete_workspace_file = d.soft_delete_workspace_file
# ============ Credits ============ #
spend_credits = d.spend_credits
get_credits = d.get_credits
# ============ Understanding ============ #
get_business_understanding = d.get_business_understanding
upsert_business_understanding = d.upsert_business_understanding

View File

@@ -1,750 +0,0 @@
import asyncio
import csv
import io
import logging
import os
import re
import socket
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Literal, Optional
from uuid import uuid4
import prisma.enums
import prisma.models
import prisma.types
from prisma.errors import UniqueViolationError
from pydantic import BaseModel, EmailStr, TypeAdapter, ValidationError
from backend.data.db import transaction
from backend.data.model import User
from backend.data.redis_client import get_redis_async
from backend.data.tally import get_business_understanding_input_from_tally, mask_email
from backend.data.understanding import (
BusinessUnderstandingInput,
merge_business_understanding_data,
)
from backend.data.user import get_user_by_email, get_user_by_id
from backend.executor.cluster_lock import AsyncClusterLock
from backend.util.exceptions import (
NotAuthorizedError,
NotFoundError,
PreconditionFailed,
)
from backend.util.json import SafeJson
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
_settings = Settings()
_WORKER_ID = f"{socket.gethostname()}:{os.getpid()}"
_tally_seed_tasks: dict[str, asyncio.Task] = {}
_TALLY_STALE_SECONDS = 300
_MAX_TALLY_ERROR_LENGTH = 200
_email_adapter = TypeAdapter(EmailStr)
MAX_BULK_INVITE_FILE_BYTES = 1024 * 1024
MAX_BULK_INVITE_ROWS = 500
class InvitedUserRecord(BaseModel):
id: str
email: str
status: prisma.enums.InvitedUserStatus
auth_user_id: Optional[str] = None
name: Optional[str] = None
tally_understanding: Optional[dict[str, Any]] = None
tally_status: prisma.enums.TallyComputationStatus
tally_computed_at: Optional[datetime] = None
tally_error: Optional[str] = None
created_at: datetime
updated_at: datetime
@classmethod
def from_db(cls, invited_user: "prisma.models.InvitedUser") -> "InvitedUserRecord":
payload = (
invited_user.tallyUnderstanding
if isinstance(invited_user.tallyUnderstanding, dict)
else None
)
return cls(
id=invited_user.id,
email=invited_user.email,
status=invited_user.status,
auth_user_id=invited_user.authUserId,
name=invited_user.name,
tally_understanding=payload,
tally_status=invited_user.tallyStatus,
tally_computed_at=invited_user.tallyComputedAt,
tally_error=invited_user.tallyError,
created_at=invited_user.createdAt,
updated_at=invited_user.updatedAt,
)
class BulkInvitedUserRowResult(BaseModel):
row_number: int
email: Optional[str] = None
name: Optional[str] = None
status: Literal["CREATED", "SKIPPED", "ERROR"]
message: str
invited_user: Optional[InvitedUserRecord] = None
class BulkInvitedUsersResult(BaseModel):
created_count: int
skipped_count: int
error_count: int
results: list[BulkInvitedUserRowResult]
@dataclass(frozen=True)
class _ParsedInviteRow:
row_number: int
email: str
name: Optional[str]
def normalize_email(email: str) -> str:
return email.strip().lower()
def _normalize_name(name: Optional[str]) -> Optional[str]:
if name is None:
return None
normalized = name.strip()
return normalized or None
def _default_profile_name(email: str, preferred_name: Optional[str]) -> str:
if preferred_name:
return preferred_name
local_part = email.split("@", 1)[0].strip()
return local_part or "user"
def _sanitize_username_base(email: str) -> str:
local_part = email.split("@", 1)[0].lower()
sanitized = re.sub(r"[^a-z0-9-]", "", local_part)
sanitized = sanitized.strip("-")
return sanitized[:40] or "user"
async def _generate_unique_profile_username(email: str, tx) -> str:
base = _sanitize_username_base(email)
for _ in range(2):
candidate = f"{base}-{uuid4().hex[:6]}"
existing = await prisma.models.Profile.prisma(tx).find_unique(
where={"username": candidate}
)
if existing is None:
return candidate
raise RuntimeError(f"Unable to generate unique username for {email}")
async def _ensure_default_profile(
user_id: str,
email: str,
preferred_name: Optional[str],
tx,
) -> None:
existing_profile = await prisma.models.Profile.prisma(tx).find_unique(
where={"userId": user_id}
)
if existing_profile is not None:
return
username = await _generate_unique_profile_username(email, tx)
await prisma.models.Profile.prisma(tx).create(
data=prisma.types.ProfileCreateInput(
userId=user_id,
name=_default_profile_name(email, preferred_name),
username=username,
description="I'm new here",
links=[],
avatarUrl="",
)
)
async def _ensure_default_onboarding(user_id: str, tx) -> None:
await prisma.models.UserOnboarding.prisma(tx).upsert(
where={"userId": user_id},
data={
"create": prisma.types.UserOnboardingCreateInput(userId=user_id),
"update": {},
},
)
async def _apply_tally_understanding(
user_id: str,
invited_user: "prisma.models.InvitedUser",
tx,
) -> None:
if not isinstance(invited_user.tallyUnderstanding, dict):
return
try:
input_data = BusinessUnderstandingInput.model_validate(
invited_user.tallyUnderstanding
)
except Exception:
logger.warning(
"Malformed tallyUnderstanding for invited user %s; skipping",
invited_user.id,
exc_info=True,
)
return
payload = merge_business_understanding_data({}, input_data)
await prisma.models.CoPilotUnderstanding.prisma(tx).upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "data": SafeJson(payload)},
"update": {"data": SafeJson(payload)},
},
)
async def list_invited_users(
page: int = 1,
page_size: int = 50,
) -> tuple[list[InvitedUserRecord], int]:
total = await prisma.models.InvitedUser.prisma().count()
invited_users = await prisma.models.InvitedUser.prisma().find_many(
order={"createdAt": "desc"},
skip=(page - 1) * page_size,
take=page_size,
)
return [InvitedUserRecord.from_db(iu) for iu in invited_users], total
async def create_invited_user(
email: str, name: Optional[str] = None
) -> InvitedUserRecord:
normalized_email = normalize_email(email)
normalized_name = _normalize_name(name)
existing_user = await prisma.models.User.prisma().find_unique(
where={"email": normalized_email}
)
if existing_user is not None:
raise PreconditionFailed("An active user with this email already exists")
existing_invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"email": normalized_email}
)
if existing_invited_user is not None:
raise PreconditionFailed("An invited user with this email already exists")
try:
invited_user = await prisma.models.InvitedUser.prisma().create(
data={
"email": normalized_email,
"name": normalized_name,
"status": prisma.enums.InvitedUserStatus.INVITED,
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
}
)
except UniqueViolationError:
raise PreconditionFailed("An invited user with this email already exists")
schedule_invited_user_tally_precompute(invited_user.id)
return InvitedUserRecord.from_db(invited_user)
async def revoke_invited_user(invited_user_id: str) -> InvitedUserRecord:
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"id": invited_user_id}
)
if invited_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
if invited_user.status == prisma.enums.InvitedUserStatus.CLAIMED:
raise PreconditionFailed("Claimed invited users cannot be revoked")
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
return InvitedUserRecord.from_db(invited_user)
revoked_user = await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={"status": prisma.enums.InvitedUserStatus.REVOKED},
)
if revoked_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
return InvitedUserRecord.from_db(revoked_user)
async def retry_invited_user_tally(invited_user_id: str) -> InvitedUserRecord:
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"id": invited_user_id}
)
if invited_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
raise PreconditionFailed("Revoked invited users cannot retry Tally seeding")
refreshed_user = await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyUnderstanding": None,
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
"tallyComputedAt": None,
"tallyError": None,
},
)
if refreshed_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
schedule_invited_user_tally_precompute(invited_user_id)
return InvitedUserRecord.from_db(refreshed_user)
def _decode_bulk_invite_file(content: bytes) -> str:
if len(content) > MAX_BULK_INVITE_FILE_BYTES:
raise ValueError("Invite file exceeds the maximum size of 1 MB")
try:
return content.decode("utf-8-sig")
except UnicodeDecodeError as exc:
raise ValueError("Invite file must be UTF-8 encoded") from exc
def _parse_bulk_invite_csv(text: str) -> list[_ParsedInviteRow]:
indexed_rows: list[tuple[int, list[str]]] = []
for row_number, row in enumerate(csv.reader(io.StringIO(text)), start=1):
normalized_row = [cell.strip() for cell in row]
if any(normalized_row):
indexed_rows.append((row_number, normalized_row))
if not indexed_rows:
return []
header = [cell.lower() for cell in indexed_rows[0][1]]
has_header = "email" in header
email_index = header.index("email") if has_header else 0
name_index: Optional[int] = (
header.index("name")
if has_header and "name" in header
else (1 if not has_header else None)
)
data_rows = indexed_rows[1:] if has_header else indexed_rows
parsed_rows: list[_ParsedInviteRow] = []
for row_number, row in data_rows:
if len(parsed_rows) >= MAX_BULK_INVITE_ROWS:
break
email = row[email_index].strip() if len(row) > email_index else ""
name = (
row[name_index].strip()
if name_index is not None and len(row) > name_index
else ""
)
parsed_rows.append(
_ParsedInviteRow(
row_number=row_number,
email=email,
name=name or None,
)
)
return parsed_rows
def _parse_bulk_invite_text(text: str) -> list[_ParsedInviteRow]:
parsed_rows: list[_ParsedInviteRow] = []
for row_number, raw_line in enumerate(text.splitlines(), start=1):
if len(parsed_rows) >= MAX_BULK_INVITE_ROWS:
break
line = raw_line.strip()
if not line or line.startswith("#"):
continue
parsed_rows.append(
_ParsedInviteRow(
row_number=row_number,
email=line,
name=None,
)
)
return parsed_rows
def _parse_bulk_invite_file(
filename: Optional[str],
content: bytes,
) -> list[_ParsedInviteRow]:
text = _decode_bulk_invite_file(content)
file_name = filename.lower() if filename else ""
parsed_rows = (
_parse_bulk_invite_csv(text)
if file_name.endswith(".csv")
else _parse_bulk_invite_text(text)
)
if not parsed_rows:
raise ValueError("Invite file did not contain any emails")
return parsed_rows
async def bulk_create_invited_users_from_file(
filename: Optional[str],
content: bytes,
) -> BulkInvitedUsersResult:
parsed_rows = _parse_bulk_invite_file(filename, content)
created_count = 0
skipped_count = 0
error_count = 0
results: list[BulkInvitedUserRowResult] = []
seen_emails: set[str] = set()
for row in parsed_rows:
row_name = _normalize_name(row.name)
try:
validated_email = _email_adapter.validate_python(row.email)
except ValidationError:
error_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=row.email or None,
name=row_name,
status="ERROR",
message="Invalid email address",
)
)
continue
normalized_email = normalize_email(str(validated_email))
if normalized_email in seen_emails:
skipped_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="SKIPPED",
message="Duplicate email in upload file",
)
)
continue
seen_emails.add(normalized_email)
try:
invited_user = await create_invited_user(normalized_email, row_name)
except PreconditionFailed as exc:
skipped_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="SKIPPED",
message=str(exc),
)
)
except Exception:
masked = mask_email(normalized_email)
logger.exception(
"Failed to create bulk invite for row %s (%s)",
row.row_number,
masked,
)
error_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="ERROR",
message="Unexpected error creating invite",
)
)
else:
created_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="CREATED",
message="Invite created",
invited_user=invited_user,
)
)
return BulkInvitedUsersResult(
created_count=created_count,
skipped_count=skipped_count,
error_count=error_count,
results=results,
)
async def _compute_invited_user_tally_seed(invited_user_id: str) -> None:
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"id": invited_user_id}
)
if invited_user is None:
return
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
return
try:
r = await get_redis_async()
except Exception:
r = None
lock: AsyncClusterLock | None = None
if r is not None:
lock = AsyncClusterLock(
redis=r,
key=f"tally_seed:{invited_user_id}",
owner_id=_WORKER_ID,
timeout=_TALLY_STALE_SECONDS,
)
current_owner = await lock.try_acquire()
if current_owner is None:
logger.warn("Redis unvailable for tally lock - skipping tally enrichement")
return
elif current_owner != _WORKER_ID:
logger.debug(
"Tally seed for %s already locked by %s, skipping",
invited_user_id,
current_owner,
)
return
if (
invited_user.tallyStatus == prisma.enums.TallyComputationStatus.RUNNING
and invited_user.updatedAt is not None
):
age = (datetime.now(timezone.utc) - invited_user.updatedAt).total_seconds()
if age < _TALLY_STALE_SECONDS:
logger.debug(
"Tally task for %s still RUNNING (age=%ds), skipping",
invited_user_id,
int(age),
)
return
logger.info(
"Tally task for %s is stale (age=%ds), re-running",
invited_user_id,
int(age),
)
await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyStatus": prisma.enums.TallyComputationStatus.RUNNING,
"tallyError": None,
},
)
try:
input_data = await get_business_understanding_input_from_tally(
invited_user.email,
require_api_key=True,
)
payload = (
SafeJson(input_data.model_dump(exclude_none=True))
if input_data is not None
else None
)
await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyUnderstanding": payload,
"tallyStatus": prisma.enums.TallyComputationStatus.READY,
"tallyComputedAt": datetime.now(timezone.utc),
"tallyError": None,
},
)
except Exception as exc:
logger.exception(
"Failed to compute Tally understanding for invited user %s",
invited_user_id,
)
sanitized_error = re.sub(
r"https?://\S+", "<url>", f"{type(exc).__name__}: {exc}"
)[:_MAX_TALLY_ERROR_LENGTH]
await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyStatus": prisma.enums.TallyComputationStatus.FAILED,
"tallyError": sanitized_error,
},
)
def schedule_invited_user_tally_precompute(invited_user_id: str) -> None:
existing = _tally_seed_tasks.get(invited_user_id)
if existing is not None and not existing.done():
logger.debug("Tally task already running for %s, skipping", invited_user_id)
return
task = asyncio.create_task(_compute_invited_user_tally_seed(invited_user_id))
_tally_seed_tasks[invited_user_id] = task
def _on_done(t: asyncio.Task, _id: str = invited_user_id) -> None:
if _tally_seed_tasks.get(_id) is t:
del _tally_seed_tasks[_id]
task.add_done_callback(_on_done)
async def _open_signup_create_user(
auth_user_id: str,
normalized_email: str,
metadata_name: Optional[str],
) -> User:
"""Create a user without requiring an invite (open signup mode)."""
preferred_name = _normalize_name(metadata_name)
try:
async with transaction() as tx:
user = await prisma.models.User.prisma(tx).create(
data=prisma.types.UserCreateInput(
id=auth_user_id,
email=normalized_email,
name=preferred_name,
)
)
await _ensure_default_profile(
auth_user_id, normalized_email, preferred_name, tx
)
await _ensure_default_onboarding(auth_user_id, tx)
except UniqueViolationError:
existing = await prisma.models.User.prisma().find_unique(
where={"id": auth_user_id}
)
if existing is not None:
return User.from_db(existing)
raise
return User.from_db(user)
# TODO: We need to change this functions logic before going live
async def get_or_activate_user(user_data: dict) -> User:
auth_user_id = user_data.get("sub")
if not auth_user_id:
raise NotAuthorizedError("User ID not found in token")
auth_email = user_data.get("email")
if not auth_email:
raise NotAuthorizedError("Email not found in token")
normalized_email = normalize_email(auth_email)
user_metadata = user_data.get("user_metadata")
metadata_name = (
user_metadata.get("name") if isinstance(user_metadata, dict) else None
)
existing_user = None
try:
existing_user = await get_user_by_id(auth_user_id)
except ValueError:
existing_user = None
except Exception:
logger.exception("Error on get user by id during tally enrichment process")
raise
if existing_user is not None:
return existing_user
if not _settings.config.enable_invite_gate or normalized_email.endswith("@agpt.co"):
return await _open_signup_create_user(
auth_user_id, normalized_email, metadata_name
)
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"email": normalized_email}
)
if invited_user is None:
raise NotAuthorizedError("Your email is not allowed to access the platform")
if invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
raise NotAuthorizedError("Your invitation is no longer active")
try:
async with transaction() as tx:
current_user = await prisma.models.User.prisma(tx).find_unique(
where={"id": auth_user_id}
)
if current_user is not None:
return User.from_db(current_user)
current_invited_user = await prisma.models.InvitedUser.prisma(
tx
).find_unique(where={"email": normalized_email})
if current_invited_user is None:
raise NotAuthorizedError(
"Your email is not allowed to access the platform"
)
if current_invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
raise NotAuthorizedError("Your invitation is no longer active")
if current_invited_user.authUserId not in (None, auth_user_id):
raise NotAuthorizedError("Your invitation has already been claimed")
preferred_name = current_invited_user.name or _normalize_name(metadata_name)
await prisma.models.User.prisma(tx).create(
data=prisma.types.UserCreateInput(
id=auth_user_id,
email=normalized_email,
name=preferred_name,
)
)
await prisma.models.InvitedUser.prisma(tx).update(
where={"id": current_invited_user.id},
data={
"status": prisma.enums.InvitedUserStatus.CLAIMED,
"authUserId": auth_user_id,
},
)
await _ensure_default_profile(
auth_user_id,
normalized_email,
preferred_name,
tx,
)
await _ensure_default_onboarding(auth_user_id, tx)
await _apply_tally_understanding(auth_user_id, current_invited_user, tx)
except UniqueViolationError:
logger.info("Concurrent activation for user %s; re-fetching", auth_user_id)
already_created = await prisma.models.User.prisma().find_unique(
where={"id": auth_user_id}
)
if already_created is not None:
return User.from_db(already_created)
raise RuntimeError(
f"UniqueViolationError during activation but user {auth_user_id} not found"
)
get_user_by_id.cache_delete(auth_user_id)
get_user_by_email.cache_delete(normalized_email)
activated_user = await prisma.models.User.prisma().find_unique(
where={"id": auth_user_id}
)
if activated_user is None:
raise RuntimeError(
f"Activated user {auth_user_id} was not found after creation"
)
return User.from_db(activated_user)

View File

@@ -1,335 +0,0 @@
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from types import SimpleNamespace
from typing import cast
from unittest.mock import AsyncMock, Mock
import prisma.enums
import prisma.models
import pytest
import pytest_mock
from backend.util.exceptions import NotAuthorizedError, PreconditionFailed
from .invited_user import (
InvitedUserRecord,
bulk_create_invited_users_from_file,
create_invited_user,
get_or_activate_user,
retry_invited_user_tally,
)
def _invited_user_db_record(
*,
status: prisma.enums.InvitedUserStatus = prisma.enums.InvitedUserStatus.INVITED,
tally_understanding: dict | None = None,
):
now = datetime.now(timezone.utc)
return SimpleNamespace(
id="invite-1",
email="invited@example.com",
status=status,
authUserId=None,
name="Invited User",
tallyUnderstanding=tally_understanding,
tallyStatus=prisma.enums.TallyComputationStatus.PENDING,
tallyComputedAt=None,
tallyError=None,
createdAt=now,
updatedAt=now,
)
def _invited_user_record(
*,
status: prisma.enums.InvitedUserStatus = prisma.enums.InvitedUserStatus.INVITED,
tally_understanding: dict | None = None,
):
return InvitedUserRecord.from_db(
cast(
prisma.models.InvitedUser,
_invited_user_db_record(
status=status,
tally_understanding=tally_understanding,
),
)
)
def _user_db_record():
now = datetime.now(timezone.utc)
return SimpleNamespace(
id="auth-user-1",
email="invited@example.com",
emailVerified=True,
name="Invited User",
createdAt=now,
updatedAt=now,
metadata={},
integrations="",
stripeCustomerId=None,
topUpConfig=None,
maxEmailsPerDay=3,
notifyOnAgentRun=True,
notifyOnZeroBalance=True,
notifyOnLowBalance=True,
notifyOnBlockExecutionFailed=True,
notifyOnContinuousAgentError=True,
notifyOnDailySummary=True,
notifyOnWeeklySummary=True,
notifyOnMonthlySummary=True,
notifyOnAgentApproved=True,
notifyOnAgentRejected=True,
timezone="not-set",
)
@pytest.mark.asyncio
async def test_create_invited_user_rejects_existing_active_user(
mocker: pytest_mock.MockerFixture,
) -> None:
user_repo = Mock()
user_repo.find_unique = AsyncMock(return_value=_user_db_record())
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock()
mocker.patch(
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
with pytest.raises(PreconditionFailed):
await create_invited_user("Invited@example.com")
@pytest.mark.asyncio
async def test_create_invited_user_schedules_tally_seed(
mocker: pytest_mock.MockerFixture,
) -> None:
user_repo = Mock()
user_repo.find_unique = AsyncMock(return_value=None)
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock(return_value=None)
invited_user_repo.create = AsyncMock(return_value=_invited_user_db_record())
schedule = mocker.patch(
"backend.data.invited_user.schedule_invited_user_tally_precompute"
)
mocker.patch(
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
invited_user = await create_invited_user("Invited@example.com", "Invited User")
assert invited_user.email == "invited@example.com"
invited_user_repo.create.assert_awaited_once()
schedule.assert_called_once_with("invite-1")
@pytest.mark.asyncio
async def test_retry_invited_user_tally_resets_state_and_schedules(
mocker: pytest_mock.MockerFixture,
) -> None:
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock(return_value=_invited_user_db_record())
invited_user_repo.update = AsyncMock(return_value=_invited_user_db_record())
schedule = mocker.patch(
"backend.data.invited_user.schedule_invited_user_tally_precompute"
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
invited_user = await retry_invited_user_tally("invite-1")
assert invited_user.id == "invite-1"
invited_user_repo.update.assert_awaited_once()
schedule.assert_called_once_with("invite-1")
@pytest.mark.asyncio
async def test_get_or_activate_user_requires_invite(
mocker: pytest_mock.MockerFixture,
) -> None:
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock(return_value=None)
mock_get_user_by_id = AsyncMock(side_effect=ValueError("User not found"))
mock_get_user_by_id.cache_delete = Mock()
mocker.patch(
"backend.data.invited_user.get_user_by_id",
mock_get_user_by_id,
)
mocker.patch(
"backend.data.invited_user._settings.config.enable_invite_gate",
True,
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
with pytest.raises(NotAuthorizedError):
await get_or_activate_user(
{"sub": "auth-user-1", "email": "invited@example.com"}
)
@pytest.mark.asyncio
async def test_get_or_activate_user_creates_user_from_invite(
mocker: pytest_mock.MockerFixture,
) -> None:
tx = object()
invited_user = _invited_user_db_record(
tally_understanding={"user_name": "Invited User", "industry": "Automation"}
)
created_user = _user_db_record()
outside_user_repo = Mock()
# Only called once at post-transaction verification (line 741);
# get_user_by_id (line 657) uses prisma.user.find_unique, not this mock.
outside_user_repo.find_unique = AsyncMock(return_value=created_user)
inside_user_repo = Mock()
inside_user_repo.find_unique = AsyncMock(return_value=None)
inside_user_repo.create = AsyncMock(return_value=created_user)
outside_invited_repo = Mock()
outside_invited_repo.find_unique = AsyncMock(return_value=invited_user)
inside_invited_repo = Mock()
inside_invited_repo.find_unique = AsyncMock(return_value=invited_user)
inside_invited_repo.update = AsyncMock(return_value=invited_user)
def user_prisma(client=None):
return inside_user_repo if client is tx else outside_user_repo
def invited_user_prisma(client=None):
return inside_invited_repo if client is tx else outside_invited_repo
@asynccontextmanager
async def fake_transaction():
yield tx
# Mock get_user_by_id since it uses prisma.user.find_unique (global client),
# not prisma.models.User.prisma().find_unique which we mock above.
mock_get_user_by_id = AsyncMock(side_effect=ValueError("User not found"))
mock_get_user_by_id.cache_delete = Mock()
mocker.patch(
"backend.data.invited_user.get_user_by_id",
mock_get_user_by_id,
)
mock_get_user_by_email = AsyncMock()
mock_get_user_by_email.cache_delete = Mock()
mocker.patch(
"backend.data.invited_user.get_user_by_email",
mock_get_user_by_email,
)
ensure_profile = mocker.patch(
"backend.data.invited_user._ensure_default_profile",
AsyncMock(),
)
ensure_onboarding = mocker.patch(
"backend.data.invited_user._ensure_default_onboarding",
AsyncMock(),
)
apply_tally = mocker.patch(
"backend.data.invited_user._apply_tally_understanding",
AsyncMock(),
)
mocker.patch("backend.data.invited_user.transaction", fake_transaction)
mocker.patch(
"backend.data.invited_user.prisma.models.User.prisma", side_effect=user_prisma
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
side_effect=invited_user_prisma,
)
user = await get_or_activate_user(
{
"sub": "auth-user-1",
"email": "Invited@example.com",
"user_metadata": {"name": "Invited User"},
}
)
assert user.id == "auth-user-1"
inside_user_repo.create.assert_awaited_once()
inside_invited_repo.update.assert_awaited_once()
ensure_profile.assert_awaited_once()
ensure_onboarding.assert_awaited_once_with("auth-user-1", tx)
apply_tally.assert_awaited_once_with("auth-user-1", invited_user, tx)
@pytest.mark.asyncio
async def test_bulk_create_invited_users_from_text_file(
mocker: pytest_mock.MockerFixture,
) -> None:
create_invited = mocker.patch(
"backend.data.invited_user.create_invited_user",
AsyncMock(
side_effect=[
_invited_user_record(),
_invited_user_record(),
]
),
)
result = await bulk_create_invited_users_from_file(
"invites.txt",
b"Invited@example.com\nsecond@example.com\n",
)
assert result.created_count == 2
assert result.skipped_count == 0
assert result.error_count == 0
assert [row.status for row in result.results] == ["CREATED", "CREATED"]
assert create_invited.await_count == 2
@pytest.mark.asyncio
async def test_bulk_create_invited_users_handles_csv_duplicates_and_invalid_rows(
mocker: pytest_mock.MockerFixture,
) -> None:
create_invited = mocker.patch(
"backend.data.invited_user.create_invited_user",
AsyncMock(
side_effect=[
_invited_user_record(),
PreconditionFailed("An invited user with this email already exists"),
]
),
)
result = await bulk_create_invited_users_from_file(
"invites.csv",
(
"email,name\n"
"valid@example.com,Valid User\n"
"not-an-email,Bad Row\n"
"valid@example.com,Duplicate In File\n"
"existing@example.com,Existing User\n"
).encode("utf-8"),
)
assert result.created_count == 1
assert result.skipped_count == 2
assert result.error_count == 1
assert [row.status for row in result.results] == [
"CREATED",
"ERROR",
"SKIPPED",
"SKIPPED",
]
assert create_invited.await_count == 2

View File

@@ -8,8 +8,6 @@ from backend.api.model import NotificationPayload
from backend.data.event_bus import AsyncRedisEventBus
from backend.util.settings import Settings
_settings = Settings()
class NotificationEvent(BaseModel):
"""Generic notification event destined for websocket delivery."""
@@ -28,7 +26,7 @@ class AsyncRedisNotificationEventBus(AsyncRedisEventBus[NotificationEvent]):
@property
def event_bus_name(self) -> str:
return _settings.config.notification_event_bus_name
return Settings().config.notification_event_bus_name
async def publish(self, event: NotificationEvent) -> None:
await self.publish_event(event, event.user_id)

View File

@@ -41,7 +41,7 @@ _MAX_PAGES = 100
_LLM_TIMEOUT = 30
def mask_email(email: str) -> str:
def _mask_email(email: str) -> str:
"""Mask an email for safe logging: 'alice@example.com' -> 'a***e@example.com'."""
try:
local, domain = email.rsplit("@", 1)
@@ -196,7 +196,8 @@ async def _refresh_cache(form_id: str) -> tuple[dict, list]:
Returns (email_index, questions).
"""
client = _make_tally_client(_settings.secrets.tally_api_key)
settings = Settings()
client = _make_tally_client(settings.secrets.tally_api_key)
redis = await get_redis_async()
last_fetch_key = _LAST_FETCH_KEY.format(form_id=form_id)
@@ -331,9 +332,6 @@ Fields:
- current_software (list of strings): software/tools currently used
- existing_automation (list of strings): existing automations
- additional_notes (string): any additional context
- suggested_prompts (list of 5 strings): short action prompts (each under 20 words) that would help \
this person get started with automating their work. Should be specific to their industry, role, and \
pain points; actionable and conversational in tone; focused on automation opportunities.
Form data:
"""
@@ -341,21 +339,21 @@ Form data:
_EXTRACTION_SUFFIX = "\n\nReturn ONLY valid JSON."
async def extract_business_understanding_from_tally(
async def extract_business_understanding(
formatted_text: str,
) -> BusinessUnderstandingInput:
"""
Use an LLM to extract structured business understanding from form text.
"""Use an LLM to extract structured business understanding from form text.
Raises on timeout or unparseable response so the caller can handle it.
"""
api_key = _settings.secrets.open_router_api_key
settings = Settings()
api_key = settings.secrets.open_router_api_key
client = AsyncOpenAI(api_key=api_key, base_url=OPENROUTER_BASE_URL)
try:
response = await asyncio.wait_for(
client.chat.completions.create(
model=_settings.config.tally_extraction_llm_model,
model="openai/gpt-4o-mini",
messages=[
{
"role": "user",
@@ -380,57 +378,9 @@ async def extract_business_understanding_from_tally(
# Filter out null values before constructing
cleaned = {k: v for k, v in data.items() if v is not None}
# Validate suggested_prompts: filter >20 words, keep top 3
raw_prompts = cleaned.get("suggested_prompts", [])
if isinstance(raw_prompts, list):
valid = [
p.strip()
for p in raw_prompts
if isinstance(p, str) and len(p.strip().split()) <= 20
]
# This will keep up to 3 suggestions
short_prompts = valid[:3] if valid else None
if short_prompts:
cleaned["suggested_prompts"] = short_prompts
else:
# We dont want to add a None value suggested_prompts field
cleaned.pop("suggested_prompts", None)
else:
# suggested_prompts must be a list - removing it as its not here
cleaned.pop("suggested_prompts", None)
return BusinessUnderstandingInput(**cleaned)
async def get_business_understanding_input_from_tally(
email: str,
*,
require_api_key: bool = False,
) -> Optional[BusinessUnderstandingInput]:
if not _settings.secrets.tally_api_key:
if require_api_key:
raise RuntimeError("Tally API key is not configured")
logger.debug("Tally: no API key configured, skipping")
return None
masked = mask_email(email)
result = await find_submission_by_email(TALLY_FORM_ID, email)
if result is None:
logger.debug(f"Tally: no submission found for {masked}")
return None
submission, questions = result
logger.info(f"Tally: found submission for {masked}, extracting understanding")
formatted = format_submission_for_llm(submission, questions)
if not formatted.strip():
logger.warning("Tally: formatted submission was empty, skipping")
return None
return await extract_business_understanding_from_tally(formatted)
async def populate_understanding_from_tally(user_id: str, email: str) -> None:
"""Main orchestrator: check Tally for a matching submission and populate understanding.
@@ -445,10 +395,30 @@ async def populate_understanding_from_tally(user_id: str, email: str) -> None:
)
return
understanding_input = await get_business_understanding_input_from_tally(email)
if understanding_input is None:
# Check API key is configured
settings = Settings()
if not settings.secrets.tally_api_key:
logger.debug("Tally: no API key configured, skipping")
return
# Look up submission by email
masked = _mask_email(email)
result = await find_submission_by_email(TALLY_FORM_ID, email)
if result is None:
logger.debug(f"Tally: no submission found for {masked}")
return
submission, questions = result
logger.info(f"Tally: found submission for {masked}, extracting understanding")
# Format and extract
formatted = format_submission_for_llm(submission, questions)
if not formatted.strip():
logger.warning("Tally: formatted submission was empty, skipping")
return
understanding_input = await extract_business_understanding(formatted)
# Upsert into database
await upsert_business_understanding(user_id, understanding_input)
logger.info(f"Tally: successfully populated understanding for user {user_id}")

View File

@@ -12,11 +12,11 @@ from backend.data.tally import (
_build_email_index,
_format_answer,
_make_tally_client,
_mask_email,
_refresh_cache,
extract_business_understanding_from_tally,
extract_business_understanding,
find_submission_by_email,
format_submission_for_llm,
mask_email,
populate_understanding_from_tally,
)
@@ -248,7 +248,7 @@ async def test_populate_understanding_skips_no_api_key():
new_callable=AsyncMock,
return_value=None,
),
patch("backend.data.tally._settings", mock_settings),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.find_submission_by_email",
new_callable=AsyncMock,
@@ -284,7 +284,6 @@ async def test_populate_understanding_full_flow():
],
}
mock_input = MagicMock()
mock_input.suggested_prompts = ["Prompt 1", "Prompt 2", "Prompt 3"]
with (
patch(
@@ -292,14 +291,14 @@ async def test_populate_understanding_full_flow():
new_callable=AsyncMock,
return_value=None,
),
patch("backend.data.tally._settings", mock_settings),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.find_submission_by_email",
new_callable=AsyncMock,
return_value=(submission, SAMPLE_QUESTIONS),
),
patch(
"backend.data.tally.extract_business_understanding_from_tally",
"backend.data.tally.extract_business_understanding",
new_callable=AsyncMock,
return_value=mock_input,
) as mock_extract,
@@ -332,14 +331,14 @@ async def test_populate_understanding_handles_llm_timeout():
new_callable=AsyncMock,
return_value=None,
),
patch("backend.data.tally._settings", mock_settings),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.find_submission_by_email",
new_callable=AsyncMock,
return_value=(submission, SAMPLE_QUESTIONS),
),
patch(
"backend.data.tally.extract_business_understanding_from_tally",
"backend.data.tally.extract_business_understanding",
new_callable=AsyncMock,
side_effect=asyncio.TimeoutError(),
),
@@ -357,13 +356,13 @@ async def test_populate_understanding_handles_llm_timeout():
def test_mask_email():
assert mask_email("alice@example.com") == "a***e@example.com"
assert mask_email("ab@example.com") == "a***@example.com"
assert mask_email("a@example.com") == "a***@example.com"
assert _mask_email("alice@example.com") == "a***e@example.com"
assert _mask_email("ab@example.com") == "a***@example.com"
assert _mask_email("a@example.com") == "a***@example.com"
def test_mask_email_invalid():
assert mask_email("no-at-sign") == "***"
assert _mask_email("no-at-sign") == "***"
# ── Prompt construction (curly-brace safety) ─────────────────────────────────
@@ -394,11 +393,11 @@ def test_extraction_prompt_no_format_placeholders():
assert single_braces == [], f"Found format placeholders: {single_braces}"
# ── extract_business_understanding_from_tally ────────────────────────────────────────────
# ── extract_business_understanding ────────────────────────────────────────────
@pytest.mark.asyncio
async def test_extract_business_understanding_from_tally_success():
async def test_extract_business_understanding_success():
"""Happy path: LLM returns valid JSON that maps to BusinessUnderstandingInput."""
mock_choice = MagicMock()
mock_choice.message.content = json.dumps(
@@ -407,13 +406,6 @@ async def test_extract_business_understanding_from_tally_success():
"business_name": "Acme Corp",
"industry": "Technology",
"pain_points": ["manual reporting"],
"suggested_prompts": [
"Automate weekly reports",
"Set up invoice processing",
"Create a customer onboarding flow",
"Track project deadlines automatically",
"Send follow-up emails after meetings",
],
}
)
mock_response = MagicMock()
@@ -423,56 +415,16 @@ async def test_extract_business_understanding_from_tally_success():
mock_client.chat.completions.create.return_value = mock_response
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
result = await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
result = await extract_business_understanding("Q: Name?\nA: Alice")
assert result.user_name == "Alice"
assert result.business_name == "Acme Corp"
assert result.industry == "Technology"
assert result.pain_points == ["manual reporting"]
# suggested_prompts validated and sliced to top 3
assert result.suggested_prompts == [
"Automate weekly reports",
"Set up invoice processing",
"Create a customer onboarding flow",
]
@pytest.mark.asyncio
async def test_extract_business_understanding_from_tally_filters_long_prompts():
"""Prompts exceeding 20 words are excluded and only top 3 are kept."""
long_prompt = " ".join(["word"] * 21)
mock_choice = MagicMock()
mock_choice.message.content = json.dumps(
{
"user_name": "Alice",
"suggested_prompts": [
long_prompt,
"Short prompt one",
long_prompt,
"Short prompt two",
"Short prompt three",
"Short prompt four",
],
}
)
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_client = AsyncMock()
mock_client.chat.completions.create.return_value = mock_response
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
result = await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
assert result.suggested_prompts == [
"Short prompt one",
"Short prompt two",
"Short prompt three",
]
@pytest.mark.asyncio
async def test_extract_business_understanding_from_tally_filters_nulls():
async def test_extract_business_understanding_filters_nulls():
"""Null values from LLM should be excluded from the result."""
mock_choice = MagicMock()
mock_choice.message.content = json.dumps(
@@ -485,7 +437,7 @@ async def test_extract_business_understanding_from_tally_filters_nulls():
mock_client.chat.completions.create.return_value = mock_response
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
result = await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
result = await extract_business_understanding("Q: Name?\nA: Alice")
assert result.user_name == "Alice"
assert result.business_name is None
@@ -493,7 +445,7 @@ async def test_extract_business_understanding_from_tally_filters_nulls():
@pytest.mark.asyncio
async def test_extract_business_understanding_from_tally_invalid_json():
async def test_extract_business_understanding_invalid_json():
"""Invalid JSON from LLM should raise JSONDecodeError."""
mock_choice = MagicMock()
mock_choice.message.content = "not valid json {"
@@ -507,11 +459,11 @@ async def test_extract_business_understanding_from_tally_invalid_json():
patch("backend.data.tally.AsyncOpenAI", return_value=mock_client),
pytest.raises(json.JSONDecodeError),
):
await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
await extract_business_understanding("Q: Name?\nA: Alice")
@pytest.mark.asyncio
async def test_extract_business_understanding_from_tally_timeout():
async def test_extract_business_understanding_timeout():
"""LLM timeout should propagate as asyncio.TimeoutError."""
mock_client = AsyncMock()
mock_client.chat.completions.create.side_effect = asyncio.TimeoutError()
@@ -521,7 +473,7 @@ async def test_extract_business_understanding_from_tally_timeout():
patch("backend.data.tally._LLM_TIMEOUT", 0.001),
pytest.raises(asyncio.TimeoutError),
):
await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
await extract_business_understanding("Q: Name?\nA: Alice")
# ── _refresh_cache ───────────────────────────────────────────────────────────
@@ -540,7 +492,7 @@ async def test_refresh_cache_full_fetch():
submissions = SAMPLE_SUBMISSIONS
with (
patch("backend.data.tally._settings", mock_settings),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.get_redis_async",
new_callable=AsyncMock,
@@ -588,7 +540,7 @@ async def test_refresh_cache_incremental_fetch():
new_submissions = [SAMPLE_SUBMISSIONS[0]] # Just Alice
with (
patch("backend.data.tally._settings", mock_settings),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.get_redis_async",
new_callable=AsyncMock,

View File

@@ -86,11 +86,6 @@ class BusinessUnderstandingInput(pydantic.BaseModel):
None, description="Any additional context"
)
# Suggested prompts (UI-only, not included in system prompt)
suggested_prompts: Optional[list[str]] = pydantic.Field(
None, description="LLM-generated suggested prompts based on business context"
)
class BusinessUnderstanding(pydantic.BaseModel):
"""Full business understanding model returned from database."""
@@ -127,9 +122,6 @@ class BusinessUnderstanding(pydantic.BaseModel):
# Additional context
additional_notes: Optional[str] = None
# Suggested prompts (UI-only, not included in system prompt)
suggested_prompts: list[str] = pydantic.Field(default_factory=list)
@classmethod
def from_db(cls, db_record: CoPilotUnderstanding) -> "BusinessUnderstanding":
"""Convert database record to Pydantic model."""
@@ -157,7 +149,6 @@ class BusinessUnderstanding(pydantic.BaseModel):
current_software=_json_to_list(business.get("current_software")),
existing_automation=_json_to_list(business.get("existing_automation")),
additional_notes=business.get("additional_notes"),
suggested_prompts=_json_to_list(data.get("suggested_prompts")),
)
@@ -175,62 +166,6 @@ def _merge_lists(existing: list | None, new: list | None) -> list | None:
return merged
def merge_business_understanding_data(
existing_data: dict[str, Any],
input_data: BusinessUnderstandingInput,
) -> dict[str, Any]:
merged_data = dict(existing_data)
merged_business: dict[str, Any] = {}
if isinstance(merged_data.get("business"), dict):
merged_business = dict(merged_data["business"])
business_string_fields = [
"job_title",
"business_name",
"industry",
"business_size",
"user_role",
"additional_notes",
]
business_list_fields = [
"key_workflows",
"daily_activities",
"pain_points",
"bottlenecks",
"manual_tasks",
"automation_goals",
"current_software",
"existing_automation",
]
if input_data.user_name is not None:
merged_data["name"] = input_data.user_name
for field in business_string_fields:
value = getattr(input_data, field)
if value is not None:
merged_business[field] = value
for field in business_list_fields:
value = getattr(input_data, field)
if value is not None:
existing_list = _json_to_list(merged_business.get(field))
merged_list = _merge_lists(existing_list, value)
merged_business[field] = merged_list
merged_business["version"] = 1
merged_data["business"] = merged_business
# suggested_prompts lives at the top level (not under `business`) because
# it's a UI-only artifact consumed by the frontend, not business understanding
# data. The `business` sub-dict feeds the system prompt.
if input_data.suggested_prompts is not None:
merged_data["suggested_prompts"] = input_data.suggested_prompts
return merged_data
async def _get_from_cache(user_id: str) -> Optional[BusinessUnderstanding]:
"""Get business understanding from Redis cache."""
try:
@@ -310,18 +245,63 @@ async def upsert_business_understanding(
where={"userId": user_id}
)
# Get existing data structure or start fresh
existing_data: dict[str, Any] = {}
if existing and isinstance(existing.data, dict):
existing_data = dict(existing.data)
merged_data = merge_business_understanding_data(existing_data, input_data)
existing_business: dict[str, Any] = {}
if isinstance(existing_data.get("business"), dict):
existing_business = dict(existing_data["business"])
# Business fields (stored inside business object)
business_string_fields = [
"job_title",
"business_name",
"industry",
"business_size",
"user_role",
"additional_notes",
]
business_list_fields = [
"key_workflows",
"daily_activities",
"pain_points",
"bottlenecks",
"manual_tasks",
"automation_goals",
"current_software",
"existing_automation",
]
# Handle top-level name field
if input_data.user_name is not None:
existing_data["name"] = input_data.user_name
# Business string fields - overwrite if provided
for field in business_string_fields:
value = getattr(input_data, field)
if value is not None:
existing_business[field] = value
# Business list fields - merge with existing
for field in business_list_fields:
value = getattr(input_data, field)
if value is not None:
existing_list = _json_to_list(existing_business.get(field))
merged = _merge_lists(existing_list, value)
existing_business[field] = merged
# Set version and nest business data
existing_business["version"] = 1
existing_data["business"] = existing_business
# Upsert with the merged data
record = await CoPilotUnderstanding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "data": SafeJson(merged_data)},
"update": {"data": SafeJson(merged_data)},
"create": {"userId": user_id, "data": SafeJson(existing_data)},
"update": {"data": SafeJson(existing_data)},
},
)

View File

@@ -1,102 +0,0 @@
"""Tests for business understanding merge and format logic."""
from datetime import datetime, timezone
from typing import Any
from backend.data.understanding import (
BusinessUnderstanding,
BusinessUnderstandingInput,
format_understanding_for_prompt,
merge_business_understanding_data,
)
def _make_input(**kwargs: Any) -> BusinessUnderstandingInput:
"""Create a BusinessUnderstandingInput with only the specified fields."""
return BusinessUnderstandingInput.model_validate(kwargs)
# ─── merge_business_understanding_data: suggested_prompts ─────────────
def test_merge_suggested_prompts_overwrites_existing():
"""New suggested_prompts should fully replace existing ones (not append)."""
existing = {
"name": "Alice",
"business": {"industry": "Tech", "version": 1},
"suggested_prompts": ["Old prompt 1", "Old prompt 2"],
}
input_data = _make_input(
suggested_prompts=["New prompt A", "New prompt B", "New prompt C"],
)
result = merge_business_understanding_data(existing, input_data)
assert result["suggested_prompts"] == [
"New prompt A",
"New prompt B",
"New prompt C",
]
def test_merge_suggested_prompts_none_preserves_existing():
"""When input has suggested_prompts=None, existing prompts are preserved."""
existing = {
"name": "Alice",
"business": {"industry": "Tech", "version": 1},
"suggested_prompts": ["Keep me"],
}
input_data = _make_input(industry="Finance")
result = merge_business_understanding_data(existing, input_data)
assert result["suggested_prompts"] == ["Keep me"]
assert result["business"]["industry"] == "Finance"
def test_merge_suggested_prompts_added_to_empty_data():
"""Suggested prompts are set at top level even when starting from empty data."""
existing: dict[str, Any] = {}
input_data = _make_input(suggested_prompts=["Prompt 1"])
result = merge_business_understanding_data(existing, input_data)
assert result["suggested_prompts"] == ["Prompt 1"]
def test_merge_suggested_prompts_empty_list_overwrites():
"""An explicit empty list should overwrite existing prompts."""
existing: dict[str, Any] = {
"suggested_prompts": ["Old prompt"],
"business": {"version": 1},
}
input_data = _make_input(suggested_prompts=[])
result = merge_business_understanding_data(existing, input_data)
assert result["suggested_prompts"] == []
# ─── format_understanding_for_prompt: excludes suggested_prompts ──────
def test_format_understanding_excludes_suggested_prompts():
"""suggested_prompts is UI-only and must NOT appear in the system prompt."""
understanding = BusinessUnderstanding(
id="test-id",
user_id="user-1",
created_at=datetime.now(tz=timezone.utc),
updated_at=datetime.now(tz=timezone.utc),
user_name="Alice",
industry="Technology",
suggested_prompts=["Automate reports", "Set up alerts", "Track KPIs"],
)
formatted = format_understanding_for_prompt(understanding)
assert "Alice" in formatted
assert "Technology" in formatted
assert "suggested_prompts" not in formatted
assert "Automate reports" not in formatted
assert "Set up alerts" not in formatted
assert "Track KPIs" not in formatted

View File

@@ -46,7 +46,7 @@ from backend.util.exceptions import (
)
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
from backend.util.settings import Config
from backend.util.type import coerce_inputs_to_schema
from backend.util.type import convert
config = Config()
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil]")
@@ -213,8 +213,11 @@ def validate_exec(
if resolve_input:
data = merge_execution_input(data)
# Coerce non-matching data types to the expected input schema.
coerce_inputs_to_schema(data, schema)
# Convert non-matching data types to the expected input schema.
for name, data_type in schema.__annotations__.items():
value = data.get(name)
if (value is not None) and (type(value) is not data_type):
data[name] = convert(value, data_type)
# Input data post-merge should contain all required fields from the schema.
if missing_input := schema.get_missing_input(data):

View File

@@ -70,9 +70,6 @@ def _msg_tokens(msg: dict, enc) -> int:
# Count tool result tokens
tool_call_tokens += _tok_len(item.get("tool_use_id", ""), enc)
tool_call_tokens += _tok_len(item.get("content", ""), enc)
elif isinstance(item, dict) and item.get("type") == "text":
# Count text block tokens
tool_call_tokens += _tok_len(item.get("text", ""), enc)
elif isinstance(item, dict) and "content" in item:
# Other content types with content field
tool_call_tokens += _tok_len(item.get("content", ""), enc)
@@ -148,14 +145,10 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
if len(ids) <= max_tok:
return text # nothing to do
# Need at least 3 tokens (head + ellipsis + tail) for meaningful truncation
mid = enc.encode("")
if max_tok < 3:
return enc.decode(mid)
# Split the allowance between the two ends:
head = max_tok // 2 - 1 # -1 for the ellipsis
tail = max_tok - head - 1
mid = enc.encode("")
return enc.decode(ids[:head] + mid + ids[-tail:])
@@ -403,7 +396,7 @@ def validate_and_remove_orphan_tool_responses(
if log_warning:
logger.warning(
"Removing %d orphan tool response(s): %s", len(orphan_ids), orphan_ids
f"Removing {len(orphan_ids)} orphan tool response(s): {orphan_ids}"
)
return _remove_orphan_tool_responses(messages, orphan_ids)
@@ -495,9 +488,8 @@ def _ensure_tool_pairs_intact(
# Some tool_call_ids couldn't be resolved - remove those tool responses
# This shouldn't happen in normal operation but handles edge cases
logger.warning(
"Could not find assistant messages for tool_call_ids: %s. "
"Removing orphan tool responses.",
orphan_tool_call_ids,
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
"Removing orphan tool responses."
)
recent_messages = _remove_orphan_tool_responses(
recent_messages, orphan_tool_call_ids
@@ -505,8 +497,8 @@ def _ensure_tool_pairs_intact(
if messages_to_prepend:
logger.info(
"Extended recent messages by %d to preserve tool_call/tool_response pairs",
len(messages_to_prepend),
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
f"tool_call/tool_response pairs"
)
return messages_to_prepend + recent_messages
@@ -694,15 +686,11 @@ async def compress_context(
msgs = [summary_msg] + recent_msgs
logger.info(
"Context summarized: %d -> %d tokens, summarized %d messages",
original_count,
total_tokens(),
messages_summarized,
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
f"summarized {messages_summarized} messages"
)
except Exception as e:
logger.warning(
"Summarization failed, continuing with truncation: %s", e
)
logger.warning(f"Summarization failed, continuing with truncation: {e}")
# Fall through to content truncation
# ---- STEP 2: Normalize content ----------------------------------------
@@ -740,12 +728,6 @@ async def compress_context(
# This is more granular than dropping all old messages at once.
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
deletable: list[int] = []
# Count assistant messages to ensure we keep at least one
assistant_indices: set[int] = {
i
for i in range(len(msgs))
if msgs[i] is not None and msgs[i].get("role") == "assistant"
}
for i in range(1, len(msgs) - 1):
msg = msgs[i]
if (
@@ -753,9 +735,6 @@ async def compress_context(
and not _is_tool_message(msg)
and not _is_objective_message(msg)
):
# Skip if this is the last remaining assistant message
if msg.get("role") == "assistant" and len(assistant_indices) <= 1:
continue
deletable.append(i)
if not deletable:
break

View File

@@ -89,10 +89,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
le=500,
description="Thread pool size for FastAPI sync operations. All sync endpoints and dependencies automatically use this pool. Higher values support more concurrent sync operations but use more memory.",
)
tally_extraction_llm_model: str = Field(
default="openai/gpt-4o-mini",
description="OpenRouter model ID used for extracting business understanding from Tally form data",
)
ollama_host: str = Field(
default="localhost:11434",
description="Default Ollama host; exempted from SSRF checks.",
@@ -121,10 +117,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default=True,
description="If authentication is enabled or not",
)
enable_invite_gate: bool = Field(
default=True,
description="If the invite-only signup gate is enforced",
)
enable_credit: bool = Field(
default=False,
description="If user credit system is enabled or not",

View File

@@ -249,87 +249,6 @@ def convert(value: Any, target_type: Any) -> Any:
raise ConversionError(f"Failed to convert {value} to {target_type}") from e
def _value_satisfies_type(value: Any, target: Any) -> bool:
"""Check whether *value* already satisfies *target*, including inner elements.
For union types this checks each member; for generic container types it
recursively checks that inner elements satisfy the type args (e.g. every
element in a ``list[str]`` is a ``str``). Returns ``False`` when uncertain
so the caller falls through to :func:`convert`.
"""
# typing.Any cannot be used with isinstance(); treat as always satisfied.
if target is Any:
return True
origin = get_origin(target)
if origin is Union or origin is types.UnionType:
non_none = [a for a in get_args(target) if a is not type(None)]
return any(_value_satisfies_type(value, member) for member in non_none)
# Generic container type (e.g. list[str], dict[str, int])
if origin is not None:
# Guard: origin may not be a runtime type (e.g. Literal)
if not isinstance(origin, type):
return False
if not isinstance(value, origin):
return False
args = get_args(target)
if not args:
return True
# Check inner elements satisfy the type args
if _is_type_or_subclass(origin, list):
return all(_value_satisfies_type(v, args[0]) for v in value)
if _is_type_or_subclass(origin, dict) and len(args) >= 2:
return all(
_value_satisfies_type(k, args[0]) and _value_satisfies_type(v, args[1])
for k, v in value.items()
)
if (
_is_type_or_subclass(origin, set) or _is_type_or_subclass(origin, frozenset)
) and args:
return all(_value_satisfies_type(v, args[0]) for v in value)
if _is_type_or_subclass(origin, tuple):
# Homogeneous tuple[T, ...] — single type + Ellipsis
if len(args) == 2 and args[1] is Ellipsis:
return all(_value_satisfies_type(v, args[0]) for v in value)
# Heterogeneous tuple[T1, T2, ...] — positional types
if len(value) != len(args):
return False
return all(_value_satisfies_type(v, t) for v, t in zip(value, args))
# Unhandled generic origin — fall through to convert()
return False
# Simple type (e.g. str, int)
if isinstance(target, type):
return isinstance(value, target)
return False
def coerce_inputs_to_schema(data: dict[str, Any], schema: type) -> None:
"""Coerce *data* values in-place to match *schema*'s field types.
Uses ``model_fields`` (not ``__annotations__``) so inherited fields are
included. Skips coercion when the value already satisfies the target
type — in particular for union-typed fields where the value matches one
member but differs from the annotation object itself.
This is the single authoritative coercion step shared by the executor
(``validate_exec``) and the CoPilot (``execute_block``).
"""
for name, field_info in schema.model_fields.items():
value = data.get(name)
if value is None:
continue
target = field_info.annotation
if target is None:
continue
if _value_satisfies_type(value, target):
continue
data[name] = convert(value, target)
class FormattedStringType(str):
string_format: str

View File

@@ -1,8 +1,6 @@
from typing import Any, List, Literal, Optional
from typing import List, Optional
from pydantic import BaseModel
from backend.util.type import _value_satisfies_type, coerce_inputs_to_schema, convert
from backend.util.type import convert
def test_type_conversion():
@@ -48,343 +46,3 @@ def test_type_conversion():
# Test other empty list conversions
assert convert([], int) == 0 # len([]) = 0
assert convert([], Optional[int]) == 0
# ---------------------------------------------------------------------------
# _value_satisfies_type
# ---------------------------------------------------------------------------
class TestValueSatisfiesType:
# --- simple types ---
def test_simple_match(self):
assert _value_satisfies_type("hello", str) is True
assert _value_satisfies_type(42, int) is True
assert _value_satisfies_type(3.14, float) is True
assert _value_satisfies_type(True, bool) is True
def test_simple_mismatch(self):
assert _value_satisfies_type("hello", int) is False
assert _value_satisfies_type(42, str) is False
assert _value_satisfies_type([1, 2], str) is False
# --- Any ---
def test_any_always_satisfied(self):
assert _value_satisfies_type("hello", Any) is True
assert _value_satisfies_type(42, Any) is True
assert _value_satisfies_type([1, 2], Any) is True
assert _value_satisfies_type(None, Any) is True
# --- Optional / Union ---
def test_optional_with_value(self):
assert _value_satisfies_type("hello", Optional[str]) is True
assert _value_satisfies_type(42, Optional[int]) is True
def test_optional_mismatch(self):
assert _value_satisfies_type(42, Optional[str]) is False
def test_union_matches_first_member(self):
assert _value_satisfies_type("hello", str | list[str]) is True
def test_union_matches_second_member(self):
assert _value_satisfies_type(["a", "b"], str | list[str]) is True
def test_union_no_match(self):
assert _value_satisfies_type(42, str | list[str]) is False
# --- list[T] ---
def test_list_str_all_match(self):
assert _value_satisfies_type(["a", "b", "c"], list[str]) is True
def test_list_str_inner_mismatch(self):
assert _value_satisfies_type([1, 2, 3], list[str]) is False
def test_list_int_all_match(self):
assert _value_satisfies_type([1, 2, 3], list[int]) is True
def test_list_int_inner_mismatch(self):
assert _value_satisfies_type(["1", "2"], list[int]) is False
def test_empty_list_satisfies_any_list_type(self):
assert _value_satisfies_type([], list[str]) is True
assert _value_satisfies_type([], list[int]) is True
def test_string_does_not_satisfy_list(self):
assert _value_satisfies_type("hello", list[str]) is False
# --- nested list[list[str]] ---
def test_nested_list_all_match(self):
assert _value_satisfies_type([["a", "b"], ["c"]], list[list[str]]) is True
def test_nested_list_inner_mismatch(self):
assert _value_satisfies_type([["a", 1], ["c"]], list[list[str]]) is False
def test_nested_list_outer_mismatch(self):
assert _value_satisfies_type(["a", "b"], list[list[str]]) is False
# --- dict[K, V] ---
def test_dict_str_int_match(self):
assert _value_satisfies_type({"a": 1, "b": 2}, dict[str, int]) is True
def test_dict_str_int_value_mismatch(self):
assert _value_satisfies_type({"a": "1", "b": "2"}, dict[str, int]) is False
def test_dict_str_int_key_mismatch(self):
assert _value_satisfies_type({1: 1, 2: 2}, dict[str, int]) is False
def test_empty_dict_satisfies(self):
assert _value_satisfies_type({}, dict[str, int]) is True
# --- set[T] / tuple[T] ---
def test_set_match(self):
assert _value_satisfies_type({1, 2, 3}, set[int]) is True
def test_set_mismatch(self):
assert _value_satisfies_type({"a", "b"}, set[int]) is False
def test_tuple_homogeneous_match(self):
assert _value_satisfies_type((1, 2, 3), tuple[int, ...]) is True
def test_tuple_homogeneous_mismatch(self):
assert _value_satisfies_type((1, "2", 3), tuple[int, ...]) is False
def test_tuple_heterogeneous_match(self):
assert _value_satisfies_type(("a", 1, True), tuple[str, int, bool]) is True
def test_tuple_heterogeneous_mismatch(self):
assert _value_satisfies_type(("a", "b", True), tuple[str, int, bool]) is False
def test_tuple_heterogeneous_wrong_length(self):
assert _value_satisfies_type(("a", 1), tuple[str, int, bool]) is False
# --- bare generics (no args) ---
def test_bare_list(self):
assert _value_satisfies_type([1, "a"], list) is True
def test_bare_dict(self):
assert _value_satisfies_type({"a": 1}, dict) is True
# --- union with generic inner mismatch ---
def test_union_list_with_wrong_inner_falls_through(self):
# [1, 2] doesn't satisfy list[str] (inner mismatch), and not str either
assert _value_satisfies_type([1, 2], str | list[str]) is False
# --- Literal (non-runtime origin) ---
def test_literal_does_not_crash(self):
"""Literal origins are not runtime types — should return False, not crash."""
assert _value_satisfies_type("active", Literal["active", "inactive"]) is False
# ---------------------------------------------------------------------------
# coerce_inputs_to_schema — using real Pydantic models
# ---------------------------------------------------------------------------
class SampleSchema(BaseModel):
name: str
count: int
items: list[str]
config: dict[str, int] = {}
class NestedSchema(BaseModel):
rows: list[list[str]]
class UnionSchema(BaseModel):
content: str | list[str]
class OptionalSchema(BaseModel):
label: Optional[str] = None
value: int = 0
class AnyFieldSchema(BaseModel):
data: Any
class TestCoerceInputsToSchema:
def test_string_to_int(self):
data: dict[str, Any] = {"name": "test", "count": "42", "items": ["a"]}
coerce_inputs_to_schema(data, SampleSchema)
assert data["count"] == 42
assert isinstance(data["count"], int)
def test_json_string_to_list(self):
data: dict[str, Any] = {"name": "test", "count": 1, "items": '["a","b","c"]'}
coerce_inputs_to_schema(data, SampleSchema)
assert data["items"] == ["a", "b", "c"]
def test_already_correct_types_unchanged(self):
data: dict[str, Any] = {
"name": "test",
"count": 42,
"items": ["a", "b"],
"config": {"x": 1},
}
coerce_inputs_to_schema(data, SampleSchema)
assert data == {
"name": "test",
"count": 42,
"items": ["a", "b"],
"config": {"x": 1},
}
def test_inner_element_coercion(self):
"""list[str] with int inner elements → coerced to strings."""
data: dict[str, Any] = {"name": "test", "count": 1, "items": [1, 2, 3]}
coerce_inputs_to_schema(data, SampleSchema)
assert data["items"] == ["1", "2", "3"]
def test_dict_value_coercion(self):
"""dict[str, int] with string values → coerced to ints."""
data: dict[str, Any] = {
"name": "test",
"count": 1,
"items": [],
"config": {"x": "10", "y": "20"},
}
coerce_inputs_to_schema(data, SampleSchema)
assert data["config"] == {"x": 10, "y": 20}
def test_nested_list_from_json_string(self):
data: dict[str, Any] = {
"rows": '[["Name","Score"],["Alice","90"]]',
}
coerce_inputs_to_schema(data, NestedSchema)
assert data["rows"] == [["Name", "Score"], ["Alice", "90"]]
def test_nested_list_already_correct(self):
original = [["a", "b"], ["c", "d"]]
data: dict[str, Any] = {"rows": original}
coerce_inputs_to_schema(data, NestedSchema)
assert data["rows"] == original
def test_union_preserves_valid_list(self):
"""list[str] value should NOT be stringified for str | list[str]."""
data: dict[str, Any] = {"content": ["a", "b"]}
coerce_inputs_to_schema(data, UnionSchema)
assert data["content"] == ["a", "b"]
assert isinstance(data["content"], list)
def test_union_preserves_valid_string(self):
data: dict[str, Any] = {"content": "hello"}
coerce_inputs_to_schema(data, UnionSchema)
assert data["content"] == "hello"
def test_union_list_with_wrong_inner_gets_coerced(self):
"""[1, 2] for str | list[str] — inner ints don't match list[str],
so convert() is called. convert tries str first → stringifies."""
data: dict[str, Any] = {"content": [1, 2]}
coerce_inputs_to_schema(data, UnionSchema)
# convert([1,2], str | list[str]) tries str first → "[1, 2]"
# This is convert()'s union behavior — str wins over list[str]
assert isinstance(data["content"], (str, list))
def test_skips_none_values(self):
data: dict[str, Any] = {"label": None, "value": "5"}
coerce_inputs_to_schema(data, OptionalSchema)
assert data["label"] is None
assert data["value"] == 5
def test_skips_missing_fields(self):
data: dict[str, Any] = {"value": "10"}
coerce_inputs_to_schema(data, OptionalSchema)
assert "label" not in data
assert data["value"] == 10
def test_any_field_skipped(self):
"""Fields typed as Any should pass through without coercion."""
data: dict[str, Any] = {"data": [1, "mixed", {"nested": True}]}
coerce_inputs_to_schema(data, AnyFieldSchema)
assert data["data"] == [1, "mixed", {"nested": True}]
def test_preserves_all_convert_capabilities(self):
"""Verify coerce_inputs_to_schema doesn't lose any convert() capability
that existed before the _value_satisfies_type gate was added."""
class FullSchema(BaseModel):
as_int: int
as_float: float
as_bool: bool
as_str: str
as_list: list[int]
as_dict: dict[str, str]
data: dict[str, Any] = {
"as_int": "42",
"as_float": "3.14",
"as_bool": "True",
"as_str": 123,
"as_list": "[1,2,3]",
"as_dict": '{"a": "b"}',
}
coerce_inputs_to_schema(data, FullSchema)
assert data["as_int"] == 42
assert data["as_float"] == 3.14
assert data["as_bool"] is True
assert data["as_str"] == "123"
assert data["as_list"] == [1, 2, 3]
assert data["as_dict"] == {"a": "b"}
def test_inherited_fields_are_coerced(self):
"""model_fields includes inherited fields; __annotations__ does not.
This verifies that fields from a parent schema are still coerced."""
class ParentSchema(BaseModel):
base_count: int
class ChildSchema(ParentSchema):
name: str
# base_count is inherited — __annotations__ wouldn't include it
assert "base_count" not in ChildSchema.__annotations__
assert "base_count" in ChildSchema.model_fields
data: dict[str, Any] = {"base_count": "42", "name": "test"}
coerce_inputs_to_schema(data, ChildSchema)
assert data["base_count"] == 42
assert isinstance(data["base_count"], int)
def test_nested_pydantic_model_field(self):
"""dict input for a Pydantic model-typed field passes through.
convert() doesn't construct Pydantic models — Pydantic validation
handles that downstream. This test documents the behavior."""
class InnerModel(BaseModel):
x: int
class OuterModel(BaseModel):
inner: InnerModel
data: dict[str, Any] = {"inner": {"x": 1}}
coerce_inputs_to_schema(data, OuterModel)
# dict stays as dict — convert() doesn't construct Pydantic models
assert data["inner"] == {"x": 1}
assert isinstance(data["inner"], dict)
def test_literal_field_passes_through(self):
"""Literal-typed fields should not crash coercion."""
class LiteralSchema(BaseModel):
status: Literal["active", "inactive"]
data: dict[str, Any] = {"status": "active"}
coerce_inputs_to_schema(data, LiteralSchema)
assert data["status"] == "active"
def test_list_of_pydantic_model_field(self):
"""list[dict] for list[PydanticModel] passes through unchanged."""
class ItemModel(BaseModel):
name: str
class ContainerModel(BaseModel):
items: list[ItemModel]
data: dict[str, Any] = {"items": [{"name": "a"}, {"name": "b"}]}
coerce_inputs_to_schema(data, ContainerModel)
# Dicts stay as dicts — Pydantic validation handles construction
assert data["items"] == [{"name": "a"}, {"name": "b"}]
assert isinstance(data["items"][0], dict)

View File

@@ -1,22 +0,0 @@
-- Migrate Gemini 3 Pro Preview to Gemini 3.1 Pro Preview
-- This updates all AgentNode blocks that use the deprecated Gemini 3 Pro Preview model
-- Google is shutting down google/gemini-3-pro-preview on March 9, 2026
-- Update AgentNode constant inputs
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
'"google/gemini-3.1-pro-preview"'::jsonb
)
WHERE "constantInput"::jsonb->>'model' = 'google/gemini-3-pro-preview';
-- Update AgentPreset input overrides (stored in AgentNodeExecutionInputOutput)
UPDATE "AgentNodeExecutionInputOutput"
SET "data" = JSONB_SET(
"data"::jsonb,
'{model}',
'"google/gemini-3.1-pro-preview"'::jsonb
)
WHERE "agentPresetId" IS NOT NULL
AND "data"::jsonb->>'model' = 'google/gemini-3-pro-preview';

View File

@@ -1,46 +0,0 @@
/*
Warnings:
- You are about to drop the column `search` on the `StoreListingVersion` table. All the data in the column will be lost.
*/-- CreateEnum
CREATE TYPE "InvitedUserStatus" AS ENUM('INVITED',
'CLAIMED',
'REVOKED');
-- CreateEnum
CREATE TYPE "TallyComputationStatus" AS ENUM('PENDING',
'RUNNING',
'READY',
'FAILED');
-- CreateTable
CREATE TABLE "InvitedUser"(
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"email" TEXT NOT NULL,
"status" "InvitedUserStatus" NOT NULL DEFAULT 'INVITED',
"authUserId" TEXT,
"name" TEXT,
"tallyUnderstanding" JSONB,
"tallyStatus" "TallyComputationStatus" NOT NULL DEFAULT 'PENDING',
"tallyComputedAt" TIMESTAMP(3),
"tallyError" TEXT,
CONSTRAINT "InvitedUser_pkey" PRIMARY KEY("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "InvitedUser_email_key"
ON "InvitedUser"("email");
-- CreateIndex
CREATE UNIQUE INDEX "InvitedUser_authUserId_key"
ON "InvitedUser"("authUserId");
-- CreateIndex
CREATE INDEX "InvitedUser_status_idx"
ON "InvitedUser"("status");
-- CreateIndex
CREATE INDEX "InvitedUser_tallyStatus_idx"
ON "InvitedUser"("tallyStatus");
-- AddForeignKey
ALTER TABLE "InvitedUser" ADD CONSTRAINT "InvitedUser_authUserId_fkey" FOREIGN KEY("authUserId") REFERENCES "User"("id")
ON DELETE
SET NULL
ON UPDATE CASCADE;

View File

@@ -1,15 +0,0 @@
-- Drop the trigger that auto-creates User + Profile on auth.users INSERT.
-- The invite activation flow in get_or_activate_user() now handles this.
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM information_schema.tables
WHERE table_schema = 'auth' AND table_name = 'users'
) THEN
DROP TRIGGER IF EXISTS user_added_to_platform ON auth.users;
END IF;
END $$;
DROP FUNCTION IF EXISTS add_user_and_profile_to_platform();
DROP FUNCTION IF EXISTS add_user_to_platform();
-- Keep generate_username() — used by backfill migration 20250205110132

View File

@@ -1,7 +0,0 @@
-- DropIndex
DROP INDEX "InvitedUser_status_idx";
-- DropIndex
DROP INDEX "InvitedUser_tallyStatus_idx";
-- CreateIndex
CREATE INDEX "InvitedUser_createdAt_idx"
ON "InvitedUser"("createdAt");

View File

@@ -1,40 +0,0 @@
-- Fix PerplexityBlock nodes that have invalid model values (e.g. gpt-4o,
-- gpt-5.2-2025-12-11) set by the agent generator. Defaults them to the
-- standard "perplexity/sonar" model.
--
-- PerplexityBlock ID: c8a5f2e9-8b3d-4a7e-9f6c-1d5e3c9b7a4f
-- Valid models: perplexity/sonar, perplexity/sonar-pro, perplexity/sonar-deep-research
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
'"perplexity/sonar"'::jsonb
)
WHERE "agentBlockId" = 'c8a5f2e9-8b3d-4a7e-9f6c-1d5e3c9b7a4f'
AND "constantInput"::jsonb ? 'model'
AND "constantInput"::jsonb->>'model' NOT IN (
'perplexity/sonar',
'perplexity/sonar-pro',
'perplexity/sonar-deep-research'
);
-- Update AgentPreset input overrides (stored in AgentNodeExecutionInputOutput).
-- The table links to AgentNode through AgentNodeExecution, not directly.
UPDATE "AgentNodeExecutionInputOutput" io
SET "data" = JSONB_SET(
io."data"::jsonb,
'{model}',
'"perplexity/sonar"'::jsonb
)
FROM "AgentNodeExecution" exe
JOIN "AgentNode" n ON n."id" = exe."agentNodeId"
WHERE io."agentPresetId" IS NOT NULL
AND (io."referencedByInputExecId" = exe."id" OR io."referencedByOutputExecId" = exe."id")
AND n."agentBlockId" = 'c8a5f2e9-8b3d-4a7e-9f6c-1d5e3c9b7a4f'
AND io."data"::jsonb ? 'model'
AND io."data"::jsonb->>'model' NOT IN (
'perplexity/sonar',
'perplexity/sonar-pro',
'perplexity/sonar-deep-research'
);

View File

@@ -1282,14 +1282,14 @@ pgp = ["gpg"]
[[package]]
name = "e2b"
version = "2.15.2"
version = "2.15.1"
description = "E2B SDK that give agents cloud environments"
optional = false
python-versions = "<4.0,>=3.10"
groups = ["main"]
files = [
{file = "e2b-2.15.2-py3-none-any.whl", hash = "sha256:19a56fbdea25974dc81426ed48337eae6cea91d404f5bcf8861a5a2c6e4d982a"},
{file = "e2b-2.15.2.tar.gz", hash = "sha256:414379d2421d6827eeb2eb50a4d6b3fdb7d691b39ff73b5ea05ca4b532819831"},
{file = "e2b-2.15.1-py3-none-any.whl", hash = "sha256:a3bc4e004eab51fb05bae44e9ee4fe821e4637260f4ce3064c8f7c6ed7f5a2a0"},
{file = "e2b-2.15.1.tar.gz", hash = "sha256:a4f1bbc8b5180a8a1098079257fcb73e42503ed546098f676f722f11f0d68c09"},
]
[package.dependencies]
@@ -8882,4 +8882,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.14"
content-hash = "4e4365721cd3b68c58c237353b74adae1c64233fd4446904c335f23eb866fdca"
content-hash = "618d61b0586ab82fec1e28d1feb549a198e0b5c9d152e808862e55efc00a65b9"

View File

@@ -20,7 +20,7 @@ claude-agent-sdk = "0.1.45" # see copilot/sdk/sdk_compat_test.py for capability
click = "^8.2.0"
cryptography = "^46.0"
discord-py = "^2.5.2"
e2b = "^2.15.2"
e2b = "^2.0"
e2b-code-interpreter = "^2.0"
elevenlabs = "^1.50.0"
fastapi = "^0.128.6"

View File

@@ -65,7 +65,6 @@ model User {
NotificationBatches UserNotificationBatch[]
PendingHumanReviews PendingHumanReview[]
Workspace UserWorkspace?
ClaimedInvite InvitedUser? @relation("InvitedUserAuthUser")
// OAuth Provider relations
OAuthApplications OAuthApplication[]
@@ -74,38 +73,6 @@ model User {
OAuthRefreshTokens OAuthRefreshToken[]
}
enum InvitedUserStatus {
INVITED
CLAIMED
REVOKED
}
enum TallyComputationStatus {
PENDING
RUNNING
READY
FAILED
}
model InvitedUser {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
email String @unique
status InvitedUserStatus @default(INVITED)
authUserId String? @unique
AuthUser User? @relation("InvitedUserAuthUser", fields: [authUserId], references: [id], onDelete: SetNull)
name String?
tallyUnderstanding Json?
tallyStatus TallyComputationStatus @default(PENDING)
tallyComputedAt DateTime?
tallyError String?
@@index([createdAt])
}
enum OnboardingStep {
// Introductory onboarding (Library)
WELCOME
@@ -1025,7 +992,7 @@ model StoreListing {
ActiveVersion StoreListingVersion? @relation("ActiveVersion", fields: [activeVersionId], references: [id])
// The agent link here is only so we can do lookup on agentId
agentGraphId String @unique
agentGraphId String @unique
owningUserId String
OwningUser User @relation(fields: [owningUserId], references: [id])

View File

@@ -4,6 +4,7 @@
"id": "test-agent-1",
"graph_id": "test-agent-1",
"graph_version": 1,
"owner_user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
"image_url": null,
"creator_name": "Test Creator",
"creator_image_url": "",
@@ -50,6 +51,7 @@
"id": "test-agent-2",
"graph_id": "test-agent-2",
"graph_version": 1,
"owner_user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
"image_url": null,
"creator_name": "Test Creator",
"creator_image_url": "",

View File

@@ -84,27 +84,6 @@ class TestGmailReadBlock:
assert "Hello World" in result
assert "This is HTML content" in result
@pytest.mark.asyncio
async def test_html_fallback_when_html2text_conversion_fails(self):
"""Fallback to raw HTML when html2text converter raises unexpectedly."""
html_text = "<html><body><p>Broken <b>HTML</p></body></html>"
msg = {
"id": "test_msg_html_error",
"payload": {
"mimeType": "text/html",
"body": {"data": self._encode_base64(html_text)},
},
}
with patch("html2text.HTML2Text") as mock_html2text:
mock_converter = Mock()
mock_converter.handle.side_effect = ValueError("conversion failed")
mock_html2text.return_value = mock_converter
result = await self.gmail_block._get_email_body(msg, self.mock_service)
assert result == html_text
@pytest.mark.asyncio
async def test_html_fallback_when_html2text_unavailable(self):
"""Test fallback to raw HTML when html2text is not available."""

View File

@@ -34,7 +34,7 @@ from backend.data.auth.api_key import create_api_key
from backend.data.credit import get_user_credit_model
from backend.data.db import prisma
from backend.data.graph import Graph, Link, Node, create_graph
from backend.data.invited_user import get_or_activate_user
from backend.data.user import get_or_create_user
from backend.util.clients import get_supabase
faker = Faker()
@@ -151,7 +151,7 @@ class TestDataCreator:
}
# Use the API function to create user in local database
user = await get_or_activate_user(user_data)
user = await get_or_create_user(user_data)
users.append(user.model_dump())
except Exception as e:

View File

@@ -1,14 +1,7 @@
"use client";
import { Sidebar } from "@/components/__legacy__/Sidebar";
import {
UsersIcon,
CurrencyDollarSimpleIcon,
UserPlusIcon,
MagnifyingGlassIcon,
FileTextIcon,
SlidersHorizontalIcon,
} from "@phosphor-icons/react";
import { Users, DollarSign, UserSearch, FileText } from "lucide-react";
import { IconSliders } from "@/components/__legacy__/ui/icons";
const sidebarLinkGroups = [
{
@@ -16,32 +9,27 @@ const sidebarLinkGroups = [
{
text: "Marketplace Management",
href: "/admin/marketplace",
icon: <UsersIcon size={24} />,
icon: <Users className="h-6 w-6" />,
},
{
text: "User Spending",
href: "/admin/spending",
icon: <CurrencyDollarSimpleIcon size={24} />,
},
{
text: "Beta Invites",
href: "/admin/users",
icon: <UserPlusIcon size={24} />,
icon: <DollarSign className="h-6 w-6" />,
},
{
text: "User Impersonation",
href: "/admin/impersonation",
icon: <MagnifyingGlassIcon size={24} />,
icon: <UserSearch className="h-6 w-6" />,
},
{
text: "Execution Analytics",
href: "/admin/execution-analytics",
icon: <FileTextIcon size={24} />,
icon: <FileText className="h-6 w-6" />,
},
{
text: "Admin User Management",
href: "/admin/settings",
icon: <SlidersHorizontalIcon size={24} />,
icon: <IconSliders className="h-6 w-6" />,
},
],
},

View File

@@ -1,80 +0,0 @@
"use client";
import { Card } from "@/components/atoms/Card/Card";
import { BulkInviteForm } from "../BulkInviteForm/BulkInviteForm";
import { InviteUserForm } from "../InviteUserForm/InviteUserForm";
import { InvitedUsersTable } from "../InvitedUsersTable/InvitedUsersTable";
import { useAdminUsersPage } from "../../useAdminUsersPage";
export function AdminUsersPage() {
const {
email,
name,
bulkInviteFile,
bulkInviteInputKey,
lastBulkInviteResult,
invitedUsers,
isLoadingInvitedUsers,
isRefreshingInvitedUsers,
isCreatingInvite,
isBulkInviting,
pendingInviteAction,
setEmail,
setName,
handleBulkInviteFileChange,
handleBulkInviteSubmit,
handleCreateInvite,
handleRetryTally,
handleRevoke,
} = useAdminUsersPage();
return (
<div className="mx-auto flex max-w-7xl flex-col gap-6 p-6">
<div className="flex flex-col gap-2">
<h1 className="text-3xl font-bold text-zinc-900">Beta Invites</h1>
<p className="max-w-3xl text-sm text-zinc-600">
Pre-provision beta users before they sign up. Invites store the
platform-side record, run Tally understanding extraction, and activate
the real account on the user&apos;s first authenticated request.
</p>
</div>
<div className="grid gap-6 xl:grid-cols-[24rem,1fr]">
<div className="flex flex-col gap-6">
<Card className="border border-zinc-200 shadow-sm">
<InviteUserForm
email={email}
name={name}
isSubmitting={isCreatingInvite}
onEmailChange={setEmail}
onNameChange={setName}
onSubmit={handleCreateInvite}
/>
</Card>
<Card className="border border-zinc-200 shadow-sm">
<BulkInviteForm
selectedFile={bulkInviteFile}
inputKey={bulkInviteInputKey}
isSubmitting={isBulkInviting}
lastResult={lastBulkInviteResult}
onFileChange={handleBulkInviteFileChange}
onSubmit={handleBulkInviteSubmit}
/>
</Card>
</div>
<Card className="border border-zinc-200 shadow-sm">
<InvitedUsersTable
invitedUsers={invitedUsers}
isLoading={isLoadingInvitedUsers}
isRefreshing={isRefreshingInvitedUsers}
pendingInviteAction={pendingInviteAction}
onRetryTally={handleRetryTally}
onRevoke={handleRevoke}
/>
</Card>
</div>
</div>
);
}

View File

@@ -1,135 +0,0 @@
"use client";
import type { BulkInvitedUsersResponse } from "@/app/api/__generated__/models/bulkInvitedUsersResponse";
import { Badge } from "@/components/atoms/Badge/Badge";
import { Button } from "@/components/atoms/Button/Button";
import type { FormEvent } from "react";
interface Props {
selectedFile: File | null;
inputKey: number;
isSubmitting: boolean;
lastResult: BulkInvitedUsersResponse | null;
onFileChange: (file: File | null) => void;
onSubmit: (event: FormEvent<HTMLFormElement>) => void;
}
function getStatusVariant(status: "CREATED" | "SKIPPED" | "ERROR") {
if (status === "CREATED") {
return "success";
}
if (status === "ERROR") {
return "error";
}
return "info";
}
export function BulkInviteForm({
selectedFile,
inputKey,
isSubmitting,
lastResult,
onFileChange,
onSubmit,
}: Props) {
return (
<form className="flex flex-col gap-4" onSubmit={onSubmit}>
<div className="flex flex-col gap-1">
<h2 className="text-xl font-semibold text-zinc-900">Bulk invite</h2>
<p className="text-sm text-zinc-600">
Upload a <span className="font-medium text-zinc-800">.txt</span> file
with one email per line, or a{" "}
<span className="font-medium text-zinc-800">.csv</span> with
<span className="font-medium text-zinc-800"> email</span> and optional
<span className="font-medium text-zinc-800"> name</span> columns.
</p>
</div>
<label
htmlFor="bulk-invite-file-input"
className="flex cursor-pointer flex-col gap-2 rounded-2xl border border-dashed border-zinc-300 bg-zinc-50 px-4 py-5 text-sm text-zinc-600 transition-colors focus-within:ring-2 focus-within:ring-zinc-500 focus-within:ring-offset-2 hover:border-zinc-400 hover:bg-zinc-100"
>
<span className="font-medium text-zinc-900">
{selectedFile ? selectedFile.name : "Choose invite file"}
</span>
<span>Maximum 500 rows, UTF-8 encoded.</span>
<input
id="bulk-invite-file-input"
key={inputKey}
type="file"
accept=".txt,.csv,text/plain,text/csv"
disabled={isSubmitting}
className="sr-only"
onChange={(event) =>
onFileChange(event.target.files?.item(0) ?? null)
}
/>
</label>
<Button
type="submit"
variant="primary"
loading={isSubmitting}
disabled={!selectedFile}
className="w-full"
>
{isSubmitting ? "Uploading invites..." : "Upload invite file"}
</Button>
{lastResult ? (
<div className="flex flex-col gap-3 rounded-2xl border border-zinc-200 bg-zinc-50 p-4">
<div className="grid grid-cols-3 gap-2 text-center">
<div className="rounded-xl bg-white px-3 py-2">
<div className="text-lg font-semibold text-zinc-900">
{lastResult.created_count}
</div>
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
Created
</div>
</div>
<div className="rounded-xl bg-white px-3 py-2">
<div className="text-lg font-semibold text-zinc-900">
{lastResult.skipped_count}
</div>
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
Skipped
</div>
</div>
<div className="rounded-xl bg-white px-3 py-2">
<div className="text-lg font-semibold text-zinc-900">
{lastResult.error_count}
</div>
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
Errors
</div>
</div>
</div>
<div className="max-h-64 overflow-y-auto rounded-xl border border-zinc-200 bg-white">
<div className="flex flex-col divide-y divide-zinc-100">
{lastResult.results.map((row) => (
<div
key={`${row.row_number}-${row.email ?? row.message}`}
className="flex items-start gap-3 px-3 py-3"
>
<Badge variant={getStatusVariant(row.status)} size="small">
{row.status}
</Badge>
<div className="flex min-w-0 flex-1 flex-col gap-1">
<span className="text-sm font-medium text-zinc-900">
Row {row.row_number}
{row.email ? ` · ${row.email}` : ""}
</span>
<span className="text-xs text-zinc-500">{row.message}</span>
</div>
</div>
))}
</div>
</div>
</div>
) : null}
</form>
);
}

View File

@@ -1,66 +0,0 @@
"use client";
import { Button } from "@/components/atoms/Button/Button";
import { Input } from "@/components/atoms/Input/Input";
import type { FormEvent } from "react";
interface Props {
email: string;
name: string;
isSubmitting: boolean;
onEmailChange: (value: string) => void;
onNameChange: (value: string) => void;
onSubmit: (event: FormEvent<HTMLFormElement>) => void;
}
export function InviteUserForm({
email,
name,
isSubmitting,
onEmailChange,
onNameChange,
onSubmit,
}: Props) {
return (
<form className="flex flex-col gap-4" onSubmit={onSubmit}>
<div className="flex flex-col gap-1">
<h2 className="text-xl font-semibold text-zinc-900">Create invite</h2>
<p className="text-sm text-zinc-600">
The invite is stored immediately, then Tally pre-seeding starts in the
background.
</p>
</div>
<Input
id="invite-email"
label="Email"
type="email"
value={email}
placeholder="jane@example.com"
autoComplete="email"
disabled={isSubmitting}
onChange={(event) => onEmailChange(event.target.value)}
/>
<Input
id="invite-name"
label="Name"
type="text"
value={name}
placeholder="Jane Doe"
disabled={isSubmitting}
onChange={(event) => onNameChange(event.target.value)}
/>
<Button
type="submit"
variant="primary"
loading={isSubmitting}
disabled={!email.trim()}
className="w-full"
>
{isSubmitting ? "Creating invite..." : "Create invite"}
</Button>
</form>
);
}

View File

@@ -1,209 +0,0 @@
"use client";
import type { InvitedUserResponse } from "@/app/api/__generated__/models/invitedUserResponse";
import { Badge } from "@/components/atoms/Badge/Badge";
import { Button } from "@/components/atoms/Button/Button";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/__legacy__/ui/table";
interface Props {
invitedUsers: InvitedUserResponse[];
isLoading: boolean;
isRefreshing: boolean;
pendingInviteAction: string | null;
onRetryTally: (invitedUserId: string) => void;
onRevoke: (invitedUserId: string) => void;
}
function getInviteBadgeVariant(status: InvitedUserResponse["status"]) {
if (status === "CLAIMED") {
return "success";
}
if (status === "REVOKED") {
return "error";
}
return "info";
}
function getTallyBadgeVariant(status: InvitedUserResponse["tally_status"]) {
if (status === "READY") {
return "success";
}
if (status === "FAILED") {
return "error";
}
return "info";
}
function formatDate(value: Date | undefined) {
if (!value) {
return "-";
}
return value.toLocaleString();
}
function getTallySummary(invitedUser: InvitedUserResponse) {
if (invitedUser.tally_status === "FAILED" && invitedUser.tally_error) {
return invitedUser.tally_error;
}
if (invitedUser.tally_status === "READY" && invitedUser.tally_understanding) {
return "Stored and ready for activation";
}
if (invitedUser.tally_status === "READY") {
return "No matching Tally submission found";
}
if (invitedUser.tally_status === "RUNNING") {
return "Extraction in progress";
}
return "Waiting to run";
}
function isActionPending(
pendingInviteAction: string | null,
action: "retry" | "revoke",
invitedUserId: string,
) {
return pendingInviteAction === `${action}:${invitedUserId}`;
}
export function InvitedUsersTable({
invitedUsers,
isLoading,
isRefreshing,
pendingInviteAction,
onRetryTally,
onRevoke,
}: Props) {
return (
<div className="flex flex-col gap-4">
<div className="flex items-center justify-between gap-4">
<div className="flex flex-col gap-1">
<h2 className="text-xl font-semibold text-zinc-900">Invited users</h2>
<p className="text-sm text-zinc-600">
Live invite state, claim status, and Tally pre-seeding progress.
</p>
</div>
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
{isRefreshing ? "Refreshing" : `${invitedUsers.length} total`}
</span>
</div>
<div className="overflow-hidden rounded-2xl border border-zinc-200">
<Table>
<TableHeader className="bg-zinc-50">
<TableRow>
<TableHead>Email</TableHead>
<TableHead>Name</TableHead>
<TableHead>Invite</TableHead>
<TableHead>Tally</TableHead>
<TableHead>Claimed User</TableHead>
<TableHead>Created</TableHead>
<TableHead className="text-right">Actions</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{isLoading ? (
<TableRow>
<TableCell
colSpan={7}
className="py-10 text-center text-zinc-500"
>
Loading invited users...
</TableCell>
</TableRow>
) : invitedUsers.length === 0 ? (
<TableRow>
<TableCell
colSpan={7}
className="py-10 text-center text-zinc-500"
>
No invited users yet
</TableCell>
</TableRow>
) : (
invitedUsers.map((invitedUser) => (
<TableRow key={invitedUser.id} className="align-top">
<TableCell className="font-medium text-zinc-900">
{invitedUser.email}
</TableCell>
<TableCell>{invitedUser.name || "-"}</TableCell>
<TableCell>
<Badge variant={getInviteBadgeVariant(invitedUser.status)}>
{invitedUser.status}
</Badge>
</TableCell>
<TableCell>
<div className="flex max-w-xs flex-col gap-2">
<Badge
variant={getTallyBadgeVariant(invitedUser.tally_status)}
>
{invitedUser.tally_status}
</Badge>
<span className="text-xs text-zinc-500">
{getTallySummary(invitedUser)}
</span>
<span className="text-xs text-zinc-400">
{formatDate(invitedUser.tally_computed_at ?? undefined)}
</span>
</div>
</TableCell>
<TableCell className="font-mono text-xs text-zinc-500">
{invitedUser.auth_user_id || "-"}
</TableCell>
<TableCell className="text-sm text-zinc-500">
{formatDate(invitedUser.created_at)}
</TableCell>
<TableCell>
<div className="flex justify-end gap-2">
<Button
variant="outline"
size="small"
disabled={invitedUser.status === "REVOKED"}
loading={isActionPending(
pendingInviteAction,
"retry",
invitedUser.id,
)}
onClick={() => onRetryTally(invitedUser.id)}
>
Retry Tally
</Button>
<Button
variant="secondary"
size="small"
disabled={invitedUser.status !== "INVITED"}
loading={isActionPending(
pendingInviteAction,
"revoke",
invitedUser.id,
)}
onClick={() => onRevoke(invitedUser.id)}
>
Revoke
</Button>
</div>
</TableCell>
</TableRow>
))
)}
</TableBody>
</Table>
</div>
</div>
);
}

View File

@@ -1,11 +1,16 @@
import { withRoleAccess } from "@/lib/withRoleAccess";
import { AdminUsersPage } from "./components/AdminUsersPage/AdminUsersPage";
import React from "react";
function AdminUsers() {
return <AdminUsersPage />;
return (
<div>
<h1>Users Dashboard</h1>
{/* Add your admin-only content here */}
</div>
);
}
export default async function AdminUsersRoute() {
export default async function AdminUsersPage() {
"use server";
const withAdminAccess = await withRoleAccess(["admin"]);
const ProtectedAdminUsers = await withAdminAccess(AdminUsers);

View File

@@ -1,197 +0,0 @@
"use client";
import type { BulkInvitedUsersResponse } from "@/app/api/__generated__/models/bulkInvitedUsersResponse";
import { okData } from "@/app/api/helpers";
import {
getGetV2ListInvitedUsersQueryKey,
useGetV2ListInvitedUsers,
usePostV2BulkCreateInvitedUsers,
usePostV2CreateInvitedUser,
usePostV2RetryInvitedUserTally,
usePostV2RevokeInvitedUser,
} from "@/app/api/__generated__/endpoints/admin/admin";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { useQueryClient } from "@tanstack/react-query";
import { type FormEvent, useState } from "react";
function getErrorMessage(error: unknown) {
if (error instanceof Error) {
return error.message;
}
return "Something went wrong";
}
export function useAdminUsersPage() {
const queryClient = useQueryClient();
const { toast } = useToast();
const [email, setEmail] = useState("");
const [name, setName] = useState("");
const [bulkInviteFile, setBulkInviteFile] = useState<File | null>(null);
const [bulkInviteInputKey, setBulkInviteInputKey] = useState(0);
const [lastBulkInviteResult, setLastBulkInviteResult] =
useState<BulkInvitedUsersResponse | null>(null);
const [pendingInviteAction, setPendingInviteAction] = useState<string | null>(
null,
);
const invitedUsersQuery = useGetV2ListInvitedUsers(undefined, {
query: {
select: okData,
refetchInterval: 30_000,
},
});
const createInvitedUserMutation = usePostV2CreateInvitedUser({
mutation: {
onSuccess: async () => {
setEmail("");
setName("");
await queryClient.invalidateQueries({
queryKey: getGetV2ListInvitedUsersQueryKey(),
});
toast({
title: "Invited user created",
variant: "default",
});
},
onError: (error) => {
toast({
title: getErrorMessage(error),
variant: "destructive",
});
},
},
});
const bulkCreateInvitedUsersMutation = usePostV2BulkCreateInvitedUsers({
mutation: {
onSuccess: async (response) => {
const result = okData(response) ?? null;
setBulkInviteFile(null);
setBulkInviteInputKey((currentValue) => currentValue + 1);
setLastBulkInviteResult(result);
await queryClient.invalidateQueries({
queryKey: getGetV2ListInvitedUsersQueryKey(),
});
toast({
title: result
? `${result.created_count} invites created`
: "Bulk invite upload complete",
variant: "default",
});
},
onError: (error) => {
toast({
title: getErrorMessage(error),
variant: "destructive",
});
},
},
});
const retryInvitedUserTallyMutation = usePostV2RetryInvitedUserTally({
mutation: {
onSuccess: async () => {
setPendingInviteAction(null);
await queryClient.invalidateQueries({
queryKey: getGetV2ListInvitedUsersQueryKey(),
});
toast({
title: "Tally pre-seeding restarted",
variant: "default",
});
},
onError: (error) => {
setPendingInviteAction(null);
toast({
title: getErrorMessage(error),
variant: "destructive",
});
},
},
});
const revokeInvitedUserMutation = usePostV2RevokeInvitedUser({
mutation: {
onSuccess: async () => {
setPendingInviteAction(null);
await queryClient.invalidateQueries({
queryKey: getGetV2ListInvitedUsersQueryKey(),
});
toast({
title: "Invite revoked",
variant: "default",
});
},
onError: (error) => {
setPendingInviteAction(null);
toast({
title: getErrorMessage(error),
variant: "destructive",
});
},
},
});
function handleCreateInvite(event: FormEvent<HTMLFormElement>) {
event.preventDefault();
createInvitedUserMutation.mutate({
data: {
email,
name: name.trim() || null,
},
});
}
function handleRetryTally(invitedUserId: string) {
setPendingInviteAction(`retry:${invitedUserId}`);
retryInvitedUserTallyMutation.mutate({ invitedUserId });
}
function handleBulkInviteFileChange(file: File | null) {
setBulkInviteFile(file);
}
function handleBulkInviteSubmit(event: FormEvent<HTMLFormElement>) {
event.preventDefault();
if (!bulkInviteFile) {
return;
}
bulkCreateInvitedUsersMutation.mutate({
data: {
file: bulkInviteFile,
},
});
}
function handleRevoke(invitedUserId: string) {
setPendingInviteAction(`revoke:${invitedUserId}`);
revokeInvitedUserMutation.mutate({ invitedUserId });
}
return {
email,
name,
bulkInviteFile,
bulkInviteInputKey,
lastBulkInviteResult,
invitedUsers: invitedUsersQuery.data?.invited_users ?? [],
invitedUsersError: invitedUsersQuery.error,
isLoadingInvitedUsers: invitedUsersQuery.isLoading,
isRefreshingInvitedUsers: invitedUsersQuery.isFetching,
isCreatingInvite: createInvitedUserMutation.isPending,
isBulkInviting: bulkCreateInvitedUsersMutation.isPending,
pendingInviteAction,
setEmail,
setName,
handleBulkInviteFileChange,
handleBulkInviteSubmit,
handleCreateInvite,
handleRetryTally,
handleRevoke,
};
}

View File

@@ -75,7 +75,7 @@ export const getSecondCalculatorNode = () => {
export const getFormContainerSelector = (blockId: string): string | null => {
const node = getNodeByBlockId(blockId);
if (node) {
return `[data-id="form-creator-container-${node.id}-node"]`;
return `[data-id="form-creator-container-${node.id}"]`;
}
return null;
};

View File

@@ -7,7 +7,6 @@
*
* Typography (body, small, action, info, tip, warning) uses Tailwind utilities directly in steps.ts
*/
import "shepherd.js/dist/css/shepherd.css";
import "./tutorial.css";
export const injectTutorialStyles = () => {

View File

@@ -1,14 +1,3 @@
.new-builder-tutorial-disable {
opacity: 0.3 !important;
pointer-events: none !important;
filter: grayscale(100%) !important;
}
.new-builder-tutorial-highlight {
position: relative;
z-index: 10;
}
.new-builder-tutorial-highlight * {
opacity: 1 !important;
filter: none !important;

View File

@@ -1,16 +1,20 @@
"use client";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/molecules/DropdownMenu/DropdownMenu";
import { SidebarProvider } from "@/components/ui/sidebar";
import { cn } from "@/lib/utils";
import { UploadSimple } from "@phosphor-icons/react";
import { DotsThree, UploadSimple } from "@phosphor-icons/react";
import { useCallback, useRef, useState } from "react";
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
import { DeleteChatDialog } from "./components/DeleteChatDialog/DeleteChatDialog";
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
import { NotificationBanner } from "./components/NotificationBanner/NotificationBanner";
import { NotificationDialog } from "./components/NotificationDialog/NotificationDialog";
import { ScaleLoader } from "./components/ScaleLoader/ScaleLoader";
import { useCopilotPage } from "./useCopilotPage";
@@ -86,6 +90,7 @@ export function CopilotPage() {
// Delete functionality
sessionToDelete,
isDeleting,
handleDeleteClick,
handleConfirmDelete,
handleCancelDelete,
} = useCopilotPage();
@@ -112,7 +117,6 @@ export function CopilotPage() {
onDrop={handleDrop}
>
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
<NotificationBanner />
{/* Drop overlay */}
<div
className={cn(
@@ -141,6 +145,38 @@ export function CopilotPage() {
isUploadingFiles={isUploadingFiles}
droppedFiles={droppedFiles}
onDroppedFilesConsumed={handleDroppedFilesConsumed}
headerSlot={
isMobile && sessionId ? (
<div className="flex justify-end">
<DropdownMenu>
<DropdownMenuTrigger asChild>
<button
className="rounded p-1.5 hover:bg-neutral-100"
aria-label="More actions"
>
<DotsThree className="h-5 w-5 text-neutral-600" />
</button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem
onClick={() => {
const session = sessions.find(
(s) => s.id === sessionId,
);
if (session) {
handleDeleteClick(session.id, session.title);
}
}}
disabled={isDeleting}
className="text-red-600 focus:bg-red-50 focus:text-red-600"
>
Delete chat
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
</div>
) : undefined
}
/>
</div>
</div>
@@ -165,7 +201,6 @@ export function CopilotPage() {
onCancel={handleCancelDelete}
/>
)}
<NotificationDialog />
</SidebarProvider>
);
}

View File

@@ -2,6 +2,7 @@
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
import { UIDataTypes, UIMessage, UITools } from "ai";
import { LayoutGroup, motion } from "framer-motion";
import { ReactNode } from "react";
import { ChatMessagesContainer } from "../ChatMessagesContainer/ChatMessagesContainer";
import { CopilotChatActionsProvider } from "../CopilotChatActionsProvider/CopilotChatActionsProvider";
import { EmptySession } from "../EmptySession/EmptySession";
@@ -20,6 +21,7 @@ export interface ChatContainerProps {
onSend: (message: string, files?: File[]) => void | Promise<void>;
onStop: () => void;
isUploadingFiles?: boolean;
headerSlot?: ReactNode;
/** Files dropped onto the chat window. */
droppedFiles?: File[];
/** Called after droppedFiles have been consumed by ChatInput. */
@@ -38,6 +40,7 @@ export const ChatContainer = ({
onSend,
onStop,
isUploadingFiles,
headerSlot,
droppedFiles,
onDroppedFilesConsumed,
}: ChatContainerProps) => {
@@ -60,6 +63,7 @@ export const ChatContainer = ({
status={status}
error={error}
isLoading={isLoadingSession}
headerSlot={headerSlot}
sessionID={sessionId}
/>
<motion.div

View File

@@ -30,6 +30,7 @@ interface Props {
status: string;
error: Error | undefined;
isLoading: boolean;
headerSlot?: React.ReactNode;
sessionID?: string | null;
}
@@ -101,6 +102,7 @@ export function ChatMessagesContainer({
status,
error,
isLoading,
headerSlot,
sessionID,
}: Props) {
const lastMessage = messages[messages.length - 1];
@@ -133,6 +135,7 @@ export function ChatMessagesContainer({
return (
<Conversation className="min-h-0 flex-1">
<ConversationContent className="flex flex-1 flex-col gap-6 px-3 py-6">
{headerSlot}
{isLoading && messages.length === 0 && (
<div
className="flex flex-1 items-center justify-center"

View File

@@ -23,37 +23,24 @@ import {
useSidebar,
} from "@/components/ui/sidebar";
import { cn } from "@/lib/utils";
import {
CheckCircle,
DotsThree,
PlusCircleIcon,
PlusIcon,
} from "@phosphor-icons/react";
import { DotsThree, PlusCircleIcon, PlusIcon } from "@phosphor-icons/react";
import { useQueryClient } from "@tanstack/react-query";
import { AnimatePresence, motion } from "framer-motion";
import { parseAsString, useQueryState } from "nuqs";
import { useEffect, useRef, useState } from "react";
import { useCopilotUIStore } from "../../store";
import { NotificationToggle } from "./components/NotificationToggle/NotificationToggle";
import { DeleteChatDialog } from "../DeleteChatDialog/DeleteChatDialog";
import { PulseLoader } from "../PulseLoader/PulseLoader";
import { UsageLimits } from "../UsageLimits/UsageLimits";
export function ChatSidebar() {
const { state } = useSidebar();
const isCollapsed = state === "collapsed";
const [sessionId, setSessionId] = useQueryState("sessionId", parseAsString);
const {
sessionToDelete,
setSessionToDelete,
completedSessionIDs,
clearCompletedSession,
} = useCopilotUIStore();
const { sessionToDelete, setSessionToDelete } = useCopilotUIStore();
const queryClient = useQueryClient();
const { data: sessionsResponse, isLoading: isLoadingSessions } =
useGetV2ListSessions({ limit: 50 }, { query: { refetchInterval: 10_000 } });
useGetV2ListSessions({ limit: 50 });
const { mutate: deleteSession, isPending: isDeleting } =
useDeleteV2DeleteSession({
@@ -112,22 +99,6 @@ export function ChatSidebar() {
}
}, [editingSessionId]);
// Refetch session list when active session changes
useEffect(() => {
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
}, [sessionId, queryClient]);
// Clear completed indicator when navigating to a session (works for all paths)
useEffect(() => {
if (!sessionId || !completedSessionIDs.has(sessionId)) return;
clearCompletedSession(sessionId);
const remaining = completedSessionIDs.size - 1;
document.title =
remaining > 0 ? `(${remaining}) Otto is ready - AutoGPT` : "AutoGPT";
}, [sessionId, completedSessionIDs, clearCompletedSession]);
const sessions =
sessionsResponse?.status === 200 ? sessionsResponse.data.sessions : [];
@@ -257,9 +228,7 @@ export function ChatSidebar() {
<Text variant="h3" size="body-medium">
Your chats
</Text>
<div className="flex items-center">
<UsageLimits />
<NotificationToggle />
<div className="relative left-6">
<SidebarTrigger />
</div>
</div>
@@ -336,8 +305,8 @@ export function ChatSidebar() {
onClick={() => handleSelectSession(session.id)}
className="w-full px-3 py-2.5 pr-10 text-left"
>
<div className="flex min-w-0 max-w-full items-center gap-2">
<div className="min-w-0 flex-1">
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
<div className="min-w-0 max-w-full">
<Text
variant="body"
className={cn(
@@ -360,22 +329,10 @@ export function ChatSidebar() {
</motion.span>
</AnimatePresence>
</Text>
<Text variant="small" className="text-neutral-400">
{formatDate(session.updated_at)}
</Text>
</div>
{session.is_processing &&
session.id !== sessionId &&
!completedSessionIDs.has(session.id) && (
<PulseLoader size={16} className="shrink-0" />
)}
{completedSessionIDs.has(session.id) &&
session.id !== sessionId && (
<CheckCircle
className="h-4 w-4 shrink-0 text-green-500"
weight="fill"
/>
)}
<Text variant="small" className="text-neutral-400">
{formatDate(session.updated_at)}
</Text>
</div>
</button>
)}

View File

@@ -1,90 +0,0 @@
"use client";
import { Switch } from "@/components/atoms/Switch/Switch";
import {
Popover,
PopoverContent,
PopoverTrigger,
} from "@/components/molecules/Popover/Popover";
import { toast } from "@/components/molecules/Toast/use-toast";
import { Button } from "@/components/ui/button";
import { cn } from "@/lib/utils";
import { Bell, BellRinging, BellSlash } from "@phosphor-icons/react";
import { useCopilotUIStore } from "../../../../store";
export function NotificationToggle() {
const {
isNotificationsEnabled,
setNotificationsEnabled,
isSoundEnabled,
toggleSound,
} = useCopilotUIStore();
async function handleToggleNotifications() {
if (isNotificationsEnabled) {
setNotificationsEnabled(false);
return;
}
if (typeof Notification === "undefined") {
toast({
title: "Notifications not supported",
description: "Your browser does not support notifications.",
variant: "destructive",
});
return;
}
const permission = await Notification.requestPermission();
if (permission === "granted") {
setNotificationsEnabled(true);
} else {
toast({
title: "Notifications blocked",
description:
"Please allow notifications in your browser settings to enable this feature.",
variant: "destructive",
});
}
}
return (
<Popover>
<PopoverTrigger asChild>
<Button variant="ghost" size="icon" aria-label="Notification settings">
{!isNotificationsEnabled ? (
<BellSlash className="!size-5" />
) : isSoundEnabled ? (
<BellRinging className="!size-5" />
) : (
<Bell className="!size-5" />
)}
</Button>
</PopoverTrigger>
<PopoverContent align="start" className="w-56 p-3">
<div className="flex flex-col gap-3">
<label className="flex items-center justify-between">
<span className="text-sm text-zinc-700">Notifications</span>
<Switch
checked={isNotificationsEnabled}
onCheckedChange={handleToggleNotifications}
/>
</label>
<label className="flex items-center justify-between">
<span
className={cn(
"text-sm text-zinc-700",
!isNotificationsEnabled && "opacity-50",
)}
>
Sound
</span>
<Switch
checked={isSoundEnabled && isNotificationsEnabled}
onCheckedChange={toggleSound}
disabled={!isNotificationsEnabled}
/>
</label>
</div>
</PopoverContent>
</Popover>
);
}

View File

@@ -1,9 +1,7 @@
"use client";
import { useGetV2GetSuggestedPrompts } from "@/app/api/__generated__/endpoints/chat/chat";
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
import { Button } from "@/components/atoms/Button/Button";
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
import { Text } from "@/components/atoms/Text/Text";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { SpinnerGapIcon } from "@phosphor-icons/react";
@@ -35,38 +33,15 @@ export function EmptySession({
}: Props) {
const { user } = useSupabase();
const greetingName = getGreetingName(user);
const { data: suggestedPromptsResponse, isLoading: isLoadingPrompts } =
useGetV2GetSuggestedPrompts({
query: { staleTime: Infinity },
});
const customPrompts =
suggestedPromptsResponse?.status === 200
? suggestedPromptsResponse.data.prompts
: undefined;
const quickActions = getQuickActions(customPrompts);
const quickActions = getQuickActions();
const [loadingAction, setLoadingAction] = useState<string | null>(null);
const [inputPlaceholder, setInputPlaceholder] = useState(
getInputPlaceholder(),
);
// Use matchMedia instead of resize event — fires only when crossing
// the 500px and 1081px breakpoints defined in getInputPlaceholder(),
// rather than dozens of times per second during a window drag.
useEffect(() => {
function update() {
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
}
const mq500 = window.matchMedia("(min-width: 500px)");
const mq1081 = window.matchMedia("(min-width: 1081px)");
update();
mq500.addEventListener("change", update);
mq1081.addEventListener("change", update);
return () => {
mq500.removeEventListener("change", update);
mq1081.removeEventListener("change", update);
};
}, []);
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
}, [window.innerWidth]);
async function handleQuickActionClick(action: string) {
if (isCreatingSession || loadingAction) return;
@@ -116,32 +91,28 @@ export function EmptySession({
</div>
<div className="flex flex-wrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
{isLoadingPrompts
? Array.from({ length: 3 }, (_, i) => (
<Skeleton key={i} className="h-10 w-64 shrink-0 rounded-full" />
))
: quickActions.map((action) => (
<Button
key={action}
type="button"
variant="outline"
size="small"
onClick={() => void handleQuickActionClick(action)}
disabled={isCreatingSession || loadingAction !== null}
aria-busy={loadingAction === action}
leftIcon={
loadingAction === action ? (
<SpinnerGapIcon
className="h-4 w-4 animate-spin"
weight="bold"
/>
) : null
}
className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
>
{action}
</Button>
))}
{quickActions.map((action) => (
<Button
key={action}
type="button"
variant="outline"
size="small"
onClick={() => void handleQuickActionClick(action)}
disabled={isCreatingSession || loadingAction !== null}
aria-busy={loadingAction === action}
leftIcon={
loadingAction === action ? (
<SpinnerGapIcon
className="h-4 w-4 animate-spin"
weight="bold"
/>
) : null
}
className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
>
{action}
</Button>
))}
</div>
</motion.div>
</div>

View File

@@ -12,17 +12,12 @@ export function getInputPlaceholder(width?: number) {
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
}
const DEFAULT_QUICK_ACTIONS = [
"I don't know where to start, just ask me stuff",
"I do the same thing every week and it's killing me",
"Help me find where I'm wasting my time",
];
export function getQuickActions(customPrompts?: string[]) {
if (customPrompts && customPrompts.length > 0) {
return customPrompts;
}
return DEFAULT_QUICK_ACTIONS;
export function getQuickActions() {
return [
"I don't know where to start, just ask me stuff",
"I do the same thing every week and it's killing me",
"Help me find where I'm wasting my time",
];
}
export function getGreetingName(user?: User | null) {

View File

@@ -3,17 +3,8 @@ import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { scrollbarStyles } from "@/components/styles/scrollbars";
import { cn } from "@/lib/utils";
import {
CheckCircle,
PlusIcon,
SpeakerHigh,
SpeakerSlash,
SpinnerGapIcon,
X,
} from "@phosphor-icons/react";
import { PlusIcon, SpinnerGapIcon, X } from "@phosphor-icons/react";
import { Drawer } from "vaul";
import { useCopilotUIStore } from "../../store";
import { PulseLoader } from "../PulseLoader/PulseLoader";
interface Props {
isOpen: boolean;
@@ -61,13 +52,6 @@ export function MobileDrawer({
onClose,
onOpenChange,
}: Props) {
const {
completedSessionIDs,
clearCompletedSession,
isSoundEnabled,
toggleSound,
} = useCopilotUIStore();
return (
<Drawer.Root open={isOpen} onOpenChange={onOpenChange} direction="left">
<Drawer.Portal>
@@ -78,31 +62,14 @@ export function MobileDrawer({
<Drawer.Title className="text-lg font-semibold text-zinc-800">
Your chats
</Drawer.Title>
<div className="flex items-center gap-1">
<button
onClick={toggleSound}
className="rounded p-1.5 text-zinc-400 transition-colors hover:text-zinc-600"
aria-label={
isSoundEnabled
? "Disable notification sound"
: "Enable notification sound"
}
>
{isSoundEnabled ? (
<SpeakerHigh className="h-4 w-4" />
) : (
<SpeakerSlash className="h-4 w-4" />
)}
</button>
<Button
variant="icon"
size="icon"
aria-label="Close sessions"
onClick={onClose}
>
<X width="1rem" height="1rem" />
</Button>
</div>
<Button
variant="icon"
size="icon"
aria-label="Close sessions"
onClick={onClose}
>
<X width="1rem" height="1rem" />
</Button>
</div>
{currentSessionId ? (
<div className="mt-2">
@@ -136,12 +103,7 @@ export function MobileDrawer({
sessions.map((session) => (
<button
key={session.id}
onClick={() => {
onSelectSession(session.id);
if (completedSessionIDs.has(session.id)) {
clearCompletedSession(session.id);
}
}}
onClick={() => onSelectSession(session.id)}
className={cn(
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
session.id === currentSessionId
@@ -150,7 +112,7 @@ export function MobileDrawer({
)}
>
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
<div className="flex min-w-0 max-w-full items-center gap-1.5">
<div className="min-w-0 max-w-full">
<Text
variant="body"
className={cn(
@@ -162,18 +124,6 @@ export function MobileDrawer({
>
{session.title || "Untitled chat"}
</Text>
{session.is_processing &&
!completedSessionIDs.has(session.id) &&
session.id !== currentSessionId && (
<PulseLoader size={8} className="shrink-0" />
)}
{completedSessionIDs.has(session.id) &&
session.id !== currentSessionId && (
<CheckCircle
className="h-4 w-4 shrink-0 text-green-500"
weight="fill"
/>
)}
</div>
<Text variant="small" className="text-neutral-400">
{formatDate(session.updated_at)}

View File

@@ -1,74 +0,0 @@
"use client";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { Key, storage } from "@/services/storage/local-storage";
import { BellRinging, X } from "@phosphor-icons/react";
import { useEffect, useState } from "react";
import { useCopilotUIStore } from "../../store";
export function NotificationBanner() {
const { setNotificationsEnabled, isNotificationsEnabled } =
useCopilotUIStore();
const [dismissed, setDismissed] = useState(
() => storage.get(Key.COPILOT_NOTIFICATION_BANNER_DISMISSED) === "true",
);
const [permission, setPermission] = useState(() =>
typeof Notification !== "undefined" ? Notification.permission : "denied",
);
// Re-read dismissed flag when notifications are toggled off (e.g. clearCopilotLocalData)
useEffect(() => {
if (!isNotificationsEnabled) {
setDismissed(
storage.get(Key.COPILOT_NOTIFICATION_BANNER_DISMISSED) === "true",
);
}
}, [isNotificationsEnabled]);
// Don't show if notifications aren't supported, already decided, dismissed, or already enabled
if (
typeof Notification === "undefined" ||
permission !== "default" ||
dismissed ||
isNotificationsEnabled
) {
return null;
}
function handleEnable() {
Notification.requestPermission().then((result) => {
setPermission(result);
if (result === "granted") {
setNotificationsEnabled(true);
handleDismiss();
}
});
}
function handleDismiss() {
storage.set(Key.COPILOT_NOTIFICATION_BANNER_DISMISSED, "true");
setDismissed(true);
}
return (
<div className="flex items-center gap-3 border-b border-amber-200 bg-amber-50 px-4 py-2.5">
<BellRinging className="h-5 w-5 shrink-0 text-amber-600" weight="fill" />
<Text variant="body" className="flex-1 text-sm text-amber-800">
Enable browser notifications to know when Otto finishes working, even
when you switch tabs.
</Text>
<Button variant="primary" size="small" onClick={handleEnable}>
Enable
</Button>
<button
onClick={handleDismiss}
className="rounded p-1 text-amber-400 transition-colors hover:text-amber-600"
aria-label="Dismiss"
>
<X className="h-4 w-4" />
</button>
</div>
);
}

View File

@@ -1,95 +0,0 @@
"use client";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Key, storage } from "@/services/storage/local-storage";
import { BellRinging } from "@phosphor-icons/react";
import { useEffect, useState } from "react";
import { useCopilotUIStore } from "../../store";
export function NotificationDialog() {
const {
showNotificationDialog,
setShowNotificationDialog,
setNotificationsEnabled,
isNotificationsEnabled,
} = useCopilotUIStore();
const [dismissed, setDismissed] = useState(
() => storage.get(Key.COPILOT_NOTIFICATION_DIALOG_DISMISSED) === "true",
);
const [permission, setPermission] = useState(() =>
typeof Notification !== "undefined" ? Notification.permission : "denied",
);
// Re-read dismissed flag when notifications are toggled off (e.g. clearCopilotLocalData)
useEffect(() => {
if (!isNotificationsEnabled) {
setDismissed(
storage.get(Key.COPILOT_NOTIFICATION_DIALOG_DISMISSED) === "true",
);
}
}, [isNotificationsEnabled]);
const shouldShowAuto =
typeof Notification !== "undefined" &&
permission === "default" &&
!dismissed;
const isOpen = showNotificationDialog || shouldShowAuto;
function handleEnable() {
if (typeof Notification === "undefined") {
handleDismiss();
return;
}
Notification.requestPermission().then((result) => {
setPermission(result);
if (result === "granted") {
setNotificationsEnabled(true);
handleDismiss();
}
});
}
function handleDismiss() {
storage.set(Key.COPILOT_NOTIFICATION_DIALOG_DISMISSED, "true");
setDismissed(true);
setShowNotificationDialog(false);
}
return (
<Dialog
title="Stay in the loop"
styling={{ maxWidth: "28rem", minWidth: "auto" }}
controlled={{
isOpen,
set: async (open) => {
if (!open) handleDismiss();
},
}}
onClose={handleDismiss}
>
<Dialog.Content>
<div className="flex flex-col items-center gap-4 py-2">
<div className="flex h-12 w-12 items-center justify-center rounded-full bg-violet-100">
<BellRinging className="h-6 w-6 text-violet-600" weight="fill" />
</div>
<Text variant="body" className="text-center text-neutral-600">
Otto can notify you when a response is ready, even if you switch
tabs or close this page. Enable notifications so you never miss one.
</Text>
</div>
<Dialog.Footer>
<Button variant="secondary" onClick={handleDismiss}>
Not now
</Button>
<Button variant="primary" onClick={handleEnable}>
Enable notifications
</Button>
</Dialog.Footer>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -15,8 +15,6 @@
position: absolute;
left: 0;
top: 0;
transform: scale(0);
opacity: 0;
animation: ripple 2s linear infinite;
}
@@ -27,10 +25,7 @@
@keyframes ripple {
0% {
transform: scale(0);
opacity: 0.6;
}
50% {
opacity: 0.3;
opacity: 1;
}
100% {
transform: scale(1);

View File

@@ -1,146 +0,0 @@
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
import {
Popover,
PopoverContent,
PopoverTrigger,
} from "@/components/molecules/Popover/Popover";
import { Button } from "@/components/ui/button";
import { ChartBar } from "@phosphor-icons/react";
import { useUsageLimits } from "./useUsageLimits";
const MS_PER_MINUTE = 60_000;
const MS_PER_HOUR = 3_600_000;
function formatResetTime(resetsAt: Date | string): string {
const resetDate =
typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt;
const now = new Date();
const diffMs = resetDate.getTime() - now.getTime();
if (diffMs <= 0) return "now";
const hours = Math.floor(diffMs / MS_PER_HOUR);
// Under 24h: show relative time ("in 4h 23m")
if (hours < 24) {
const minutes = Math.floor((diffMs % MS_PER_HOUR) / MS_PER_MINUTE);
if (hours > 0) return `in ${hours}h ${minutes}m`;
return `in ${minutes}m`;
}
// Over 24h: show day and time in local timezone ("Mon 12:00 AM PST")
return resetDate.toLocaleString(undefined, {
weekday: "short",
hour: "numeric",
minute: "2-digit",
timeZoneName: "short",
});
}
function UsageBar({
label,
used,
limit,
resetsAt,
}: {
label: string;
used: number;
limit: number;
resetsAt: Date | string;
}) {
if (limit <= 0) return null;
const rawPercent = (used / limit) * 100;
const percent = Math.min(100, Math.round(rawPercent));
const isHigh = percent >= 80;
const percentLabel =
used > 0 && percent === 0 ? "<1% used" : `${percent}% used`;
return (
<div className="flex flex-col gap-1">
<div className="flex items-baseline justify-between">
<span className="text-xs font-medium text-neutral-700">{label}</span>
<span className="text-[11px] tabular-nums text-neutral-500">
{percentLabel}
</span>
</div>
<div className="text-[10px] text-neutral-400">
Resets {formatResetTime(resetsAt)}
</div>
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
<div
className={`h-full rounded-full transition-[width] duration-300 ease-out ${
isHigh ? "bg-orange-500" : "bg-blue-500"
}`}
style={{ width: `${Math.max(used > 0 ? 1 : 0, percent)}%` }}
/>
</div>
</div>
);
}
export function UsagePanelContent({
usage,
showBillingLink = true,
}: {
usage: CoPilotUsageStatus;
showBillingLink?: boolean;
}) {
const hasDailyLimit = usage.daily.limit > 0;
const hasWeeklyLimit = usage.weekly.limit > 0;
if (!hasDailyLimit && !hasWeeklyLimit) {
return (
<div className="text-xs text-neutral-500">No usage limits configured</div>
);
}
return (
<div className="flex flex-col gap-3">
<div className="text-xs font-semibold text-neutral-800">Usage limits</div>
{hasDailyLimit && (
<UsageBar
label="Today"
used={usage.daily.used}
limit={usage.daily.limit}
resetsAt={usage.daily.resets_at}
/>
)}
{hasWeeklyLimit && (
<UsageBar
label="This week"
used={usage.weekly.used}
limit={usage.weekly.limit}
resetsAt={usage.weekly.resets_at}
/>
)}
{showBillingLink && (
<a
href="/profile/credits"
className="text-[11px] text-blue-600 hover:underline"
>
Learn more about usage limits
</a>
)}
</div>
);
}
export function UsageLimits() {
const { data: usage, isLoading } = useUsageLimits();
if (isLoading || !usage) return null;
if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null;
return (
<Popover>
<PopoverTrigger asChild>
<Button variant="ghost" size="icon" aria-label="Usage limits">
<ChartBar className="!size-5" weight="light" />
</Button>
</PopoverTrigger>
<PopoverContent align="start" className="w-64 p-3">
<UsagePanelContent usage={usage} />
</PopoverContent>
</Popover>
);
}

View File

@@ -1,121 +0,0 @@
import { render, screen, cleanup } from "@/tests/integrations/test-utils";
import { afterEach, describe, expect, it, vi } from "vitest";
import { UsageLimits } from "../UsageLimits";
// Mock the useUsageLimits hook
const mockUseUsageLimits = vi.fn();
vi.mock("../useUsageLimits", () => ({
useUsageLimits: () => mockUseUsageLimits(),
}));
// Mock Popover to render children directly (Radix portals don't work in happy-dom)
vi.mock("@/components/molecules/Popover/Popover", () => ({
Popover: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
PopoverTrigger: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
PopoverContent: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
}));
afterEach(() => {
cleanup();
mockUseUsageLimits.mockReset();
});
function makeUsage({
dailyUsed = 500,
dailyLimit = 10000,
weeklyUsed = 2000,
weeklyLimit = 50000,
}: {
dailyUsed?: number;
dailyLimit?: number;
weeklyUsed?: number;
weeklyLimit?: number;
} = {}) {
const future = new Date(Date.now() + 3600 * 1000); // 1h from now
return {
daily: { used: dailyUsed, limit: dailyLimit, resets_at: future },
weekly: { used: weeklyUsed, limit: weeklyLimit, resets_at: future },
};
}
describe("UsageLimits", () => {
it("renders nothing while loading", () => {
mockUseUsageLimits.mockReturnValue({ data: undefined, isLoading: true });
const { container } = render(<UsageLimits />);
expect(container.innerHTML).toBe("");
});
it("renders nothing when no limits are configured", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage({ dailyLimit: 0, weeklyLimit: 0 }),
isLoading: false,
});
const { container } = render(<UsageLimits />);
expect(container.innerHTML).toBe("");
});
it("renders the usage button when limits exist", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage(),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByRole("button", { name: /usage limits/i })).toBeDefined();
});
it("displays daily and weekly usage percentages", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage({ dailyUsed: 5000, dailyLimit: 10000 }),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByText("50% used")).toBeDefined();
expect(screen.getByText("Today")).toBeDefined();
expect(screen.getByText("This week")).toBeDefined();
expect(screen.getByText("Usage limits")).toBeDefined();
});
it("shows only weekly bar when daily limit is 0", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage({
dailyLimit: 0,
weeklyUsed: 25000,
weeklyLimit: 50000,
}),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByText("This week")).toBeDefined();
expect(screen.queryByText("Today")).toBeNull();
});
it("caps percentage at 100% when over limit", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage({ dailyUsed: 15000, dailyLimit: 10000 }),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByText("100% used")).toBeDefined();
});
it("shows learn more link to credits page", () => {
mockUseUsageLimits.mockReturnValue({
data: makeUsage(),
isLoading: false,
});
render(<UsageLimits />);
const link = screen.getByText("Learn more about usage limits");
expect(link).toBeDefined();
expect(link.closest("a")?.getAttribute("href")).toBe("/profile/credits");
});
});

Some files were not shown because too many files have changed in this diff Show More