mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
15 Commits
feat/githu
...
swiftyos/n
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b6903b70d | ||
|
|
e0e7b129ed | ||
|
|
2f6f02971e | ||
|
|
dfda58306b | ||
|
|
f13e4f60c9 | ||
|
|
2246799694 | ||
|
|
2456882e47 | ||
|
|
5fe35fd156 | ||
|
|
43d71107ef | ||
|
|
050dcd02b6 | ||
|
|
72856b0c11 | ||
|
|
5f574a5974 | ||
|
|
c773faca96 | ||
|
|
97d83aaa75 | ||
|
|
182927a1d4 |
@@ -6,11 +6,13 @@ from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
import prisma.enums
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
from backend.copilot.session_types import ChatSessionStartType
|
||||
from backend.data.model import UserTransaction
|
||||
from backend.util.models import Pagination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.invited_user import BulkInvitedUsersResult, InvitedUserRecord
|
||||
from backend.data.model import User
|
||||
|
||||
|
||||
class UserHistoryResponse(BaseModel):
|
||||
@@ -90,3 +92,51 @@ class BulkInvitedUsersResponse(BaseModel):
|
||||
for row in result.results
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class AdminCopilotUserSummary(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
name: Optional[str] = None
|
||||
timezone: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_user(cls, user: "User") -> "AdminCopilotUserSummary":
|
||||
return cls(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
name=user.name,
|
||||
timezone=user.timezone,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at,
|
||||
)
|
||||
|
||||
|
||||
class AdminCopilotUsersResponse(BaseModel):
|
||||
users: list[AdminCopilotUserSummary]
|
||||
|
||||
|
||||
class TriggerCopilotSessionRequest(BaseModel):
|
||||
user_id: str
|
||||
start_type: ChatSessionStartType
|
||||
|
||||
|
||||
class TriggerCopilotSessionResponse(BaseModel):
|
||||
session_id: str
|
||||
start_type: ChatSessionStartType
|
||||
|
||||
|
||||
class SendCopilotEmailsRequest(BaseModel):
|
||||
user_id: str
|
||||
|
||||
|
||||
class SendCopilotEmailsResponse(BaseModel):
|
||||
candidate_count: int
|
||||
processed_count: int
|
||||
sent_count: int
|
||||
skipped_count: int
|
||||
repair_queued_count: int
|
||||
running_count: int
|
||||
failed_count: int
|
||||
|
||||
@@ -2,8 +2,12 @@ import logging
|
||||
import math
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, File, Query, Security, UploadFile
|
||||
from fastapi import APIRouter, File, HTTPException, Query, Security, UploadFile
|
||||
|
||||
from backend.copilot.autopilot import (
|
||||
send_pending_copilot_emails_for_user,
|
||||
trigger_autopilot_session_for_user,
|
||||
)
|
||||
from backend.data.invited_user import (
|
||||
bulk_create_invited_users_from_file,
|
||||
create_invited_user,
|
||||
@@ -12,13 +16,20 @@ from backend.data.invited_user import (
|
||||
revoke_invited_user,
|
||||
)
|
||||
from backend.data.tally import mask_email
|
||||
from backend.data.user import search_users
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from .model import (
|
||||
AdminCopilotUsersResponse,
|
||||
AdminCopilotUserSummary,
|
||||
BulkInvitedUsersResponse,
|
||||
CreateInvitedUserRequest,
|
||||
InvitedUserResponse,
|
||||
InvitedUsersResponse,
|
||||
SendCopilotEmailsRequest,
|
||||
SendCopilotEmailsResponse,
|
||||
TriggerCopilotSessionRequest,
|
||||
TriggerCopilotSessionResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -135,3 +146,95 @@ async def retry_invited_user_tally_route(
|
||||
invited_user_id,
|
||||
)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/copilot/users",
|
||||
response_model=AdminCopilotUsersResponse,
|
||||
summary="Search Copilot Users",
|
||||
operation_id="getV2SearchCopilotUsers",
|
||||
)
|
||||
async def search_copilot_users_route(
|
||||
search: str = Query("", description="Search by email, name, or user ID"),
|
||||
limit: int = Query(20, ge=1, le=50),
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> AdminCopilotUsersResponse:
|
||||
logger.info(
|
||||
"Admin user %s searched Copilot users (query_length=%s, limit=%s)",
|
||||
admin_user_id,
|
||||
len(search.strip()),
|
||||
limit,
|
||||
)
|
||||
users = await search_users(search, limit=limit)
|
||||
return AdminCopilotUsersResponse(
|
||||
users=[AdminCopilotUserSummary.from_user(user) for user in users]
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/copilot/trigger",
|
||||
response_model=TriggerCopilotSessionResponse,
|
||||
summary="Trigger Copilot Session",
|
||||
operation_id="postV2TriggerCopilotSession",
|
||||
)
|
||||
async def trigger_copilot_session_route(
|
||||
request: TriggerCopilotSessionRequest,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> TriggerCopilotSessionResponse:
|
||||
logger.info(
|
||||
"Admin user %s manually triggered %s for user %s",
|
||||
admin_user_id,
|
||||
request.start_type,
|
||||
request.user_id,
|
||||
)
|
||||
try:
|
||||
session = await trigger_autopilot_session_for_user(
|
||||
request.user_id,
|
||||
start_type=request.start_type,
|
||||
)
|
||||
except LookupError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
logger.info(
|
||||
"Admin user %s created manual Copilot session %s for user %s",
|
||||
admin_user_id,
|
||||
session.session_id,
|
||||
request.user_id,
|
||||
)
|
||||
return TriggerCopilotSessionResponse(
|
||||
session_id=session.session_id,
|
||||
start_type=request.start_type,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/copilot/send-emails",
|
||||
response_model=SendCopilotEmailsResponse,
|
||||
summary="Send Pending Copilot Emails",
|
||||
operation_id="postV2SendPendingCopilotEmails",
|
||||
)
|
||||
async def send_pending_copilot_emails_route(
|
||||
request: SendCopilotEmailsRequest,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> SendCopilotEmailsResponse:
|
||||
logger.info(
|
||||
"Admin user %s manually triggered pending Copilot emails for user %s",
|
||||
admin_user_id,
|
||||
request.user_id,
|
||||
)
|
||||
result = await send_pending_copilot_emails_for_user(request.user_id)
|
||||
logger.info(
|
||||
"Admin user %s completed pending Copilot email sweep for user %s "
|
||||
"(candidates=%s, sent=%s, skipped=%s, repairs=%s, running=%s, failed=%s)",
|
||||
admin_user_id,
|
||||
request.user_id,
|
||||
result.candidate_count,
|
||||
result.sent_count,
|
||||
result.skipped_count,
|
||||
result.repair_queued_count,
|
||||
result.running_count,
|
||||
result.failed_count,
|
||||
)
|
||||
return SendCopilotEmailsResponse(**result.model_dump())
|
||||
|
||||
@@ -8,11 +8,15 @@ import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.copilot.autopilot_email import PendingCopilotEmailSweepResult
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.session_types import ChatSessionStartType
|
||||
from backend.data.invited_user import (
|
||||
BulkInvitedUserRowResult,
|
||||
BulkInvitedUsersResult,
|
||||
InvitedUserRecord,
|
||||
)
|
||||
from backend.data.model import User
|
||||
|
||||
from .user_admin_routes import router as user_admin_router
|
||||
|
||||
@@ -72,6 +76,20 @@ def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
|
||||
)
|
||||
|
||||
|
||||
def _sample_user() -> User:
|
||||
now = datetime.now(timezone.utc)
|
||||
return User(
|
||||
id="user-1",
|
||||
email="copilot@example.com",
|
||||
name="Copilot User",
|
||||
timezone="Europe/Madrid",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
stripe_customer_id=None,
|
||||
top_up_config=None,
|
||||
)
|
||||
|
||||
|
||||
def test_get_invited_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
@@ -166,3 +184,107 @@ def test_retry_invited_user_tally(
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["tally_status"] == "RUNNING"
|
||||
|
||||
|
||||
def test_search_copilot_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.search_users",
|
||||
AsyncMock(return_value=[_sample_user()]),
|
||||
)
|
||||
|
||||
response = client.get("/admin/copilot/users", params={"search": "copilot"})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["users"]) == 1
|
||||
assert data["users"][0]["email"] == "copilot@example.com"
|
||||
assert data["users"][0]["timezone"] == "Europe/Madrid"
|
||||
|
||||
|
||||
def test_trigger_copilot_session(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
session = ChatSession.new(
|
||||
"user-1",
|
||||
start_type=ChatSessionStartType.AUTOPILOT_CALLBACK,
|
||||
)
|
||||
trigger = mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.trigger_autopilot_session_for_user",
|
||||
AsyncMock(return_value=session),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/copilot/trigger",
|
||||
json={
|
||||
"user_id": "user-1",
|
||||
"start_type": ChatSessionStartType.AUTOPILOT_CALLBACK.value,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["session_id"] == session.session_id
|
||||
assert response.json()["start_type"] == "AUTOPILOT_CALLBACK"
|
||||
assert trigger.await_args is not None
|
||||
assert trigger.await_args.args[0] == "user-1"
|
||||
assert (
|
||||
trigger.await_args.kwargs["start_type"]
|
||||
== ChatSessionStartType.AUTOPILOT_CALLBACK
|
||||
)
|
||||
|
||||
|
||||
def test_trigger_copilot_session_returns_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.trigger_autopilot_session_for_user",
|
||||
AsyncMock(side_effect=LookupError("User not found with ID: missing-user")),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/copilot/trigger",
|
||||
json={
|
||||
"user_id": "missing-user",
|
||||
"start_type": ChatSessionStartType.AUTOPILOT_NIGHTLY.value,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "User not found with ID: missing-user"
|
||||
|
||||
|
||||
def test_send_pending_copilot_emails(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
send_emails = mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.send_pending_copilot_emails_for_user",
|
||||
AsyncMock(
|
||||
return_value=PendingCopilotEmailSweepResult(
|
||||
candidate_count=1,
|
||||
processed_count=1,
|
||||
sent_count=1,
|
||||
skipped_count=0,
|
||||
repair_queued_count=0,
|
||||
running_count=0,
|
||||
failed_count=0,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/copilot/send-emails",
|
||||
json={"user_id": "user-1"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"candidate_count": 1,
|
||||
"processed_count": 1,
|
||||
"sent_count": 1,
|
||||
"skipped_count": 0,
|
||||
"repair_queued_count": 0,
|
||||
"running_count": 0,
|
||||
"failed_count": 0,
|
||||
}
|
||||
send_emails.assert_awaited_once_with("user-1")
|
||||
|
||||
@@ -3,18 +3,23 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Any, NoReturn
|
||||
from uuid import uuid4
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from prisma.models import UserWorkspaceFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.autopilot import (
|
||||
consume_callback_token,
|
||||
strip_internal_content,
|
||||
unwrap_internal_content,
|
||||
)
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||
from backend.copilot.model import (
|
||||
@@ -27,7 +32,13 @@ from backend.copilot.model import (
|
||||
get_user_sessions,
|
||||
update_session_title,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamHeartbeat,
|
||||
)
|
||||
from backend.copilot.session_types import ChatSessionStartType
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
from backend.copilot.tools.models import (
|
||||
AgentDetailsResponse,
|
||||
@@ -53,6 +64,7 @@ from backend.copilot.tools.models import (
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import get_business_understanding
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
@@ -65,6 +77,187 @@ _UUID_RE = re.compile(
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
STREAM_QUEUE_GET_TIMEOUT_SECONDS = 10.0
|
||||
STREAMING_RESPONSE_HEADERS = {
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
}
|
||||
|
||||
|
||||
def _build_streaming_response(
|
||||
generator: AsyncGenerator[str, None],
|
||||
) -> StreamingResponse:
|
||||
return StreamingResponse(
|
||||
generator,
|
||||
media_type="text/event-stream",
|
||||
headers=STREAMING_RESPONSE_HEADERS,
|
||||
)
|
||||
|
||||
|
||||
async def _unsubscribe_stream_queue(
|
||||
session_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse] | None,
|
||||
) -> None:
|
||||
if subscriber_queue is None:
|
||||
return
|
||||
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_session(session_id, subscriber_queue)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from session {session_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def _stream_subscriber_queue(
|
||||
*,
|
||||
session_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
log_meta: dict[str, Any],
|
||||
started_at: float,
|
||||
label: str,
|
||||
surface_errors: bool,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(
|
||||
subscriber_queue.get(),
|
||||
timeout=STREAM_QUEUE_GET_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
yield StreamHeartbeat().to_sse()
|
||||
continue
|
||||
|
||||
chunk_count += 1
|
||||
if first_chunk_type is None:
|
||||
first_chunk_type = type(chunk).__name__
|
||||
elapsed = (time.perf_counter() - started_at) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] {label} first chunk at {elapsed:.1f}ms, type={first_chunk_type}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunk_type": first_chunk_type,
|
||||
"elapsed_ms": elapsed,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
yield chunk.to_sse()
|
||||
|
||||
if isinstance(chunk, StreamFinish):
|
||||
total_time = (time.perf_counter() - started_at) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] {label} received StreamFinish in {total_time:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunks_yielded": chunk_count,
|
||||
"total_time_ms": total_time,
|
||||
}
|
||||
},
|
||||
)
|
||||
break
|
||||
except GeneratorExit:
|
||||
logger.info(
|
||||
f"[TIMING] {label} client disconnected after {chunk_count} chunks",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunks_yielded": chunk_count,
|
||||
"reason": "client_disconnect",
|
||||
}
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
elapsed = (time.perf_counter() - started_at) * 1000
|
||||
logger.error(
|
||||
f"[TIMING] {label} error after {elapsed:.1f}ms: {exc}",
|
||||
extra={
|
||||
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(exc)}
|
||||
},
|
||||
)
|
||||
if surface_errors:
|
||||
yield StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="stream_error",
|
||||
).to_sse()
|
||||
yield StreamFinish().to_sse()
|
||||
finally:
|
||||
total_time = (time.perf_counter() - started_at) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] {label} finished in {total_time:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time,
|
||||
"chunks_yielded": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
}
|
||||
},
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
async def _stream_chat_events(
|
||||
*,
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
subscribe_from_id: str,
|
||||
turn_id: str,
|
||||
log_meta: dict[str, Any],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
started_at = time.perf_counter()
|
||||
subscriber_queue = await stream_registry.subscribe_to_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
last_message_id=subscribe_from_id,
|
||||
)
|
||||
|
||||
try:
|
||||
if subscriber_queue is None:
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
async for chunk in _stream_subscriber_queue(
|
||||
session_id=session_id,
|
||||
subscriber_queue=subscriber_queue,
|
||||
log_meta=log_meta,
|
||||
started_at=started_at,
|
||||
label=f"stream_chat_post[{turn_id}]",
|
||||
surface_errors=True,
|
||||
):
|
||||
yield chunk
|
||||
finally:
|
||||
await _unsubscribe_stream_queue(session_id, subscriber_queue)
|
||||
|
||||
|
||||
async def _resume_stream_events(
|
||||
*,
|
||||
session_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
started_at = time.perf_counter()
|
||||
try:
|
||||
async for chunk in _stream_subscriber_queue(
|
||||
session_id=session_id,
|
||||
subscriber_queue=subscriber_queue,
|
||||
log_meta={"session_id": session_id},
|
||||
started_at=started_at,
|
||||
label=f"resume_stream[{session_id}]",
|
||||
surface_errors=False,
|
||||
):
|
||||
yield chunk
|
||||
finally:
|
||||
await _unsubscribe_stream_queue(session_id, subscriber_queue)
|
||||
|
||||
|
||||
async def _validate_and_get_session(
|
||||
@@ -118,6 +311,8 @@ class SessionDetailResponse(BaseModel):
|
||||
created_at: str
|
||||
updated_at: str
|
||||
user_id: str | None
|
||||
start_type: ChatSessionStartType
|
||||
execution_tag: str | None = None
|
||||
messages: list[dict]
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
|
||||
@@ -129,6 +324,8 @@ class SessionSummaryResponse(BaseModel):
|
||||
created_at: str
|
||||
updated_at: str
|
||||
title: str | None = None
|
||||
start_type: ChatSessionStartType
|
||||
execution_tag: str | None = None
|
||||
is_processing: bool
|
||||
|
||||
|
||||
@@ -160,6 +357,14 @@ class UpdateSessionTitleRequest(BaseModel):
|
||||
return stripped
|
||||
|
||||
|
||||
class ConsumeCallbackTokenRequest(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
class ConsumeCallbackTokenResponse(BaseModel):
|
||||
session_id: str
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@@ -171,6 +376,7 @@ async def list_sessions(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
limit: int = Query(default=50, ge=1, le=100),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
with_auto: bool = Query(default=False),
|
||||
) -> ListSessionsResponse:
|
||||
"""
|
||||
List chat sessions for the authenticated user.
|
||||
@@ -186,7 +392,12 @@ async def list_sessions(
|
||||
Returns:
|
||||
ListSessionsResponse: List of session summaries and total count.
|
||||
"""
|
||||
sessions, total_count = await get_user_sessions(user_id, limit, offset)
|
||||
sessions, total_count = await get_user_sessions(
|
||||
user_id,
|
||||
limit,
|
||||
offset,
|
||||
with_auto=with_auto,
|
||||
)
|
||||
|
||||
# Batch-check Redis for active stream status on each session
|
||||
processing_set: set[str] = set()
|
||||
@@ -217,6 +428,8 @@ async def list_sessions(
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
title=session.title,
|
||||
start_type=session.start_type,
|
||||
execution_tag=session.execution_tag,
|
||||
is_processing=session.session_id in processing_set,
|
||||
)
|
||||
for session in sessions
|
||||
@@ -368,12 +581,26 @@ async def get_session(
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
messages = []
|
||||
for message in session.messages:
|
||||
payload = message.model_dump()
|
||||
if message.role == "user":
|
||||
visible_content = strip_internal_content(message.content)
|
||||
if (
|
||||
visible_content is None
|
||||
and session.start_type != ChatSessionStartType.MANUAL
|
||||
):
|
||||
visible_content = unwrap_internal_content(message.content)
|
||||
if visible_content is None:
|
||||
continue
|
||||
payload["content"] = visible_content
|
||||
messages.append(payload)
|
||||
|
||||
# Check if there's an active stream for this session
|
||||
active_stream_info = None
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
session_id,
|
||||
user_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
|
||||
@@ -394,11 +621,28 @@ async def get_session(
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
start_type=session.start_type,
|
||||
execution_tag=session.execution_tag,
|
||||
messages=messages,
|
||||
active_stream=active_stream_info,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/callback-token/consume",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
)
|
||||
async def consume_callback_token_route(
|
||||
request: ConsumeCallbackTokenRequest,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> ConsumeCallbackTokenResponse:
|
||||
try:
|
||||
result = await consume_callback_token(request.token, user_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
return ConsumeCallbackTokenResponse(session_id=result.session_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/cancel",
|
||||
status_code=200,
|
||||
@@ -472,9 +716,6 @@ async def stream_chat_post(
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
stream_start_time = time.perf_counter()
|
||||
log_meta = {"component": "ChatStream", "session_id": session_id}
|
||||
if user_id:
|
||||
@@ -506,18 +747,14 @@ async def stream_chat_post(
|
||||
|
||||
if valid_ids:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Batch query instead of N+1
|
||||
files = await UserWorkspaceFile.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": valid_ids},
|
||||
"workspaceId": workspace.id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
files = await workspace_db().get_workspace_files_by_ids(
|
||||
workspace_id=workspace.id,
|
||||
file_ids=valid_ids,
|
||||
)
|
||||
# Only keep IDs that actually exist in the user's workspace
|
||||
sanitized_file_ids = [wf.id for wf in files] or None
|
||||
file_lines: list[str] = [
|
||||
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
|
||||
f"- {wf.name} ({wf.mime_type}, {round(wf.size_bytes / 1024, 1)} KB), file_id={wf.id}"
|
||||
for wf in files
|
||||
]
|
||||
if file_lines:
|
||||
@@ -587,141 +824,14 @@ async def stream_chat_post(
|
||||
f"[TIMING] Task enqueued to RabbitMQ, setup={setup_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||
)
|
||||
|
||||
# SSE endpoint that subscribes to the task's stream
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import time as time_module
|
||||
|
||||
event_gen_start = time_module.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING] event_generator STARTED, turn={turn_id}, session={session_id}, "
|
||||
f"user={user_id}",
|
||||
extra={"json_fields": log_meta},
|
||||
return _build_streaming_response(
|
||||
_stream_chat_events(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
subscribe_from_id=subscribe_from_id,
|
||||
turn_id=turn_id,
|
||||
log_meta=log_meta,
|
||||
)
|
||||
subscriber_queue = None
|
||||
first_chunk_yielded = False
|
||||
chunks_yielded = 0
|
||||
try:
|
||||
# Subscribe from the position we captured before enqueuing
|
||||
# This avoids replaying old messages while catching all new ones
|
||||
subscriber_queue = await stream_registry.subscribe_to_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
last_message_id=subscribe_from_id,
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
# Read from the subscriber queue and yield to SSE
|
||||
logger.info(
|
||||
"[TIMING] Starting to read from subscriber_queue",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
|
||||
chunks_yielded += 1
|
||||
|
||||
if not first_chunk_yielded:
|
||||
first_chunk_yielded = True
|
||||
elapsed = time_module.perf_counter() - event_gen_start
|
||||
logger.info(
|
||||
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
|
||||
f"type={type(chunk).__name__}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunk_type": type(chunk).__name__,
|
||||
"elapsed_ms": elapsed * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
total_time = time_module.perf_counter() - event_gen_start
|
||||
logger.info(
|
||||
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
|
||||
f"n_chunks={chunks_yielded}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunks_yielded": chunks_yielded,
|
||||
"total_time_ms": total_time * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
yield StreamHeartbeat().to_sse()
|
||||
|
||||
except GeneratorExit:
|
||||
logger.info(
|
||||
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunks_yielded": chunks_yielded,
|
||||
"reason": "client_disconnect",
|
||||
}
|
||||
},
|
||||
)
|
||||
pass # Client disconnected - background task continues
|
||||
except Exception as e:
|
||||
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
||||
logger.error(
|
||||
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
|
||||
extra={
|
||||
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
||||
},
|
||||
)
|
||||
# Surface error to frontend so it doesn't appear stuck
|
||||
yield StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="stream_error",
|
||||
).to_sse()
|
||||
yield StreamFinish().to_sse()
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
if subscriber_queue is not None:
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_session(
|
||||
session_id, subscriber_queue
|
||||
)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from session {session_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||
total_time = time_module.perf_counter() - event_gen_start
|
||||
logger.info(
|
||||
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
|
||||
f"turn={turn_id}, session={session_id}, n_chunks={chunks_yielded}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time * 1000,
|
||||
"chunks_yielded": chunks_yielded,
|
||||
}
|
||||
},
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -747,11 +857,7 @@ async def resume_session_stream(
|
||||
StreamingResponse (SSE) when an active stream exists,
|
||||
or 204 No Content when there is nothing to resume.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
active_session, _ = await stream_registry.get_active_session(session_id, user_id)
|
||||
|
||||
if not active_session:
|
||||
return Response(status_code=204)
|
||||
@@ -768,64 +874,11 @@ async def resume_session_stream(
|
||||
|
||||
if subscriber_queue is None:
|
||||
return Response(status_code=204)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Resume stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
|
||||
if isinstance(chunk, StreamFinish):
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
yield StreamHeartbeat().to_sse()
|
||||
except GeneratorExit:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error in resume stream for session {session_id}: {e}")
|
||||
finally:
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_session(
|
||||
session_id, subscriber_queue
|
||||
)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from session {active_session.session_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.info(
|
||||
"Resume stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"n_chunks": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
return _build_streaming_response(
|
||||
_resume_stream_events(
|
||||
session_id=session_id,
|
||||
subscriber_queue=subscriber_queue,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -977,6 +1030,6 @@ ToolResponseUnion = (
|
||||
description="This endpoint is not meant to be called. It exists solely to "
|
||||
"expose tool response models in the OpenAPI schema for frontend codegen.",
|
||||
)
|
||||
async def _tool_response_schema() -> ToolResponseUnion: # type: ignore[return]
|
||||
async def _tool_response_schema() -> NoReturn:
|
||||
"""Never called at runtime. Exists only so Orval generates TS types."""
|
||||
raise HTTPException(status_code=501, detail="Schema-only endpoint")
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""Tests for chat API routes: session title update, file attachment validation, and suggested prompts."""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import fastapi
|
||||
@@ -8,6 +11,9 @@ import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.chat import routes as chat_routes
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.response_model import StreamFinish
|
||||
from backend.copilot.session_types import ChatSessionStartType
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(chat_routes.router)
|
||||
@@ -115,6 +121,238 @@ def test_update_title_not_found(
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_list_sessions_defaults_to_manual_only(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
started_at = datetime.now(timezone.utc)
|
||||
mock_get_user_sessions = mocker.patch(
|
||||
"backend.api.features.chat.routes.get_user_sessions",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(
|
||||
[
|
||||
SimpleNamespace(
|
||||
session_id="sess-1",
|
||||
started_at=started_at,
|
||||
updated_at=started_at,
|
||||
title="Nightly check-in",
|
||||
start_type=chat_routes.ChatSessionStartType.AUTOPILOT_NIGHTLY,
|
||||
execution_tag="autopilot-nightly:2026-03-13",
|
||||
)
|
||||
],
|
||||
1,
|
||||
),
|
||||
)
|
||||
|
||||
pipe = MagicMock()
|
||||
pipe.hget = MagicMock()
|
||||
pipe.execute = AsyncMock(return_value=["running"])
|
||||
redis = MagicMock()
|
||||
redis.pipeline = MagicMock(return_value=pipe)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=redis,
|
||||
)
|
||||
|
||||
response = client.get("/sessions")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"sessions": [
|
||||
{
|
||||
"id": "sess-1",
|
||||
"created_at": started_at.isoformat(),
|
||||
"updated_at": started_at.isoformat(),
|
||||
"title": "Nightly check-in",
|
||||
"start_type": "AUTOPILOT_NIGHTLY",
|
||||
"execution_tag": "autopilot-nightly:2026-03-13",
|
||||
"is_processing": True,
|
||||
}
|
||||
],
|
||||
"total": 1,
|
||||
}
|
||||
mock_get_user_sessions.assert_awaited_once_with(
|
||||
test_user_id,
|
||||
50,
|
||||
0,
|
||||
with_auto=False,
|
||||
)
|
||||
|
||||
|
||||
def test_list_sessions_can_include_auto_sessions(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_get_user_sessions = mocker.patch(
|
||||
"backend.api.features.chat.routes.get_user_sessions",
|
||||
new_callable=AsyncMock,
|
||||
return_value=([], 0),
|
||||
)
|
||||
|
||||
response = client.get("/sessions?with_auto=true")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"sessions": [], "total": 0}
|
||||
mock_get_user_sessions.assert_awaited_once_with(
|
||||
test_user_id,
|
||||
50,
|
||||
0,
|
||||
with_auto=True,
|
||||
)
|
||||
|
||||
|
||||
def test_consume_callback_token_route_returns_session_id(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mock_consume = mocker.patch(
|
||||
"backend.api.features.chat.routes.consume_callback_token",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SimpleNamespace(session_id="sess-2"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/callback-token/consume",
|
||||
json={"token": "token-123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"session_id": "sess-2"}
|
||||
mock_consume.assert_awaited_once_with("token-123", TEST_USER_ID)
|
||||
|
||||
|
||||
def test_consume_callback_token_route_returns_404_on_invalid_token(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.consume_callback_token",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("Callback token not found"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/callback-token/consume",
|
||||
json={"token": "token-123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json() == {"detail": "Callback token not found"}
|
||||
|
||||
|
||||
def test_get_session_hides_internal_only_messages_for_manual_sessions(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
session = ChatSession.new(
|
||||
TEST_USER_ID,
|
||||
start_type=ChatSessionStartType.MANUAL,
|
||||
)
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="<internal>hidden</internal>"),
|
||||
ChatMessage(
|
||||
role="user",
|
||||
content="Visible<internal>hidden</internal> text",
|
||||
),
|
||||
ChatMessage(role="assistant", content="Public response"),
|
||||
]
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.get_active_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(None, None),
|
||||
)
|
||||
|
||||
response = client.get(f"/sessions/{session.session_id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Visible text",
|
||||
"name": None,
|
||||
"tool_call_id": None,
|
||||
"refusal": None,
|
||||
"tool_calls": None,
|
||||
"function_call": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Public response",
|
||||
"name": None,
|
||||
"tool_call_id": None,
|
||||
"refusal": None,
|
||||
"tool_calls": None,
|
||||
"function_call": None,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_get_session_shows_cleaned_internal_kickoff_for_autopilot_sessions(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
session = ChatSession.new(
|
||||
TEST_USER_ID,
|
||||
start_type=ChatSessionStartType.AUTOPILOT_NIGHTLY,
|
||||
execution_tag="autopilot-nightly:2026-03-13",
|
||||
)
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="<internal>hidden</internal>"),
|
||||
ChatMessage(
|
||||
role="user",
|
||||
content="Visible<internal>hidden</internal> text",
|
||||
),
|
||||
ChatMessage(role="assistant", content="Public response"),
|
||||
]
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.get_active_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(None, None),
|
||||
)
|
||||
|
||||
response = client.get(f"/sessions/{session.session_id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hidden",
|
||||
"name": None,
|
||||
"tool_call_id": None,
|
||||
"refusal": None,
|
||||
"tool_calls": None,
|
||||
"function_call": None,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Visible text",
|
||||
"name": None,
|
||||
"tool_call_id": None,
|
||||
"refusal": None,
|
||||
"tool_calls": None,
|
||||
"function_call": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Public response",
|
||||
"name": None,
|
||||
"tool_call_id": None,
|
||||
"refusal": None,
|
||||
"tool_calls": None,
|
||||
"function_call": None,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# ─── file_ids Pydantic validation ─────────────────────────────────────
|
||||
|
||||
|
||||
@@ -142,7 +380,11 @@ def _mock_stream_internals(mocker: pytest_mock.MockFixture):
|
||||
return_value=None,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
subscriber_queue = asyncio.Queue()
|
||||
subscriber_queue.put_nowait(StreamFinish())
|
||||
mock_registry.create_session = mocker.AsyncMock(return_value=None)
|
||||
mock_registry.subscribe_to_session = mocker.AsyncMock(return_value=subscriber_queue)
|
||||
mock_registry.unsubscribe_from_session = mocker.AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
@@ -165,11 +407,11 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
workspace_store = mocker.MagicMock()
|
||||
workspace_store.get_workspace_files_by_ids = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
"backend.api.features.chat.routes.workspace_db",
|
||||
return_value=workspace_store,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
@@ -195,11 +437,11 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
workspace_store = mocker.MagicMock()
|
||||
workspace_store.get_workspace_files_by_ids = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
"backend.api.features.chat.routes.workspace_db",
|
||||
return_value=workspace_store,
|
||||
)
|
||||
|
||||
valid_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
@@ -217,9 +459,10 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
)
|
||||
|
||||
# The find_many call should only receive the one valid UUID
|
||||
mock_prisma.find_many.assert_called_once()
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["id"]["in"] == [valid_id]
|
||||
workspace_store.get_workspace_files_by_ids.assert_called_once_with(
|
||||
workspace_id="ws-1",
|
||||
file_ids=[valid_id],
|
||||
)
|
||||
|
||||
|
||||
# ─── Cross-workspace file_ids ─────────────────────────────────────────
|
||||
@@ -233,11 +476,11 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
return_value=type("W", (), {"id": "my-workspace-id"})(),
|
||||
)
|
||||
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
workspace_store = mocker.MagicMock()
|
||||
workspace_store.get_workspace_files_by_ids = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
"backend.api.features.chat.routes.workspace_db",
|
||||
return_value=workspace_store,
|
||||
)
|
||||
|
||||
fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
@@ -246,9 +489,10 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
json={"message": "hi", "file_ids": [fid]},
|
||||
)
|
||||
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
|
||||
assert call_kwargs["where"]["isDeleted"] is False
|
||||
workspace_store.get_workspace_files_by_ids.assert_called_once_with(
|
||||
workspace_id="my-workspace-id",
|
||||
file_ids=[fid],
|
||||
)
|
||||
|
||||
|
||||
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
||||
|
||||
@@ -11,10 +11,7 @@ from backend.blocks._base import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import parse_data_uri, resolve_media_content
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
@@ -181,8 +178,7 @@ class FileOperation(StrEnum):
|
||||
|
||||
class FileOperationInput(TypedDict):
|
||||
path: str
|
||||
# MediaFileType is a str NewType — no runtime breakage for existing callers.
|
||||
content: MediaFileType
|
||||
content: str
|
||||
operation: FileOperation
|
||||
|
||||
|
||||
@@ -279,11 +275,11 @@ class GithubMultiFileCommitBlock(Block):
|
||||
base_tree_sha = commit_data["tree"]["sha"]
|
||||
|
||||
# 3. Build tree entries for each file operation (blobs created concurrently)
|
||||
async def _create_blob(content: str, encoding: str = "utf-8") -> str:
|
||||
async def _create_blob(content: str) -> str:
|
||||
blob_url = repo_url + "/git/blobs"
|
||||
blob_response = await api.post(
|
||||
blob_url,
|
||||
json={"content": content, "encoding": encoding},
|
||||
json={"content": content, "encoding": "utf-8"},
|
||||
)
|
||||
return blob_response.json()["sha"]
|
||||
|
||||
@@ -305,19 +301,10 @@ class GithubMultiFileCommitBlock(Block):
|
||||
else:
|
||||
upsert_files.append((path, file_op.get("content", "")))
|
||||
|
||||
# Create all blobs concurrently. Data URIs (from store_media_file)
|
||||
# are sent as base64 blobs to preserve binary content.
|
||||
# Create all blobs concurrently
|
||||
if upsert_files:
|
||||
|
||||
async def _make_blob(content: str) -> str:
|
||||
parsed = parse_data_uri(content)
|
||||
if parsed is not None:
|
||||
_, b64_payload = parsed
|
||||
return await _create_blob(b64_payload, encoding="base64")
|
||||
return await _create_blob(content)
|
||||
|
||||
blob_shas = await asyncio.gather(
|
||||
*[_make_blob(content) for _, content in upsert_files]
|
||||
*[_create_blob(content) for _, content in upsert_files]
|
||||
)
|
||||
for (path, _), blob_sha in zip(upsert_files, blob_shas):
|
||||
tree_entries.append(
|
||||
@@ -371,36 +358,15 @@ class GithubMultiFileCommitBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
# Resolve media references (workspace://, data:, URLs) to data
|
||||
# URIs so _make_blob can send binary content correctly.
|
||||
resolved_files: list[FileOperationInput] = []
|
||||
for file_op in input_data.files:
|
||||
content = file_op.get("content", "")
|
||||
operation = FileOperation(file_op.get("operation", "upsert"))
|
||||
if operation != FileOperation.DELETE:
|
||||
content = await resolve_media_content(
|
||||
MediaFileType(content),
|
||||
execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
resolved_files.append(
|
||||
FileOperationInput(
|
||||
path=file_op["path"],
|
||||
content=MediaFileType(content),
|
||||
operation=operation,
|
||||
)
|
||||
)
|
||||
|
||||
sha, url = await self.multi_file_commit(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
resolved_files,
|
||||
input_data.files,
|
||||
)
|
||||
yield "sha", sha
|
||||
yield "url", url
|
||||
|
||||
@@ -8,7 +8,6 @@ from backend.blocks.github.pull_requests import (
|
||||
GithubMergePullRequestBlock,
|
||||
prepare_pr_api_url,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
# ── prepare_pr_api_url tests ──
|
||||
@@ -98,11 +97,7 @@ async def test_multi_file_commit_error_path():
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
with pytest.raises(BlockExecutionError, match="ref update failed"):
|
||||
async for _ in block.execute(
|
||||
input_data,
|
||||
credentials=TEST_CREDENTIALS,
|
||||
execution_context=ExecutionContext(),
|
||||
):
|
||||
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
135
autogpt_platform/backend/backend/copilot/autopilot.py
Normal file
135
autogpt_platform/backend/backend/copilot/autopilot.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Autopilot public API — thin facade re-exporting from sub-modules.
|
||||
|
||||
Implementation is split by responsibility:
|
||||
- autopilot_prompts: constants, prompt templates, context builders
|
||||
- autopilot_dispatch: timezone helpers, session creation, dispatch/scheduling
|
||||
- autopilot_completion: completion report extraction, repair, handler
|
||||
- autopilot_email: email sending, link building, notification sweep
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.autopilot_completion import ( # noqa: F401
|
||||
CompletionReportToolCall,
|
||||
CompletionReportToolCallFunction,
|
||||
ToolOutputEnvelope,
|
||||
_build_completion_report_repair_message,
|
||||
_extract_completion_report_from_session,
|
||||
_get_pending_approval_metadata,
|
||||
_queue_completion_report_repair,
|
||||
handle_non_manual_session_completion,
|
||||
)
|
||||
from backend.copilot.autopilot_dispatch import ( # noqa: F401
|
||||
_bucket_end_for_now,
|
||||
_create_autopilot_session,
|
||||
_crosses_local_midnight,
|
||||
_enqueue_session_turn,
|
||||
_resolve_timezone_name,
|
||||
_session_exists_for_execution_tag,
|
||||
_try_create_callback_session,
|
||||
_try_create_invite_cta_session,
|
||||
_try_create_nightly_session,
|
||||
_user_has_recent_manual_message,
|
||||
_user_has_session_since,
|
||||
dispatch_nightly_copilot,
|
||||
get_callback_execution_tag,
|
||||
get_graph_exec_id_for_session,
|
||||
get_invite_cta_execution_tag,
|
||||
get_nightly_execution_tag,
|
||||
trigger_autopilot_session_for_user,
|
||||
)
|
||||
from backend.copilot.autopilot_email import ( # noqa: F401
|
||||
PendingCopilotEmailSweepResult,
|
||||
_build_session_link,
|
||||
_get_completion_email_template_name,
|
||||
_markdown_to_email_html,
|
||||
_send_completion_email,
|
||||
_send_nightly_copilot_emails,
|
||||
send_nightly_copilot_emails,
|
||||
send_pending_copilot_emails_for_user,
|
||||
)
|
||||
from backend.copilot.autopilot_prompts import ( # noqa: F401
|
||||
AUTOPILOT_CALLBACK_EMAIL_TEMPLATE,
|
||||
AUTOPILOT_CALLBACK_TAG,
|
||||
AUTOPILOT_DISABLED_TOOLS,
|
||||
AUTOPILOT_INVITE_CTA_EMAIL_TEMPLATE,
|
||||
AUTOPILOT_INVITE_CTA_TAG,
|
||||
AUTOPILOT_NIGHTLY_EMAIL_TEMPLATE,
|
||||
AUTOPILOT_NIGHTLY_TAG_PREFIX,
|
||||
DEFAULT_AUTOPILOT_CALLBACK_SYSTEM_PROMPT,
|
||||
DEFAULT_AUTOPILOT_INVITE_CTA_SYSTEM_PROMPT,
|
||||
DEFAULT_AUTOPILOT_NIGHTLY_SYSTEM_PROMPT,
|
||||
INTERNAL_TAG_RE,
|
||||
MAX_COMPLETION_REPORT_REPAIRS,
|
||||
_build_autopilot_system_prompt,
|
||||
_format_start_type_label,
|
||||
_get_recent_manual_session_context,
|
||||
_get_recent_sent_email_context,
|
||||
_get_recent_session_summary_context,
|
||||
strip_internal_content,
|
||||
unwrap_internal_content,
|
||||
wrap_internal_message,
|
||||
)
|
||||
from backend.copilot.model import ChatMessage, create_chat_session
|
||||
from backend.data.db_accessors import chat_db
|
||||
|
||||
# -- re-exports from sub-modules (preserves existing import paths) ---------- #
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CallbackTokenConsumeResult(BaseModel):
|
||||
session_id: str
|
||||
|
||||
|
||||
async def consume_callback_token(
|
||||
token_id: str,
|
||||
user_id: str,
|
||||
) -> CallbackTokenConsumeResult:
|
||||
"""Consume a callback token and return the resulting session.
|
||||
|
||||
Uses an atomic consume-and-update to prevent the TOCTOU race where two
|
||||
concurrent requests could each see the token as unconsumed and create
|
||||
duplicate sessions.
|
||||
"""
|
||||
db = chat_db()
|
||||
token = await db.get_chat_session_callback_token(token_id)
|
||||
if token is None or token.user_id != user_id:
|
||||
raise ValueError("Callback token not found")
|
||||
if token.expires_at <= datetime.now(UTC):
|
||||
raise ValueError("Callback token has expired")
|
||||
|
||||
if token.consumed_session_id:
|
||||
return CallbackTokenConsumeResult(session_id=token.consumed_session_id)
|
||||
|
||||
session = await create_chat_session(
|
||||
user_id,
|
||||
initial_messages=[
|
||||
ChatMessage(role="assistant", content=token.callback_session_message)
|
||||
],
|
||||
)
|
||||
|
||||
# Atomically mark consumed — if another request already consumed the token
|
||||
# concurrently, the DB will have a non-null consumed_session_id; we re-read
|
||||
# and return the winner's session instead.
|
||||
await db.mark_chat_session_callback_token_consumed(
|
||||
token_id,
|
||||
session.session_id,
|
||||
)
|
||||
|
||||
# Re-read to see if we won the race
|
||||
refreshed = await db.get_chat_session_callback_token(token_id)
|
||||
if (
|
||||
refreshed
|
||||
and refreshed.consumed_session_id
|
||||
and refreshed.consumed_session_id != session.session_id
|
||||
):
|
||||
return CallbackTokenConsumeResult(session_id=refreshed.consumed_session_id)
|
||||
|
||||
return CallbackTokenConsumeResult(session_id=session.session_id)
|
||||
198
autogpt_platform/backend/backend/copilot/autopilot_completion.py
Normal file
198
autogpt_platform/backend/backend/copilot/autopilot_completion.py
Normal file
@@ -0,0 +1,198 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from backend.copilot.autopilot_dispatch import (
|
||||
_enqueue_session_turn,
|
||||
get_graph_exec_id_for_session,
|
||||
)
|
||||
from backend.copilot.autopilot_prompts import (
|
||||
MAX_COMPLETION_REPORT_REPAIRS,
|
||||
wrap_internal_message,
|
||||
)
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
get_chat_session,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.copilot.session_types import CompletionReportInput, StoredCompletionReport
|
||||
from backend.data.db_accessors import review_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# --------------- models --------------- #
|
||||
|
||||
|
||||
class CompletionReportToolCallFunction(BaseModel):
|
||||
name: str | None = None
|
||||
arguments: str | None = None
|
||||
|
||||
|
||||
class CompletionReportToolCall(BaseModel):
|
||||
id: str
|
||||
function: CompletionReportToolCallFunction = Field(
|
||||
default_factory=CompletionReportToolCallFunction
|
||||
)
|
||||
|
||||
|
||||
class ToolOutputEnvelope(BaseModel):
|
||||
type: str | None = None
|
||||
|
||||
|
||||
# --------------- approval metadata --------------- #
|
||||
|
||||
|
||||
async def _get_pending_approval_metadata(
|
||||
session: ChatSession,
|
||||
) -> tuple[int, str | None]:
|
||||
graph_exec_id = get_graph_exec_id_for_session(session.session_id)
|
||||
pending_count = await review_db().count_pending_reviews_for_graph_exec(
|
||||
graph_exec_id,
|
||||
session.user_id,
|
||||
)
|
||||
return pending_count, graph_exec_id if pending_count > 0 else None
|
||||
|
||||
|
||||
# --------------- extraction --------------- #
|
||||
|
||||
|
||||
def _extract_completion_report_from_session(
|
||||
session: ChatSession,
|
||||
*,
|
||||
pending_approval_count: int,
|
||||
) -> CompletionReportInput | None:
|
||||
tool_outputs = {
|
||||
message.tool_call_id: message.content
|
||||
for message in session.messages
|
||||
if message.role == "tool" and message.tool_call_id
|
||||
}
|
||||
|
||||
latest_report: CompletionReportInput | None = None
|
||||
for message in session.messages:
|
||||
if message.role != "assistant" or not message.tool_calls:
|
||||
continue
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
try:
|
||||
parsed_tool_call = CompletionReportToolCall.model_validate(tool_call)
|
||||
except ValidationError:
|
||||
continue
|
||||
|
||||
if parsed_tool_call.function.name != "completion_report":
|
||||
continue
|
||||
|
||||
output = tool_outputs.get(parsed_tool_call.id)
|
||||
if not output:
|
||||
continue
|
||||
|
||||
try:
|
||||
output_payload = ToolOutputEnvelope.model_validate_json(output)
|
||||
except ValidationError:
|
||||
output_payload = None
|
||||
|
||||
if output_payload is not None and output_payload.type == "error":
|
||||
continue
|
||||
|
||||
try:
|
||||
raw_arguments = parsed_tool_call.function.arguments or "{}"
|
||||
report = CompletionReportInput.model_validate_json(raw_arguments)
|
||||
except ValidationError:
|
||||
continue
|
||||
|
||||
if pending_approval_count > 0 and not report.approval_summary:
|
||||
continue
|
||||
|
||||
latest_report = report
|
||||
|
||||
return latest_report
|
||||
|
||||
|
||||
# --------------- repair --------------- #
|
||||
|
||||
|
||||
def _build_completion_report_repair_message(
|
||||
*,
|
||||
attempt: int,
|
||||
pending_approval_count: int,
|
||||
) -> str:
|
||||
approval_instruction = ""
|
||||
if pending_approval_count > 0:
|
||||
approval_instruction = (
|
||||
f" There are currently {pending_approval_count} pending approval item(s). "
|
||||
"If they still exist, include approval_summary."
|
||||
)
|
||||
|
||||
return wrap_internal_message(
|
||||
"The session completed without a valid completion_report tool call. "
|
||||
f"This is repair attempt {attempt}. Call completion_report now and do not do any additional user-facing work."
|
||||
+ approval_instruction
|
||||
)
|
||||
|
||||
|
||||
async def _queue_completion_report_repair(
|
||||
session: ChatSession,
|
||||
*,
|
||||
pending_approval_count: int,
|
||||
) -> None:
|
||||
attempt = session.completion_report_repair_count + 1
|
||||
repair_message = _build_completion_report_repair_message(
|
||||
attempt=attempt,
|
||||
pending_approval_count=pending_approval_count,
|
||||
)
|
||||
session.messages.append(ChatMessage(role="user", content=repair_message))
|
||||
session.completion_report_repair_count = attempt
|
||||
session.completion_report_repair_queued_at = datetime.now(UTC)
|
||||
session.completed_at = None
|
||||
session.completion_report = None
|
||||
await upsert_chat_session(session)
|
||||
await _enqueue_session_turn(
|
||||
session,
|
||||
message=repair_message,
|
||||
tool_name="completion_report_repair",
|
||||
)
|
||||
|
||||
|
||||
# --------------- handler --------------- #
|
||||
|
||||
|
||||
async def handle_non_manual_session_completion(session_id: str) -> None:
|
||||
session = await get_chat_session(session_id)
|
||||
if session is None or session.is_manual:
|
||||
return
|
||||
|
||||
pending_approval_count, graph_exec_id = await _get_pending_approval_metadata(
|
||||
session
|
||||
)
|
||||
report = _extract_completion_report_from_session(
|
||||
session,
|
||||
pending_approval_count=pending_approval_count,
|
||||
)
|
||||
|
||||
if report is not None:
|
||||
session.completion_report = StoredCompletionReport(
|
||||
**report.model_dump(),
|
||||
has_pending_approvals=pending_approval_count > 0,
|
||||
pending_approval_count=pending_approval_count,
|
||||
pending_approval_graph_exec_id=graph_exec_id,
|
||||
saved_at=datetime.now(UTC),
|
||||
)
|
||||
session.completion_report_repair_queued_at = None
|
||||
session.completed_at = datetime.now(UTC)
|
||||
await upsert_chat_session(session)
|
||||
return
|
||||
|
||||
if session.completion_report_repair_count >= MAX_COMPLETION_REPORT_REPAIRS:
|
||||
session.completion_report_repair_queued_at = None
|
||||
session.completed_at = datetime.now(UTC)
|
||||
await upsert_chat_session(session)
|
||||
return
|
||||
|
||||
await _queue_completion_report_repair(
|
||||
session,
|
||||
pending_approval_count=pending_approval_count,
|
||||
)
|
||||
386
autogpt_platform/backend/backend/copilot/autopilot_dispatch.py
Normal file
386
autogpt_platform/backend/backend/copilot/autopilot_dispatch.py
Normal file
@@ -0,0 +1,386 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, date, datetime, time, timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import prisma.enums
|
||||
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.autopilot_prompts import (
|
||||
AUTOPILOT_CALLBACK_TAG,
|
||||
AUTOPILOT_DISABLED_TOOLS,
|
||||
AUTOPILOT_INVITE_CTA_TAG,
|
||||
AUTOPILOT_NIGHTLY_TAG_PREFIX,
|
||||
_build_autopilot_system_prompt,
|
||||
_render_initial_message,
|
||||
)
|
||||
from backend.copilot.constants import COPILOT_SESSION_PREFIX
|
||||
from backend.copilot.executor.utils import enqueue_copilot_turn
|
||||
from backend.copilot.model import ChatMessage, ChatSession, create_chat_session
|
||||
from backend.copilot.session_types import ChatSessionConfig, ChatSessionStartType
|
||||
from backend.data.db_accessors import chat_db, invited_user_db, user_db
|
||||
from backend.data.model import User
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.timezone_utils import get_user_timezone_or_utc
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.invited_user import InvitedUserRecord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
DISPATCH_BATCH_SIZE = 500
|
||||
|
||||
|
||||
# --------------- tag helpers --------------- #
|
||||
|
||||
|
||||
def get_graph_exec_id_for_session(session_id: str) -> str:
|
||||
return f"{COPILOT_SESSION_PREFIX}{session_id}"
|
||||
|
||||
|
||||
def get_nightly_execution_tag(target_local_date: date) -> str:
|
||||
return f"{AUTOPILOT_NIGHTLY_TAG_PREFIX}{target_local_date.isoformat()}"
|
||||
|
||||
|
||||
def get_callback_execution_tag() -> str:
|
||||
return AUTOPILOT_CALLBACK_TAG
|
||||
|
||||
|
||||
def get_invite_cta_execution_tag() -> str:
|
||||
return AUTOPILOT_INVITE_CTA_TAG
|
||||
|
||||
|
||||
def _get_manual_trigger_execution_tag(start_type: ChatSessionStartType) -> str:
|
||||
timestamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%S%fZ")
|
||||
return f"admin-autopilot:{start_type.value}:{timestamp}:{uuid4()}"
|
||||
|
||||
|
||||
# --------------- timezone helpers --------------- #
|
||||
|
||||
|
||||
def _bucket_end_for_now(now_utc: datetime) -> datetime:
|
||||
minute = 30 if now_utc.minute >= 30 else 0
|
||||
return now_utc.replace(minute=minute, second=0, microsecond=0)
|
||||
|
||||
|
||||
def _resolve_timezone_name(raw_timezone: str | None) -> str:
|
||||
return get_user_timezone_or_utc(raw_timezone)
|
||||
|
||||
|
||||
def _crosses_local_midnight(
|
||||
bucket_start_utc: datetime,
|
||||
bucket_end_utc: datetime,
|
||||
timezone_name: str,
|
||||
) -> date | None:
|
||||
"""Return the new local date if *bucket_end_utc* falls on a different local
|
||||
date than *bucket_start_utc*, taking DST transitions into account.
|
||||
|
||||
During a DST spring-forward the wall clock jumps forward (e.g. 01:59 →
|
||||
03:00). We use ``fold=0`` on the end instant so that ambiguous/missing
|
||||
times are resolved consistently and a single 30-min UTC bucket can never
|
||||
produce midnight on *two* consecutive calls.
|
||||
"""
|
||||
tz = ZoneInfo(timezone_name)
|
||||
start_local = bucket_start_utc.astimezone(tz)
|
||||
end_local = bucket_end_utc.astimezone(tz)
|
||||
# Resolve ambiguous wall-clock times consistently (spring-forward / fall-back)
|
||||
end_local = end_local.replace(fold=0)
|
||||
if start_local.date() == end_local.date():
|
||||
return None
|
||||
return end_local.date()
|
||||
|
||||
|
||||
# --------------- thin DB wrappers --------------- #
|
||||
|
||||
|
||||
async def _user_has_recent_manual_message(user_id: str, since: datetime) -> bool:
|
||||
return await chat_db().has_recent_manual_message(user_id, since)
|
||||
|
||||
|
||||
async def _user_has_session_since(user_id: str, since: datetime) -> bool:
|
||||
return await chat_db().has_session_since(user_id, since)
|
||||
|
||||
|
||||
async def _session_exists_for_execution_tag(user_id: str, execution_tag: str) -> bool:
|
||||
return await chat_db().session_exists_for_execution_tag(user_id, execution_tag)
|
||||
|
||||
|
||||
# --------------- session creation --------------- #
|
||||
|
||||
|
||||
async def _enqueue_session_turn(
|
||||
session: ChatSession,
|
||||
*,
|
||||
message: str,
|
||||
tool_name: str,
|
||||
) -> None:
|
||||
turn_id = str(uuid4())
|
||||
await stream_registry.create_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
tool_call_id=tool_name,
|
||||
tool_name=tool_name,
|
||||
turn_id=turn_id,
|
||||
blocking=False,
|
||||
)
|
||||
await enqueue_copilot_turn(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
message=message,
|
||||
turn_id=turn_id,
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
|
||||
async def _create_autopilot_session(
|
||||
user: User,
|
||||
*,
|
||||
start_type: ChatSessionStartType,
|
||||
execution_tag: str,
|
||||
timezone_name: str,
|
||||
target_local_date: date | None = None,
|
||||
invited_user: InvitedUserRecord | None = None,
|
||||
) -> ChatSession | None:
|
||||
if await _session_exists_for_execution_tag(user.id, execution_tag):
|
||||
return None
|
||||
|
||||
system_prompt = await _build_autopilot_system_prompt(
|
||||
user,
|
||||
start_type=start_type,
|
||||
timezone_name=timezone_name,
|
||||
target_local_date=target_local_date,
|
||||
invited_user=invited_user,
|
||||
)
|
||||
initial_message = _render_initial_message(
|
||||
start_type,
|
||||
user_name=user.name,
|
||||
invited_user=invited_user,
|
||||
)
|
||||
session_config = ChatSessionConfig(
|
||||
system_prompt_override=system_prompt,
|
||||
initial_user_message=initial_message,
|
||||
extra_tools=["completion_report"],
|
||||
disabled_tools=AUTOPILOT_DISABLED_TOOLS,
|
||||
)
|
||||
|
||||
session = await create_chat_session(
|
||||
user.id,
|
||||
start_type=start_type,
|
||||
execution_tag=execution_tag,
|
||||
session_config=session_config,
|
||||
initial_messages=[ChatMessage(role="user", content=initial_message)],
|
||||
)
|
||||
await _enqueue_session_turn(
|
||||
session,
|
||||
message=initial_message,
|
||||
tool_name="autopilot_dispatch",
|
||||
)
|
||||
return session
|
||||
|
||||
|
||||
# --------------- cohort helpers --------------- #
|
||||
|
||||
|
||||
async def _try_create_invite_cta_session(
|
||||
user: User,
|
||||
*,
|
||||
invited_user: InvitedUserRecord | None,
|
||||
now_utc: datetime,
|
||||
timezone_name: str,
|
||||
invite_cta_start: date,
|
||||
invite_cta_delay: timedelta,
|
||||
) -> bool:
|
||||
if invited_user is None:
|
||||
return False
|
||||
if invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
|
||||
return False
|
||||
if invited_user.created_at.date() < invite_cta_start:
|
||||
return False
|
||||
if invited_user.created_at > now_utc - invite_cta_delay:
|
||||
return False
|
||||
if await _session_exists_for_execution_tag(user.id, get_invite_cta_execution_tag()):
|
||||
return False
|
||||
|
||||
created = await _create_autopilot_session(
|
||||
user,
|
||||
start_type=ChatSessionStartType.AUTOPILOT_INVITE_CTA,
|
||||
execution_tag=get_invite_cta_execution_tag(),
|
||||
timezone_name=timezone_name,
|
||||
invited_user=invited_user,
|
||||
)
|
||||
return created is not None
|
||||
|
||||
|
||||
async def _try_create_nightly_session(
|
||||
user: User,
|
||||
*,
|
||||
now_utc: datetime,
|
||||
timezone_name: str,
|
||||
target_local_date: date,
|
||||
) -> bool:
|
||||
if not await _user_has_recent_manual_message(
|
||||
user.id,
|
||||
now_utc - timedelta(hours=24),
|
||||
):
|
||||
return False
|
||||
|
||||
created = await _create_autopilot_session(
|
||||
user,
|
||||
start_type=ChatSessionStartType.AUTOPILOT_NIGHTLY,
|
||||
execution_tag=get_nightly_execution_tag(target_local_date),
|
||||
timezone_name=timezone_name,
|
||||
target_local_date=target_local_date,
|
||||
)
|
||||
return created is not None
|
||||
|
||||
|
||||
async def _try_create_callback_session(
|
||||
user: User,
|
||||
*,
|
||||
callback_start: datetime,
|
||||
timezone_name: str,
|
||||
) -> bool:
|
||||
if not await _user_has_session_since(user.id, callback_start):
|
||||
return False
|
||||
if await _session_exists_for_execution_tag(user.id, get_callback_execution_tag()):
|
||||
return False
|
||||
|
||||
created = await _create_autopilot_session(
|
||||
user,
|
||||
start_type=ChatSessionStartType.AUTOPILOT_CALLBACK,
|
||||
execution_tag=get_callback_execution_tag(),
|
||||
timezone_name=timezone_name,
|
||||
)
|
||||
return created is not None
|
||||
|
||||
|
||||
# --------------- dispatch --------------- #
|
||||
|
||||
|
||||
async def _dispatch_nightly_copilot() -> int:
|
||||
now_utc = datetime.now(UTC)
|
||||
bucket_end = _bucket_end_for_now(now_utc)
|
||||
bucket_start = bucket_end - timedelta(minutes=30)
|
||||
callback_start = datetime.combine(
|
||||
settings.config.nightly_copilot_callback_start_date,
|
||||
time.min,
|
||||
tzinfo=UTC,
|
||||
)
|
||||
invite_cta_start = settings.config.nightly_copilot_invite_cta_start_date
|
||||
invite_cta_delay = timedelta(
|
||||
hours=settings.config.nightly_copilot_invite_cta_delay_hours
|
||||
)
|
||||
|
||||
# Paginate user list to avoid loading the entire table into memory.
|
||||
created_count = 0
|
||||
cursor: str | None = None
|
||||
while True:
|
||||
batch = await user_db().list_users(
|
||||
limit=DISPATCH_BATCH_SIZE,
|
||||
cursor=cursor,
|
||||
)
|
||||
if not batch:
|
||||
break
|
||||
|
||||
user_ids = [user.id for user in batch]
|
||||
invites = await invited_user_db().list_invited_users_for_auth_users(user_ids)
|
||||
invites_by_user_id = {
|
||||
invite.auth_user_id: invite for invite in invites if invite.auth_user_id
|
||||
}
|
||||
|
||||
for user in batch:
|
||||
if not await is_feature_enabled(
|
||||
Flag.NIGHTLY_COPILOT, user.id, default=False
|
||||
):
|
||||
continue
|
||||
|
||||
timezone_name = _resolve_timezone_name(user.timezone)
|
||||
target_local_date = _crosses_local_midnight(
|
||||
bucket_start,
|
||||
bucket_end,
|
||||
timezone_name,
|
||||
)
|
||||
if target_local_date is None:
|
||||
continue
|
||||
|
||||
invited_user = invites_by_user_id.get(user.id)
|
||||
if await _try_create_invite_cta_session(
|
||||
user,
|
||||
invited_user=invited_user,
|
||||
now_utc=now_utc,
|
||||
timezone_name=timezone_name,
|
||||
invite_cta_start=invite_cta_start,
|
||||
invite_cta_delay=invite_cta_delay,
|
||||
):
|
||||
created_count += 1
|
||||
continue
|
||||
|
||||
if await _try_create_nightly_session(
|
||||
user,
|
||||
now_utc=now_utc,
|
||||
timezone_name=timezone_name,
|
||||
target_local_date=target_local_date,
|
||||
):
|
||||
created_count += 1
|
||||
continue
|
||||
|
||||
if await _try_create_callback_session(
|
||||
user,
|
||||
callback_start=callback_start,
|
||||
timezone_name=timezone_name,
|
||||
):
|
||||
created_count += 1
|
||||
|
||||
cursor = batch[-1].id if len(batch) == DISPATCH_BATCH_SIZE else None
|
||||
if cursor is None:
|
||||
break
|
||||
|
||||
return created_count
|
||||
|
||||
|
||||
async def dispatch_nightly_copilot() -> int:
|
||||
return await _dispatch_nightly_copilot()
|
||||
|
||||
|
||||
async def trigger_autopilot_session_for_user(
|
||||
user_id: str,
|
||||
*,
|
||||
start_type: ChatSessionStartType,
|
||||
) -> ChatSession:
|
||||
allowed_start_types = {
|
||||
ChatSessionStartType.AUTOPILOT_INVITE_CTA,
|
||||
ChatSessionStartType.AUTOPILOT_NIGHTLY,
|
||||
ChatSessionStartType.AUTOPILOT_CALLBACK,
|
||||
}
|
||||
if start_type not in allowed_start_types:
|
||||
raise ValueError(f"Unsupported autopilot start type: {start_type}")
|
||||
|
||||
try:
|
||||
user = await user_db().get_user_by_id(user_id)
|
||||
except ValueError as exc:
|
||||
raise LookupError(str(exc)) from exc
|
||||
|
||||
invites = await invited_user_db().list_invited_users_for_auth_users([user_id])
|
||||
invited_user = invites[0] if invites else None
|
||||
timezone_name = _resolve_timezone_name(user.timezone)
|
||||
target_local_date = None
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_NIGHTLY:
|
||||
target_local_date = datetime.now(UTC).astimezone(ZoneInfo(timezone_name)).date()
|
||||
|
||||
session = await _create_autopilot_session(
|
||||
user,
|
||||
start_type=start_type,
|
||||
execution_tag=_get_manual_trigger_execution_tag(start_type),
|
||||
timezone_name=timezone_name,
|
||||
target_local_date=target_local_date,
|
||||
invited_user=invited_user,
|
||||
)
|
||||
if session is None:
|
||||
raise ValueError("Failed to create autopilot session")
|
||||
|
||||
return session
|
||||
297
autogpt_platform/backend/backend/copilot/autopilot_email.py
Normal file
297
autogpt_platform/backend/backend/copilot/autopilot_email.py
Normal file
@@ -0,0 +1,297 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from markdown_it import MarkdownIt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.autopilot_completion import (
|
||||
_get_pending_approval_metadata,
|
||||
_queue_completion_report_repair,
|
||||
)
|
||||
from backend.copilot.autopilot_prompts import (
|
||||
AUTOPILOT_CALLBACK_EMAIL_TEMPLATE,
|
||||
AUTOPILOT_INVITE_CTA_EMAIL_TEMPLATE,
|
||||
AUTOPILOT_NIGHTLY_EMAIL_TEMPLATE,
|
||||
MAX_COMPLETION_REPORT_REPAIRS,
|
||||
)
|
||||
from backend.copilot.model import (
|
||||
ChatSession,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.copilot.service import _generate_session_title
|
||||
from backend.copilot.session_types import ChatSessionStartType
|
||||
from backend.data.db_accessors import chat_db, user_db
|
||||
from backend.notifications.email import EmailSender
|
||||
from backend.util.url import get_frontend_base_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
PENDING_NOTIFICATION_SWEEP_LIMIT = 200
|
||||
|
||||
_md = MarkdownIt()
|
||||
|
||||
_EMAIL_INLINE_STYLES: list[tuple[str, str]] = [
|
||||
(
|
||||
"<p>",
|
||||
'<p style="font-size: 15px; line-height: 170%;'
|
||||
" margin-top: 0; margin-bottom: 16px;"
|
||||
' color: #1F1F20;">',
|
||||
),
|
||||
(
|
||||
"<li>",
|
||||
'<li style="font-size: 15px; line-height: 170%;'
|
||||
" margin-top: 0; margin-bottom: 8px;"
|
||||
' color: #1F1F20;">',
|
||||
),
|
||||
(
|
||||
"<ul>",
|
||||
'<ul style="padding: 0 0 0 24px;' ' margin-top: 0; margin-bottom: 16px;">',
|
||||
),
|
||||
(
|
||||
"<ol>",
|
||||
'<ol style="padding: 0 0 0 24px;' ' margin-top: 0; margin-bottom: 16px;">',
|
||||
),
|
||||
(
|
||||
"<a ",
|
||||
'<a style="color: #7733F5;'
|
||||
" text-decoration: underline;"
|
||||
' font-weight: 500;" ',
|
||||
),
|
||||
(
|
||||
"<h2>",
|
||||
'<h2 style="font-size: 20px; font-weight: 600;'
|
||||
' margin-top: 0; margin-bottom: 12px; color: #1F1F20;">',
|
||||
),
|
||||
(
|
||||
"<h3>",
|
||||
'<h3 style="font-size: 18px; font-weight: 600;'
|
||||
' margin-top: 0; margin-bottom: 12px; color: #1F1F20;">',
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _markdown_to_email_html(text: str | None) -> str:
|
||||
"""Convert markdown text to email-safe HTML with inline styles."""
|
||||
if not text or not text.strip():
|
||||
return ""
|
||||
html = _md.render(text.strip())
|
||||
for tag, styled_tag in _EMAIL_INLINE_STYLES:
|
||||
html = html.replace(tag, styled_tag)
|
||||
return html.strip()
|
||||
|
||||
|
||||
# --------------- link builders --------------- #
|
||||
|
||||
|
||||
def _build_session_link(session_id: str, *, show_autopilot: bool) -> str:
|
||||
base_url = get_frontend_base_url()
|
||||
suffix = "&showAutopilot=1" if show_autopilot else ""
|
||||
return f"{base_url}/copilot?sessionId={session_id}{suffix}"
|
||||
|
||||
|
||||
def _get_completion_email_template_name(start_type: ChatSessionStartType) -> str:
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_NIGHTLY:
|
||||
return AUTOPILOT_NIGHTLY_EMAIL_TEMPLATE
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_CALLBACK:
|
||||
return AUTOPILOT_CALLBACK_EMAIL_TEMPLATE
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_INVITE_CTA:
|
||||
return AUTOPILOT_INVITE_CTA_EMAIL_TEMPLATE
|
||||
raise ValueError(f"Unsupported start type for completion email: {start_type}")
|
||||
|
||||
|
||||
class PendingCopilotEmailSweepResult(BaseModel):
|
||||
candidate_count: int = 0
|
||||
processed_count: int = 0
|
||||
sent_count: int = 0
|
||||
skipped_count: int = 0
|
||||
repair_queued_count: int = 0
|
||||
running_count: int = 0
|
||||
failed_count: int = 0
|
||||
|
||||
|
||||
async def _ensure_session_title_for_completed_session(session: ChatSession) -> None:
|
||||
if session.title or not session.user_id:
|
||||
return
|
||||
|
||||
report = session.completion_report
|
||||
if report is None:
|
||||
return
|
||||
|
||||
title = report.email_title.strip() if report.email_title else ""
|
||||
if not title:
|
||||
title_seed = report.email_body or report.thoughts
|
||||
if title_seed:
|
||||
generated_title = await _generate_session_title(
|
||||
title_seed,
|
||||
user_id=session.user_id,
|
||||
session_id=session.session_id,
|
||||
)
|
||||
title = generated_title.strip() if generated_title else ""
|
||||
|
||||
if not title:
|
||||
return
|
||||
|
||||
updated = await update_session_title(
|
||||
session.session_id,
|
||||
session.user_id,
|
||||
title,
|
||||
only_if_empty=True,
|
||||
)
|
||||
if updated:
|
||||
session.title = title
|
||||
|
||||
|
||||
# --------------- send email --------------- #
|
||||
|
||||
|
||||
async def _send_completion_email(session: ChatSession) -> None:
|
||||
report = session.completion_report
|
||||
if report is None:
|
||||
raise ValueError("Missing completion report")
|
||||
try:
|
||||
user = await user_db().get_user_by_id(session.user_id)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"User {session.user_id} not found") from exc
|
||||
|
||||
if not user.email:
|
||||
raise ValueError(f"User {session.user_id} not found")
|
||||
|
||||
approval_cta = report.has_pending_approvals
|
||||
template_name = _get_completion_email_template_name(session.start_type)
|
||||
if approval_cta:
|
||||
cta_url = _build_session_link(session.session_id, show_autopilot=True)
|
||||
cta_label = "Review in Copilot"
|
||||
else:
|
||||
cta_url = _build_session_link(session.session_id, show_autopilot=True)
|
||||
cta_label = (
|
||||
"Try Copilot"
|
||||
if session.start_type == ChatSessionStartType.AUTOPILOT_INVITE_CTA
|
||||
else "Open Copilot"
|
||||
)
|
||||
|
||||
# EmailSender.send_template is synchronous (blocking HTTP call to Postmark).
|
||||
# Run it in a thread to avoid blocking the async event loop.
|
||||
sender = EmailSender()
|
||||
await asyncio.to_thread(
|
||||
sender.send_template,
|
||||
user_email=user.email,
|
||||
subject=report.email_title or "Autopilot update",
|
||||
template_name=template_name,
|
||||
data={
|
||||
"email_body_html": _markdown_to_email_html(report.email_body),
|
||||
"approval_summary_html": _markdown_to_email_html(report.approval_summary),
|
||||
"cta_url": cta_url,
|
||||
"cta_label": cta_label,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# --------------- email sweep --------------- #
|
||||
|
||||
|
||||
async def _process_pending_copilot_email_candidates(
|
||||
candidates: list,
|
||||
) -> PendingCopilotEmailSweepResult:
|
||||
result = PendingCopilotEmailSweepResult(candidate_count=len(candidates))
|
||||
|
||||
for candidate in candidates:
|
||||
session = await get_chat_session(candidate.session_id)
|
||||
if session is None or session.is_manual:
|
||||
continue
|
||||
|
||||
active = await stream_registry.get_session(session.session_id)
|
||||
is_running = active is not None and active.status == "running"
|
||||
if is_running:
|
||||
result.running_count += 1
|
||||
continue
|
||||
|
||||
pending_approval_count, graph_exec_id = await _get_pending_approval_metadata(
|
||||
session
|
||||
)
|
||||
|
||||
if session.completion_report is None:
|
||||
if session.completion_report_repair_count < MAX_COMPLETION_REPORT_REPAIRS:
|
||||
await _queue_completion_report_repair(
|
||||
session,
|
||||
pending_approval_count=pending_approval_count,
|
||||
)
|
||||
result.repair_queued_count += 1
|
||||
continue
|
||||
|
||||
session.completed_at = session.completed_at or datetime.now(UTC)
|
||||
session.completion_report_repair_queued_at = None
|
||||
session.notification_email_skipped_at = datetime.now(UTC)
|
||||
await upsert_chat_session(session)
|
||||
result.skipped_count += 1
|
||||
continue
|
||||
|
||||
session.completed_at = session.completed_at or datetime.now(UTC)
|
||||
if (
|
||||
session.completion_report.pending_approval_graph_exec_id is None
|
||||
and graph_exec_id
|
||||
):
|
||||
session.completion_report = session.completion_report.model_copy(
|
||||
update={
|
||||
"has_pending_approvals": pending_approval_count > 0,
|
||||
"pending_approval_count": pending_approval_count,
|
||||
"pending_approval_graph_exec_id": graph_exec_id,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
await _ensure_session_title_for_completed_session(session)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to ensure session title for session %s",
|
||||
session.session_id,
|
||||
)
|
||||
|
||||
if not session.completion_report.should_notify_user:
|
||||
session.notification_email_skipped_at = datetime.now(UTC)
|
||||
await upsert_chat_session(session)
|
||||
result.skipped_count += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
await _send_completion_email(session)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to send nightly copilot email for session %s",
|
||||
session.session_id,
|
||||
)
|
||||
result.failed_count += 1
|
||||
continue
|
||||
|
||||
session.notification_email_sent_at = datetime.now(UTC)
|
||||
await upsert_chat_session(session)
|
||||
result.sent_count += 1
|
||||
|
||||
result.processed_count = result.sent_count + result.skipped_count
|
||||
return result
|
||||
|
||||
|
||||
async def _send_nightly_copilot_emails() -> int:
|
||||
candidates = await chat_db().get_pending_notification_chat_sessions(
|
||||
limit=PENDING_NOTIFICATION_SWEEP_LIMIT
|
||||
)
|
||||
result = await _process_pending_copilot_email_candidates(candidates)
|
||||
return result.processed_count
|
||||
|
||||
|
||||
async def send_nightly_copilot_emails() -> int:
|
||||
return await _send_nightly_copilot_emails()
|
||||
|
||||
|
||||
async def send_pending_copilot_emails_for_user(
|
||||
user_id: str,
|
||||
) -> PendingCopilotEmailSweepResult:
|
||||
candidates = await chat_db().get_pending_notification_chat_sessions_for_user(
|
||||
user_id,
|
||||
limit=PENDING_NOTIFICATION_SWEEP_LIMIT,
|
||||
)
|
||||
return await _process_pending_copilot_email_candidates(candidates)
|
||||
409
autogpt_platform/backend/backend/copilot/autopilot_prompts.py
Normal file
409
autogpt_platform/backend/backend/copilot/autopilot_prompts.py
Normal file
@@ -0,0 +1,409 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from backend.copilot.service import _get_system_prompt_template
|
||||
from backend.copilot.service import config as chat_config
|
||||
from backend.copilot.session_types import ChatSessionStartType
|
||||
from backend.data.db_accessors import chat_db, understanding_db
|
||||
from backend.data.understanding import format_understanding_for_prompt
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.invited_user import InvitedUserRecord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
INTERNAL_TAG_RE = re.compile(r"<internal>.*?</internal>", re.DOTALL)
|
||||
MAX_COMPLETION_REPORT_REPAIRS = 2
|
||||
AUTOPILOT_RECENT_CONTEXT_CHAR_LIMIT = 6000
|
||||
AUTOPILOT_RECENT_SESSION_LIMIT = 5
|
||||
AUTOPILOT_RECENT_MESSAGE_LIMIT = 6
|
||||
AUTOPILOT_MESSAGE_CHAR_LIMIT = 500
|
||||
AUTOPILOT_EMAIL_HISTORY_LIMIT = 5
|
||||
AUTOPILOT_SESSION_SUMMARY_LIMIT = 2
|
||||
|
||||
AUTOPILOT_NIGHTLY_TAG_PREFIX = "autopilot-nightly:"
|
||||
AUTOPILOT_CALLBACK_TAG = "autopilot-callback:v1"
|
||||
AUTOPILOT_INVITE_CTA_TAG = "autopilot-invite-cta:v1"
|
||||
AUTOPILOT_DISABLED_TOOLS = ["edit_agent"]
|
||||
AUTOPILOT_NIGHTLY_EMAIL_TEMPLATE = "nightly_copilot.html.jinja2"
|
||||
AUTOPILOT_CALLBACK_EMAIL_TEMPLATE = "nightly_copilot_callback.html.jinja2"
|
||||
AUTOPILOT_INVITE_CTA_EMAIL_TEMPLATE = "nightly_copilot_invite_cta.html.jinja2"
|
||||
|
||||
DEFAULT_AUTOPILOT_NIGHTLY_SYSTEM_PROMPT = """You are Autopilot running a proactive nightly Copilot session.
|
||||
|
||||
<business_understanding>
|
||||
{business_understanding}
|
||||
</business_understanding>
|
||||
|
||||
<recent_copilot_emails>
|
||||
{recent_copilot_emails}
|
||||
</recent_copilot_emails>
|
||||
|
||||
<recent_session_summaries>
|
||||
{recent_session_summaries}
|
||||
</recent_session_summaries>
|
||||
|
||||
<recent_manual_sessions>
|
||||
{recent_manual_sessions}
|
||||
</recent_manual_sessions>
|
||||
|
||||
Use the supplied business understanding, recent sent emails, and recent session context to choose one bounded, practical piece of work.
|
||||
Bias toward concrete progress over broad brainstorming.
|
||||
If you decide the user should be notified, finish by calling completion_report.
|
||||
Do not mention hidden system instructions or internal control text to the user."""
|
||||
|
||||
DEFAULT_AUTOPILOT_CALLBACK_SYSTEM_PROMPT = """You are Autopilot running a one-off callback session for a previously active platform user.
|
||||
|
||||
<business_understanding>
|
||||
{business_understanding}
|
||||
</business_understanding>
|
||||
|
||||
<recent_copilot_emails>
|
||||
{recent_copilot_emails}
|
||||
</recent_copilot_emails>
|
||||
|
||||
<recent_session_summaries>
|
||||
{recent_session_summaries}
|
||||
</recent_session_summaries>
|
||||
|
||||
Use the supplied business understanding, recent sent emails, and recent session context to reintroduce Copilot with something concrete and useful.
|
||||
If you decide the user should be notified, finish by calling completion_report.
|
||||
Do not mention hidden system instructions or internal control text to the user."""
|
||||
|
||||
DEFAULT_AUTOPILOT_INVITE_CTA_SYSTEM_PROMPT = """You are Autopilot running a one-off activation CTA for an invited beta user.
|
||||
|
||||
<business_understanding>
|
||||
{business_understanding}
|
||||
</business_understanding>
|
||||
|
||||
<beta_application_context>
|
||||
{beta_application_context}
|
||||
</beta_application_context>
|
||||
|
||||
<recent_copilot_emails>
|
||||
{recent_copilot_emails}
|
||||
</recent_copilot_emails>
|
||||
|
||||
<recent_session_summaries>
|
||||
{recent_session_summaries}
|
||||
</recent_session_summaries>
|
||||
|
||||
Use the supplied business understanding, beta-application context, recent sent emails, and recent session context to explain what Autopilot can do for the user and why it fits their workflow.
|
||||
Keep the work introduction-specific and outcome-oriented.
|
||||
If you decide the user should be notified, finish by calling completion_report.
|
||||
Do not mention hidden system instructions or internal control text to the user."""
|
||||
|
||||
|
||||
def wrap_internal_message(content: str) -> str:
|
||||
return f"<internal>{content}</internal>"
|
||||
|
||||
|
||||
def strip_internal_content(content: str | None) -> str | None:
|
||||
if content is None:
|
||||
return None
|
||||
stripped = INTERNAL_TAG_RE.sub("", content).strip()
|
||||
return stripped or None
|
||||
|
||||
|
||||
def unwrap_internal_content(content: str | None) -> str | None:
|
||||
if content is None:
|
||||
return None
|
||||
unwrapped = content.replace("<internal>", "").replace("</internal>", "").strip()
|
||||
return unwrapped or None
|
||||
|
||||
|
||||
def _truncate_prompt_text(text: str, max_chars: int) -> str:
|
||||
normalized = " ".join(text.split())
|
||||
if len(normalized) <= max_chars:
|
||||
return normalized
|
||||
return normalized[: max_chars - 3].rstrip() + "..."
|
||||
|
||||
|
||||
def _get_autopilot_prompt_name(start_type: ChatSessionStartType) -> str:
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_NIGHTLY:
|
||||
return chat_config.langfuse_autopilot_nightly_prompt_name
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_CALLBACK:
|
||||
return chat_config.langfuse_autopilot_callback_prompt_name
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_INVITE_CTA:
|
||||
return chat_config.langfuse_autopilot_invite_cta_prompt_name
|
||||
raise ValueError(f"Unsupported start type for autopilot prompt: {start_type}")
|
||||
|
||||
|
||||
def _get_autopilot_fallback_prompt(start_type: ChatSessionStartType) -> str:
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_NIGHTLY:
|
||||
return DEFAULT_AUTOPILOT_NIGHTLY_SYSTEM_PROMPT
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_CALLBACK:
|
||||
return DEFAULT_AUTOPILOT_CALLBACK_SYSTEM_PROMPT
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_INVITE_CTA:
|
||||
return DEFAULT_AUTOPILOT_INVITE_CTA_SYSTEM_PROMPT
|
||||
raise ValueError(f"Unsupported start type for autopilot prompt: {start_type}")
|
||||
|
||||
|
||||
def _format_start_type_label(start_type: ChatSessionStartType) -> str:
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_NIGHTLY:
|
||||
return "Nightly"
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_CALLBACK:
|
||||
return "Callback"
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_INVITE_CTA:
|
||||
return "Beta Invite CTA"
|
||||
return start_type.value
|
||||
|
||||
|
||||
def _get_invited_user_tally_understanding(
|
||||
invited_user: InvitedUserRecord | None,
|
||||
) -> dict[str, Any] | None:
|
||||
return invited_user.tally_understanding if invited_user is not None else None
|
||||
|
||||
|
||||
def _render_initial_message(
|
||||
start_type: ChatSessionStartType,
|
||||
*,
|
||||
user_name: str | None,
|
||||
invited_user: InvitedUserRecord | None = None,
|
||||
) -> str:
|
||||
display_name = user_name or "the user"
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_NIGHTLY:
|
||||
return wrap_internal_message(
|
||||
"This is a nightly proactive Copilot session. Review recent manual activity, "
|
||||
f"do one useful piece of work for {display_name}, and finish with completion_report."
|
||||
)
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_CALLBACK:
|
||||
return wrap_internal_message(
|
||||
"This is a one-off callback session for a previously active user. "
|
||||
f"Reintroduce Copilot with something concrete and useful for {display_name}, "
|
||||
"then finish with completion_report."
|
||||
)
|
||||
|
||||
invite_summary = ""
|
||||
tally_understanding = _get_invited_user_tally_understanding(invited_user)
|
||||
if tally_understanding is not None:
|
||||
invite_summary = "\nKnown context from the beta application:\n" + json.dumps(
|
||||
tally_understanding, ensure_ascii=False
|
||||
)
|
||||
return wrap_internal_message(
|
||||
"This is a one-off invite CTA session for an invited beta user who has not yet activated. "
|
||||
f"Create a tailored introduction for {display_name}, explain how Autopilot can help, "
|
||||
f"and finish with completion_report.{invite_summary}"
|
||||
)
|
||||
|
||||
|
||||
def _get_previous_local_midnight_utc(
|
||||
target_local_date: date,
|
||||
timezone_name: str,
|
||||
) -> datetime:
|
||||
from datetime import UTC
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
tz = ZoneInfo(timezone_name)
|
||||
previous_midnight_local = datetime.combine(
|
||||
target_local_date - timedelta(days=1),
|
||||
time.min,
|
||||
tzinfo=tz,
|
||||
)
|
||||
return previous_midnight_local.astimezone(UTC)
|
||||
|
||||
|
||||
async def _get_recent_manual_session_context(
|
||||
user_id: str,
|
||||
*,
|
||||
since_utc: datetime,
|
||||
) -> str:
|
||||
sessions = await chat_db().get_manual_chat_sessions_since(
|
||||
user_id,
|
||||
since_utc,
|
||||
AUTOPILOT_RECENT_SESSION_LIMIT,
|
||||
)
|
||||
|
||||
if not sessions:
|
||||
return "No recent manual sessions since the previous nightly run."
|
||||
|
||||
blocks: list[str] = []
|
||||
used_chars = 0
|
||||
|
||||
for session in sessions:
|
||||
messages = await chat_db().get_chat_messages_since(
|
||||
session.session_id, since_utc
|
||||
)
|
||||
|
||||
visible_messages: list[str] = []
|
||||
for message in messages[-AUTOPILOT_RECENT_MESSAGE_LIMIT:]:
|
||||
content = message.content or ""
|
||||
if message.role == "user":
|
||||
visible = strip_internal_content(content)
|
||||
else:
|
||||
visible = content.strip() or None
|
||||
if not visible:
|
||||
continue
|
||||
|
||||
role_label = {
|
||||
"user": "User",
|
||||
"assistant": "Assistant",
|
||||
"tool": "Tool",
|
||||
}.get(message.role, message.role.title())
|
||||
visible_messages.append(
|
||||
f"{role_label}: {_truncate_prompt_text(visible, AUTOPILOT_MESSAGE_CHAR_LIMIT)}"
|
||||
)
|
||||
|
||||
if not visible_messages:
|
||||
continue
|
||||
|
||||
title_suffix = f" ({session.title})" if session.title else ""
|
||||
block = (
|
||||
f"### Session updated {session.updated_at.isoformat()}{title_suffix}\n"
|
||||
+ "\n".join(visible_messages)
|
||||
)
|
||||
if used_chars + len(block) > AUTOPILOT_RECENT_CONTEXT_CHAR_LIMIT:
|
||||
break
|
||||
|
||||
blocks.append(block)
|
||||
used_chars += len(block)
|
||||
|
||||
return (
|
||||
"\n\n".join(blocks)
|
||||
if blocks
|
||||
else "No recent manual sessions since the previous nightly run."
|
||||
)
|
||||
|
||||
|
||||
async def _get_recent_sent_email_context(user_id: str) -> str:
|
||||
sessions = await chat_db().get_recent_sent_email_chat_sessions(
|
||||
user_id,
|
||||
AUTOPILOT_EMAIL_HISTORY_LIMIT,
|
||||
)
|
||||
if not sessions:
|
||||
return "No recent Copilot or Autopilot emails have been sent to this user."
|
||||
|
||||
blocks: list[str] = []
|
||||
for session in sessions:
|
||||
report = session.completion_report
|
||||
sent_at = session.notification_email_sent_at
|
||||
if report is None or sent_at is None:
|
||||
continue
|
||||
|
||||
lines = [
|
||||
f"### Sent {sent_at.isoformat()} ({_format_start_type_label(session.start_type)})",
|
||||
]
|
||||
if report.email_title:
|
||||
lines.append(
|
||||
f"Subject: {_truncate_prompt_text(report.email_title, AUTOPILOT_MESSAGE_CHAR_LIMIT)}"
|
||||
)
|
||||
if report.email_body:
|
||||
lines.append(
|
||||
f"Body: {_truncate_prompt_text(report.email_body, AUTOPILOT_MESSAGE_CHAR_LIMIT)}"
|
||||
)
|
||||
if report.callback_session_message:
|
||||
lines.append(
|
||||
"CTA Message: "
|
||||
+ _truncate_prompt_text(
|
||||
report.callback_session_message,
|
||||
AUTOPILOT_MESSAGE_CHAR_LIMIT,
|
||||
)
|
||||
)
|
||||
blocks.append("\n".join(lines))
|
||||
|
||||
return (
|
||||
"\n\n".join(blocks)
|
||||
if blocks
|
||||
else "No recent Copilot or Autopilot emails have been sent to this user."
|
||||
)
|
||||
|
||||
|
||||
async def _get_recent_session_summary_context(user_id: str) -> str:
|
||||
sessions = await chat_db().get_recent_completion_report_chat_sessions(
|
||||
user_id,
|
||||
AUTOPILOT_SESSION_SUMMARY_LIMIT,
|
||||
)
|
||||
if not sessions:
|
||||
return "No recent Copilot session summaries are available."
|
||||
|
||||
blocks: list[str] = []
|
||||
for session in sessions:
|
||||
report = session.completion_report
|
||||
if report is None:
|
||||
continue
|
||||
|
||||
title_suffix = f" ({session.title})" if session.title else ""
|
||||
lines = [
|
||||
f"### {_format_start_type_label(session.start_type)} session updated {session.updated_at.isoformat()}{title_suffix}",
|
||||
f"Summary: {_truncate_prompt_text(report.thoughts, AUTOPILOT_MESSAGE_CHAR_LIMIT)}",
|
||||
]
|
||||
if report.email_title:
|
||||
lines.append(
|
||||
"Email Title: "
|
||||
+ _truncate_prompt_text(
|
||||
report.email_title, AUTOPILOT_MESSAGE_CHAR_LIMIT
|
||||
)
|
||||
)
|
||||
blocks.append("\n".join(lines))
|
||||
|
||||
return (
|
||||
"\n\n".join(blocks)
|
||||
if blocks
|
||||
else "No recent Copilot session summaries are available."
|
||||
)
|
||||
|
||||
|
||||
async def _build_autopilot_system_prompt(
|
||||
user: Any,
|
||||
*,
|
||||
start_type: ChatSessionStartType,
|
||||
timezone_name: str,
|
||||
target_local_date: date | None = None,
|
||||
invited_user: InvitedUserRecord | None = None,
|
||||
) -> str:
|
||||
understanding = await understanding_db().get_business_understanding(user.id)
|
||||
business_understanding = (
|
||||
format_understanding_for_prompt(understanding)
|
||||
if understanding
|
||||
else "No saved business understanding yet."
|
||||
)
|
||||
recent_copilot_emails = await _get_recent_sent_email_context(user.id)
|
||||
recent_session_summaries = await _get_recent_session_summary_context(user.id)
|
||||
recent_manual_sessions = "Not applicable for this prompt type."
|
||||
beta_application_context = "No beta application context available."
|
||||
|
||||
users_information_sections = [
|
||||
"## Business Understanding\n" + business_understanding
|
||||
]
|
||||
users_information_sections.append(
|
||||
"## Recent Copilot Emails Sent To User\n" + recent_copilot_emails
|
||||
)
|
||||
users_information_sections.append(
|
||||
"## Recent Copilot Session Summaries\n" + recent_session_summaries
|
||||
)
|
||||
users_information = "\n\n".join(users_information_sections)
|
||||
|
||||
if (
|
||||
start_type == ChatSessionStartType.AUTOPILOT_NIGHTLY
|
||||
and target_local_date is not None
|
||||
):
|
||||
recent_manual_sessions = await _get_recent_manual_session_context(
|
||||
user.id,
|
||||
since_utc=_get_previous_local_midnight_utc(
|
||||
target_local_date,
|
||||
timezone_name,
|
||||
),
|
||||
)
|
||||
|
||||
tally_understanding = _get_invited_user_tally_understanding(invited_user)
|
||||
if tally_understanding is not None:
|
||||
beta_application_context = json.dumps(tally_understanding, ensure_ascii=False)
|
||||
|
||||
return await _get_system_prompt_template(
|
||||
users_information,
|
||||
prompt_name=_get_autopilot_prompt_name(start_type),
|
||||
fallback_prompt=_get_autopilot_fallback_prompt(start_type),
|
||||
template_vars={
|
||||
"users_information": users_information,
|
||||
"business_understanding": business_understanding,
|
||||
"recent_copilot_emails": recent_copilot_emails,
|
||||
"recent_session_summaries": recent_session_summaries,
|
||||
"recent_manual_sessions": recent_manual_sessions,
|
||||
"beta_application_context": beta_application_context,
|
||||
},
|
||||
)
|
||||
1088
autogpt_platform/backend/backend/copilot/autopilot_test.py
Normal file
1088
autogpt_platform/backend/backend/copilot/autopilot_test.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -38,8 +38,8 @@ from backend.copilot.response_model import (
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import (
|
||||
_build_system_prompt,
|
||||
_generate_session_title,
|
||||
_resolve_system_prompt,
|
||||
client,
|
||||
config,
|
||||
)
|
||||
@@ -160,7 +160,7 @@ async def stream_chat_completion_baseline(
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Generate title for new sessions
|
||||
if is_user_message and not session.title:
|
||||
if is_user_message and session.is_manual and not session.title:
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
if len(user_messages) == 1:
|
||||
first_message = user_messages[0].content or message or ""
|
||||
@@ -177,16 +177,20 @@ async def stream_chat_completion_baseline(
|
||||
# changes from concurrent chats updating business understanding.
|
||||
is_first_turn = len(session.messages) <= 1
|
||||
if is_first_turn:
|
||||
base_system_prompt, _ = await _build_system_prompt(
|
||||
user_id, has_conversation_history=False
|
||||
base_system_prompt, _ = await _resolve_system_prompt(
|
||||
session,
|
||||
user_id,
|
||||
has_conversation_history=False,
|
||||
)
|
||||
else:
|
||||
base_system_prompt, _ = await _build_system_prompt(
|
||||
user_id=None, has_conversation_history=True
|
||||
base_system_prompt, _ = await _resolve_system_prompt(
|
||||
session,
|
||||
user_id=None,
|
||||
has_conversation_history=True,
|
||||
)
|
||||
|
||||
# Append tool documentation and technical notes
|
||||
system_prompt = base_system_prompt + get_baseline_supplement()
|
||||
system_prompt = base_system_prompt + get_baseline_supplement(session)
|
||||
|
||||
# Compress context if approaching the model's token limit
|
||||
messages_for_context = await _compress_session_messages(session.messages)
|
||||
@@ -199,7 +203,7 @@ async def stream_chat_completion_baseline(
|
||||
if msg.role in ("user", "assistant") and msg.content:
|
||||
openai_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
tools = get_available_tools()
|
||||
tools = get_available_tools(session)
|
||||
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
|
||||
|
||||
@@ -65,6 +65,18 @@ class ChatConfig(BaseSettings):
|
||||
default="CoPilot Prompt",
|
||||
description="Name of the prompt in Langfuse to fetch",
|
||||
)
|
||||
langfuse_autopilot_nightly_prompt_name: str = Field(
|
||||
default="CoPilot Nightly",
|
||||
description="Langfuse prompt name for nightly Autopilot sessions",
|
||||
)
|
||||
langfuse_autopilot_callback_prompt_name: str = Field(
|
||||
default="CoPilot Callback",
|
||||
description="Langfuse prompt name for callback Autopilot sessions",
|
||||
)
|
||||
langfuse_autopilot_invite_cta_prompt_name: str = Field(
|
||||
default="CoPilot Beta Invite CTA",
|
||||
description="Langfuse prompt name for beta invite CTA Autopilot sessions",
|
||||
)
|
||||
langfuse_prompt_cache_ttl: int = Field(
|
||||
default=300,
|
||||
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
|
||||
|
||||
@@ -11,8 +11,6 @@ from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
@@ -84,17 +82,6 @@ def resolve_sandbox_path(path: str) -> str:
|
||||
return normalized
|
||||
|
||||
|
||||
async def get_workspace_manager(user_id: str, session_id: str) -> WorkspaceManager:
|
||||
"""Create a session-scoped :class:`WorkspaceManager`.
|
||||
|
||||
Placed here (rather than in ``tools/workspace_files``) so that modules
|
||||
like ``sdk/file_ref`` can import it without triggering the heavy
|
||||
``tools/__init__`` import chain.
|
||||
"""
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
return WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
|
||||
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
"""Return True if *path* is within an allowed host-filesystem location.
|
||||
|
||||
|
||||
@@ -8,19 +8,48 @@ from typing import Any
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from prisma.models import ChatSessionCallbackToken as PrismaChatSessionCallbackToken
|
||||
from prisma.types import (
|
||||
ChatMessageCreateInput,
|
||||
ChatSessionCreateInput,
|
||||
ChatSessionUpdateInput,
|
||||
ChatSessionWhereInput,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data import db
|
||||
from backend.util.json import SafeJson, sanitize_string
|
||||
|
||||
from .model import ChatMessage, ChatSession, ChatSessionInfo
|
||||
from .session_types import ChatSessionStartType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_UNSET = object()
|
||||
|
||||
|
||||
class ChatSessionCallbackTokenInfo(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
source_session_id: str | None = None
|
||||
callback_session_message: str
|
||||
expires_at: datetime
|
||||
consumed_at: datetime | None = None
|
||||
consumed_session_id: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_db(
|
||||
cls,
|
||||
token: PrismaChatSessionCallbackToken,
|
||||
) -> "ChatSessionCallbackTokenInfo":
|
||||
return cls(
|
||||
id=token.id,
|
||||
user_id=token.userId,
|
||||
source_session_id=token.sourceSessionId,
|
||||
callback_session_message=token.callbackSessionMessage,
|
||||
expires_at=token.expiresAt,
|
||||
consumed_at=token.consumedAt,
|
||||
consumed_session_id=token.consumedSessionId,
|
||||
)
|
||||
|
||||
|
||||
async def get_chat_session(session_id: str) -> ChatSession | None:
|
||||
@@ -32,9 +61,103 @@ async def get_chat_session(session_id: str) -> ChatSession | None:
|
||||
return ChatSession.from_db(session) if session else None
|
||||
|
||||
|
||||
async def has_recent_manual_message(user_id: str, since: datetime) -> bool:
|
||||
message = await PrismaChatMessage.prisma().find_first(
|
||||
where={
|
||||
"role": "user",
|
||||
"createdAt": {"gte": since},
|
||||
"Session": {
|
||||
"is": {
|
||||
"userId": user_id,
|
||||
"startType": ChatSessionStartType.MANUAL.value,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
return message is not None
|
||||
|
||||
|
||||
async def has_session_since(user_id: str, since: datetime) -> bool:
|
||||
session = await PrismaChatSession.prisma().find_first(
|
||||
where={"userId": user_id, "createdAt": {"gte": since}}
|
||||
)
|
||||
return session is not None
|
||||
|
||||
|
||||
async def session_exists_for_execution_tag(user_id: str, execution_tag: str) -> bool:
|
||||
session = await PrismaChatSession.prisma().find_first(
|
||||
where={"userId": user_id, "executionTag": execution_tag}
|
||||
)
|
||||
return session is not None
|
||||
|
||||
|
||||
async def get_manual_chat_sessions_since(
|
||||
user_id: str,
|
||||
since_utc: datetime,
|
||||
limit: int,
|
||||
) -> list[ChatSessionInfo]:
|
||||
sessions = await PrismaChatSession.prisma().find_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"startType": ChatSessionStartType.MANUAL.value,
|
||||
"updatedAt": {"gte": since_utc},
|
||||
},
|
||||
order={"updatedAt": "desc"},
|
||||
take=limit,
|
||||
)
|
||||
return [ChatSessionInfo.from_db(session) for session in sessions]
|
||||
|
||||
|
||||
async def get_chat_messages_since(
|
||||
session_id: str,
|
||||
since_utc: datetime,
|
||||
) -> list[ChatMessage]:
|
||||
messages = await PrismaChatMessage.prisma().find_many(
|
||||
where={
|
||||
"sessionId": session_id,
|
||||
"createdAt": {"gte": since_utc},
|
||||
},
|
||||
order={"sequence": "asc"},
|
||||
)
|
||||
return [ChatMessage.from_db(message) for message in messages]
|
||||
|
||||
|
||||
def _build_chat_message_create_input(
|
||||
*,
|
||||
session_id: str,
|
||||
sequence: int,
|
||||
now: datetime,
|
||||
msg: dict[str, Any],
|
||||
) -> ChatMessageCreateInput:
|
||||
data: ChatMessageCreateInput = {
|
||||
"sessionId": session_id,
|
||||
"role": msg["role"],
|
||||
"sequence": sequence,
|
||||
"createdAt": now,
|
||||
}
|
||||
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = sanitize_string(msg["content"])
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = sanitize_string(msg["refusal"])
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def create_chat_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
start_type: ChatSessionStartType = ChatSessionStartType.MANUAL,
|
||||
execution_tag: str | None = None,
|
||||
session_config: dict[str, Any] | None = None,
|
||||
) -> ChatSessionInfo:
|
||||
"""Create a new chat session in the database."""
|
||||
data = ChatSessionCreateInput(
|
||||
@@ -43,6 +166,9 @@ async def create_chat_session(
|
||||
credentials=SafeJson({}),
|
||||
successfulAgentRuns=SafeJson({}),
|
||||
successfulAgentSchedules=SafeJson({}),
|
||||
startType=start_type.value,
|
||||
executionTag=execution_tag,
|
||||
sessionConfig=SafeJson(session_config or {}),
|
||||
)
|
||||
prisma_session = await PrismaChatSession.prisma().create(data=data)
|
||||
return ChatSessionInfo.from_db(prisma_session)
|
||||
@@ -56,9 +182,19 @@ async def update_chat_session(
|
||||
total_prompt_tokens: int | None = None,
|
||||
total_completion_tokens: int | None = None,
|
||||
title: str | None = None,
|
||||
start_type: ChatSessionStartType | None = None,
|
||||
execution_tag: str | None | object = _UNSET,
|
||||
session_config: dict[str, Any] | None = None,
|
||||
completion_report: dict[str, Any] | None | object = _UNSET,
|
||||
completion_report_repair_count: int | None = None,
|
||||
completion_report_repair_queued_at: datetime | None | object = _UNSET,
|
||||
completed_at: datetime | None | object = _UNSET,
|
||||
notification_email_sent_at: datetime | None | object = _UNSET,
|
||||
notification_email_skipped_at: datetime | None | object = _UNSET,
|
||||
) -> ChatSession | None:
|
||||
"""Update a chat session's metadata."""
|
||||
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
||||
should_clear_completion_report = completion_report is None
|
||||
|
||||
if credentials is not None:
|
||||
data["credentials"] = SafeJson(credentials)
|
||||
@@ -72,12 +208,41 @@ async def update_chat_session(
|
||||
data["totalCompletionTokens"] = total_completion_tokens
|
||||
if title is not None:
|
||||
data["title"] = title
|
||||
if start_type is not None:
|
||||
data["startType"] = start_type.value
|
||||
if execution_tag is not _UNSET:
|
||||
data["executionTag"] = execution_tag
|
||||
if session_config is not None:
|
||||
data["sessionConfig"] = SafeJson(session_config)
|
||||
if completion_report is not _UNSET and completion_report is not None:
|
||||
data["completionReport"] = SafeJson(completion_report)
|
||||
if completion_report_repair_count is not None:
|
||||
data["completionReportRepairCount"] = completion_report_repair_count
|
||||
if completion_report_repair_queued_at is not _UNSET:
|
||||
data["completionReportRepairQueuedAt"] = completion_report_repair_queued_at
|
||||
if completed_at is not _UNSET:
|
||||
data["completedAt"] = completed_at
|
||||
if notification_email_sent_at is not _UNSET:
|
||||
data["notificationEmailSentAt"] = notification_email_sent_at
|
||||
if notification_email_skipped_at is not _UNSET:
|
||||
data["notificationEmailSkippedAt"] = notification_email_skipped_at
|
||||
|
||||
session = await PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data=data,
|
||||
include={"Messages": {"order_by": {"sequence": "asc"}}},
|
||||
)
|
||||
|
||||
if should_clear_completion_report:
|
||||
await db.execute_raw_with_schema(
|
||||
'UPDATE {schema_prefix}"ChatSession" SET "completionReport" = NULL WHERE "id" = $1',
|
||||
session_id,
|
||||
)
|
||||
session = await PrismaChatSession.prisma().find_unique(
|
||||
where={"id": session_id},
|
||||
include={"Messages": {"order_by": {"sequence": "asc"}}},
|
||||
)
|
||||
|
||||
return ChatSession.from_db(session) if session else None
|
||||
|
||||
|
||||
@@ -187,37 +352,15 @@ async def add_chat_messages_batch(
|
||||
now = datetime.now(UTC)
|
||||
|
||||
async with db.transaction() as tx:
|
||||
# Build all message data
|
||||
messages_data = []
|
||||
for i, msg in enumerate(messages):
|
||||
# Build ChatMessageCreateInput with only non-None values
|
||||
# (Prisma TypedDict rejects optional fields set to None)
|
||||
# Note: create_many doesn't support nested creates, use sessionId directly
|
||||
data: ChatMessageCreateInput = {
|
||||
"sessionId": session_id,
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
"createdAt": now,
|
||||
}
|
||||
|
||||
# Add optional string fields — sanitize to strip
|
||||
# PostgreSQL-incompatible control characters.
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = sanitize_string(msg["content"])
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = sanitize_string(msg["refusal"])
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
messages_data.append(data)
|
||||
messages_data = [
|
||||
_build_chat_message_create_input(
|
||||
session_id=session_id,
|
||||
sequence=start_sequence + i,
|
||||
now=now,
|
||||
msg=msg,
|
||||
)
|
||||
for i, msg in enumerate(messages)
|
||||
]
|
||||
|
||||
# Run create_many and session update in parallel within transaction
|
||||
# Both use the same timestamp for consistency
|
||||
@@ -256,10 +399,14 @@ async def get_user_chat_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
with_auto: bool = False,
|
||||
) -> list[ChatSessionInfo]:
|
||||
"""Get chat sessions for a user, ordered by most recent."""
|
||||
prisma_sessions = await PrismaChatSession.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
where={
|
||||
"userId": user_id,
|
||||
**({} if with_auto else {"startType": ChatSessionStartType.MANUAL.value}),
|
||||
},
|
||||
order={"updatedAt": "desc"},
|
||||
take=limit,
|
||||
skip=offset,
|
||||
@@ -267,9 +414,88 @@ async def get_user_chat_sessions(
|
||||
return [ChatSessionInfo.from_db(s) for s in prisma_sessions]
|
||||
|
||||
|
||||
async def get_user_session_count(user_id: str) -> int:
|
||||
async def get_pending_notification_chat_sessions(
|
||||
limit: int = 200,
|
||||
) -> list[ChatSessionInfo]:
|
||||
sessions = await PrismaChatSession.prisma().find_many(
|
||||
where={
|
||||
"startType": {"not": ChatSessionStartType.MANUAL.value},
|
||||
"notificationEmailSentAt": None,
|
||||
"notificationEmailSkippedAt": None,
|
||||
},
|
||||
order={"updatedAt": "asc"},
|
||||
take=limit,
|
||||
)
|
||||
return [ChatSessionInfo.from_db(session) for session in sessions]
|
||||
|
||||
|
||||
async def get_pending_notification_chat_sessions_for_user(
|
||||
user_id: str,
|
||||
limit: int = 200,
|
||||
) -> list[ChatSessionInfo]:
|
||||
sessions = await PrismaChatSession.prisma().find_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"startType": {"not": ChatSessionStartType.MANUAL.value},
|
||||
"notificationEmailSentAt": None,
|
||||
"notificationEmailSkippedAt": None,
|
||||
},
|
||||
order={"updatedAt": "asc"},
|
||||
take=limit,
|
||||
)
|
||||
return [ChatSessionInfo.from_db(session) for session in sessions]
|
||||
|
||||
|
||||
async def get_recent_sent_email_chat_sessions(
|
||||
user_id: str,
|
||||
limit: int,
|
||||
) -> list[ChatSessionInfo]:
|
||||
sessions = await PrismaChatSession.prisma().find_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"startType": {"not": ChatSessionStartType.MANUAL.value},
|
||||
"notificationEmailSentAt": {"not": None},
|
||||
},
|
||||
order={"notificationEmailSentAt": "desc"},
|
||||
take=max(limit * 3, limit),
|
||||
)
|
||||
return [
|
||||
session_info
|
||||
for session_info in (ChatSessionInfo.from_db(session) for session in sessions)
|
||||
if session_info.notification_email_sent_at and session_info.completion_report
|
||||
][:limit]
|
||||
|
||||
|
||||
async def get_recent_completion_report_chat_sessions(
|
||||
user_id: str,
|
||||
limit: int,
|
||||
) -> list[ChatSessionInfo]:
|
||||
sessions = await PrismaChatSession.prisma().find_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"startType": {"not": ChatSessionStartType.MANUAL.value},
|
||||
},
|
||||
order={"updatedAt": "desc"},
|
||||
take=max(limit * 5, 10),
|
||||
)
|
||||
return [
|
||||
session_info
|
||||
for session_info in (ChatSessionInfo.from_db(session) for session in sessions)
|
||||
if session_info.completion_report is not None
|
||||
][:limit]
|
||||
|
||||
|
||||
async def get_user_session_count(
|
||||
user_id: str,
|
||||
with_auto: bool = False,
|
||||
) -> int:
|
||||
"""Get the total number of chat sessions for a user."""
|
||||
return await PrismaChatSession.prisma().count(where={"userId": user_id})
|
||||
return await PrismaChatSession.prisma().count(
|
||||
where={
|
||||
"userId": user_id,
|
||||
**({} if with_auto else {"startType": ChatSessionStartType.MANUAL.value}),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def delete_chat_session(session_id: str, user_id: str | None = None) -> bool:
|
||||
@@ -359,3 +585,42 @@ async def update_tool_message_content(
|
||||
f"tool_call_id {tool_call_id}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def create_chat_session_callback_token(
|
||||
user_id: str,
|
||||
source_session_id: str,
|
||||
callback_session_message: str,
|
||||
expires_at: datetime,
|
||||
) -> ChatSessionCallbackTokenInfo:
|
||||
token = await PrismaChatSessionCallbackToken.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"sourceSessionId": source_session_id,
|
||||
"callbackSessionMessage": callback_session_message,
|
||||
"expiresAt": expires_at,
|
||||
}
|
||||
)
|
||||
return ChatSessionCallbackTokenInfo.from_db(token)
|
||||
|
||||
|
||||
async def get_chat_session_callback_token(
|
||||
token_id: str,
|
||||
) -> ChatSessionCallbackTokenInfo | None:
|
||||
token = await PrismaChatSessionCallbackToken.prisma().find_unique(
|
||||
where={"id": token_id}
|
||||
)
|
||||
return ChatSessionCallbackTokenInfo.from_db(token) if token else None
|
||||
|
||||
|
||||
async def mark_chat_session_callback_token_consumed(
|
||||
token_id: str,
|
||||
consumed_session_id: str,
|
||||
) -> None:
|
||||
await PrismaChatSessionCallbackToken.prisma().update(
|
||||
where={"id": token_id},
|
||||
data={
|
||||
"consumedAt": datetime.now(UTC),
|
||||
"consumedSessionId": consumed_session_id,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,162 +0,0 @@
|
||||
"""Integration credential lookup with per-process TTL cache.
|
||||
|
||||
Provides token retrieval for connected integrations so that copilot tools
|
||||
(e.g. bash_exec) can inject auth tokens into the execution environment without
|
||||
hitting the database on every command.
|
||||
|
||||
Cache semantics (handled automatically by TTLCache):
|
||||
- Token found → cached for _TOKEN_CACHE_TTL (5 min). Avoids repeated DB hits
|
||||
for users who have credentials and are running many bash commands.
|
||||
- No credentials found → cached for _NULL_CACHE_TTL (60 s). Avoids a DB hit
|
||||
on every E2B command for users who haven't connected an account yet, while
|
||||
still picking up a newly-connected account within one minute.
|
||||
|
||||
Both caches are bounded to _CACHE_MAX_SIZE entries; cachetools evicts the
|
||||
least-recently-used entry when the limit is reached.
|
||||
|
||||
Multi-worker note: both caches are in-process only. Each worker/replica
|
||||
maintains its own independent cache, so a credential fetch may be duplicated
|
||||
across processes. This is acceptable for the current goal (reduce DB hits per
|
||||
session per-process), but if cache efficiency across replicas becomes important
|
||||
a shared cache (e.g. Redis) should be used instead.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
register_creds_changed_hook,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maps provider slug → env var names to inject when the provider is connected.
|
||||
# Add new providers here when adding integration support.
|
||||
# NOTE: keep in sync with connect_integration._PROVIDER_INFO — both registries
|
||||
# must be updated when adding a new provider.
|
||||
PROVIDER_ENV_VARS: dict[str, list[str]] = {
|
||||
"github": ["GH_TOKEN", "GITHUB_TOKEN"],
|
||||
}
|
||||
|
||||
_TOKEN_CACHE_TTL = 300.0 # seconds — for found tokens
|
||||
_NULL_CACHE_TTL = 60.0 # seconds — for "not connected" results
|
||||
_CACHE_MAX_SIZE = 10_000
|
||||
|
||||
# (user_id, provider) → token string. TTLCache handles expiry + eviction.
|
||||
# Thread-safety note: TTLCache is NOT thread-safe, but that is acceptable here
|
||||
# because all callers (get_provider_token, invalidate_user_provider_cache) run
|
||||
# exclusively on the asyncio event loop. There are no await points between a
|
||||
# cache read and its corresponding write within any function, so no concurrent
|
||||
# coroutine can interleave. If ThreadPoolExecutor workers are ever added to
|
||||
# this path, a threading.RLock should be wrapped around these caches.
|
||||
_token_cache: TTLCache[tuple[str, str], str] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_TOKEN_CACHE_TTL
|
||||
)
|
||||
# Separate cache for "no credentials" results with a shorter TTL.
|
||||
_null_cache: TTLCache[tuple[str, str], bool] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
|
||||
)
|
||||
|
||||
|
||||
def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
|
||||
"""Remove the cached entry for *user_id*/*provider* from both caches.
|
||||
|
||||
Call this after storing new credentials so that the next
|
||||
``get_provider_token()`` call performs a fresh DB lookup instead of
|
||||
serving a stale TTL-cached result.
|
||||
"""
|
||||
key = (user_id, provider)
|
||||
_token_cache.pop(key, None)
|
||||
_null_cache.pop(key, None)
|
||||
|
||||
|
||||
# Register this module's cache-bust function with the credentials manager so
|
||||
# that any create/update/delete operation immediately evicts stale cache
|
||||
# entries. This avoids a lazy import inside creds_manager and eliminates the
|
||||
# circular-import risk.
|
||||
register_creds_changed_hook(invalidate_user_provider_cache)
|
||||
|
||||
# Module-level singleton to avoid re-instantiating IntegrationCredentialsManager
|
||||
# on every cache-miss call to get_provider_token().
|
||||
_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
"""Return the user's access token for *provider*, or ``None`` if not connected.
|
||||
|
||||
OAuth2 tokens are preferred (refreshed if needed); API keys are the fallback.
|
||||
Found tokens are cached for _TOKEN_CACHE_TTL (5 min). "Not connected" results
|
||||
are cached for _NULL_CACHE_TTL (60 s) to avoid a DB hit on every bash_exec
|
||||
command for users who haven't connected yet, while still picking up a
|
||||
newly-connected account within one minute.
|
||||
"""
|
||||
cache_key = (user_id, provider)
|
||||
|
||||
if cache_key in _null_cache:
|
||||
return None
|
||||
if cached := _token_cache.get(cache_key):
|
||||
return cached
|
||||
|
||||
manager = _manager
|
||||
try:
|
||||
creds_list = await manager.store.get_creds_by_provider(user_id, provider)
|
||||
except Exception:
|
||||
logger.debug("Failed to fetch %s credentials for user %s", provider, user_id)
|
||||
return None
|
||||
|
||||
# Pass 1: prefer OAuth2 (carry scope info, refreshable via token endpoint).
|
||||
# Sort so broader-scoped tokens come first: a token with "repo" scope covers
|
||||
# full git access, while a public-data-only token lacks push/pull permission.
|
||||
# lock=False — background injection; not worth a distributed lock acquisition.
|
||||
oauth2_creds = sorted(
|
||||
[c for c in creds_list if c.type == "oauth2"],
|
||||
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
|
||||
)
|
||||
for creds in oauth2_creds:
|
||||
if creds.type == "oauth2":
|
||||
try:
|
||||
fresh = await manager.refresh_if_needed(
|
||||
user_id, cast(OAuth2Credentials, creds), lock=False
|
||||
)
|
||||
token = fresh.access_token.get_secret_value()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to refresh %s OAuth token for user %s; "
|
||||
"falling back to potentially stale token",
|
||||
provider,
|
||||
user_id,
|
||||
)
|
||||
token = cast(OAuth2Credentials, creds).access_token.get_secret_value()
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# Pass 2: fall back to API key (no expiry, no refresh needed).
|
||||
for creds in creds_list:
|
||||
if creds.type == "api_key":
|
||||
token = cast(APIKeyCredentials, creds).api_key.get_secret_value()
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# No credentials found — cache to avoid repeated DB hits.
|
||||
_null_cache[cache_key] = True
|
||||
return None
|
||||
|
||||
|
||||
async def get_integration_env_vars(user_id: str) -> dict[str, str]:
|
||||
"""Return env vars for all providers the user has connected.
|
||||
|
||||
Iterates :data:`PROVIDER_ENV_VARS`, fetches each token, and builds a flat
|
||||
``{env_var: token}`` dict ready to pass to a subprocess or E2B sandbox.
|
||||
Only providers with a stored credential contribute entries.
|
||||
"""
|
||||
env: dict[str, str] = {}
|
||||
for provider, var_names in PROVIDER_ENV_VARS.items():
|
||||
token = await get_provider_token(user_id, provider)
|
||||
if token:
|
||||
for var in var_names:
|
||||
env[var] = token
|
||||
return env
|
||||
@@ -1,193 +0,0 @@
|
||||
"""Tests for integration_creds — TTL cache and token lookup paths."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.copilot.integration_creds import (
|
||||
_NULL_CACHE_TTL,
|
||||
_TOKEN_CACHE_TTL,
|
||||
PROVIDER_ENV_VARS,
|
||||
_null_cache,
|
||||
_token_cache,
|
||||
get_integration_env_vars,
|
||||
get_provider_token,
|
||||
invalidate_user_provider_cache,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
|
||||
_USER = "user-integration-creds-test"
|
||||
_PROVIDER = "github"
|
||||
|
||||
|
||||
def _make_api_key_creds(key: str = "test-api-key") -> APIKeyCredentials:
|
||||
return APIKeyCredentials(
|
||||
id="creds-api-key",
|
||||
provider=_PROVIDER,
|
||||
api_key=SecretStr(key),
|
||||
title="Test API Key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_oauth2_creds(token: str = "test-oauth-token") -> OAuth2Credentials:
|
||||
return OAuth2Credentials(
|
||||
id="creds-oauth2",
|
||||
provider=_PROVIDER,
|
||||
title="Test OAuth",
|
||||
access_token=SecretStr(token),
|
||||
refresh_token=SecretStr("test-refresh"),
|
||||
access_token_expires_at=None,
|
||||
refresh_token_expires_at=None,
|
||||
scopes=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_caches():
|
||||
"""Ensure clean caches before and after every test."""
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
yield
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
|
||||
|
||||
class TestInvalidateUserProviderCache:
|
||||
def test_removes_token_entry(self):
|
||||
key = (_USER, _PROVIDER)
|
||||
_token_cache[key] = "tok"
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert key not in _token_cache
|
||||
|
||||
def test_removes_null_entry(self):
|
||||
key = (_USER, _PROVIDER)
|
||||
_null_cache[key] = True
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert key not in _null_cache
|
||||
|
||||
def test_noop_when_key_not_cached(self):
|
||||
# Should not raise even when there is no cache entry.
|
||||
invalidate_user_provider_cache("no-such-user", _PROVIDER)
|
||||
|
||||
def test_only_removes_targeted_key(self):
|
||||
other_key = ("other-user", _PROVIDER)
|
||||
_token_cache[other_key] = "other-tok"
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert other_key in _token_cache
|
||||
|
||||
|
||||
class TestGetProviderToken:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_returns_cached_token_without_db_hit(self):
|
||||
_token_cache[(_USER, _PROVIDER)] = "cached-tok"
|
||||
|
||||
mock_manager = MagicMock()
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "cached-tok"
|
||||
mock_manager.store.get_creds_by_provider.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_returns_none_for_null_cached_provider(self):
|
||||
_null_cache[(_USER, _PROVIDER)] = True
|
||||
|
||||
mock_manager = MagicMock()
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
mock_manager.store.get_creds_by_provider.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_api_key_creds_returned_and_cached(self):
|
||||
api_creds = _make_api_key_creds("my-api-key")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[api_creds])
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "my-api-key"
|
||||
assert _token_cache.get((_USER, _PROVIDER)) == "my-api-key"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_preferred_over_api_key(self):
|
||||
oauth_creds = _make_oauth2_creds("oauth-tok")
|
||||
api_creds = _make_api_key_creds("api-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(
|
||||
return_value=[api_creds, oauth_creds]
|
||||
)
|
||||
mock_manager.refresh_if_needed = AsyncMock(return_value=oauth_creds)
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "oauth-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_refresh_failure_falls_back_to_stale_token(self):
|
||||
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
|
||||
mock_manager.refresh_if_needed = AsyncMock(side_effect=RuntimeError("network"))
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "stale-oauth-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_credentials_caches_null_entry(self):
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
assert _null_cache.get((_USER, _PROVIDER)) is True
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_db_exception_returns_none_without_caching(self):
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(
|
||||
side_effect=RuntimeError("db down")
|
||||
)
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
# DB errors are not cached — next call will retry
|
||||
assert (_USER, _PROVIDER) not in _token_cache
|
||||
assert (_USER, _PROVIDER) not in _null_cache
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_null_cache_has_shorter_ttl_than_token_cache(self):
|
||||
"""Verify the TTL constants are set correctly for each cache."""
|
||||
assert _null_cache.ttl == _NULL_CACHE_TTL
|
||||
assert _token_cache.ttl == _TOKEN_CACHE_TTL
|
||||
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
|
||||
|
||||
|
||||
class TestGetIntegrationEnvVars:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_injects_all_env_vars_for_provider(self):
|
||||
_token_cache[(_USER, "github")] = "gh-tok"
|
||||
|
||||
result = await get_integration_env_vars(_USER)
|
||||
|
||||
for var in PROVIDER_ENV_VARS["github"]:
|
||||
assert result[var] == "gh-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_empty_dict_when_no_credentials(self):
|
||||
_null_cache[(_USER, "github")] = True
|
||||
|
||||
result = await get_integration_env_vars(_USER)
|
||||
|
||||
assert result == {}
|
||||
@@ -21,7 +21,7 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
)
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.db_accessors import chat_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
@@ -29,6 +29,11 @@ from backend.util import json
|
||||
from backend.util.exceptions import DatabaseError, RedisError
|
||||
|
||||
from .config import ChatConfig
|
||||
from .session_types import (
|
||||
ChatSessionConfig,
|
||||
ChatSessionStartType,
|
||||
StoredCompletionReport,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
@@ -80,11 +85,20 @@ class ChatSessionInfo(BaseModel):
|
||||
user_id: str
|
||||
title: str | None = None
|
||||
usage: list[Usage]
|
||||
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
|
||||
credentials: dict[str, dict] = Field(default_factory=dict)
|
||||
started_at: datetime
|
||||
updated_at: datetime
|
||||
successful_agent_runs: dict[str, int] = {}
|
||||
successful_agent_schedules: dict[str, int] = {}
|
||||
successful_agent_runs: dict[str, int] = Field(default_factory=dict)
|
||||
successful_agent_schedules: dict[str, int] = Field(default_factory=dict)
|
||||
start_type: ChatSessionStartType = ChatSessionStartType.MANUAL
|
||||
execution_tag: str | None = None
|
||||
session_config: ChatSessionConfig = Field(default_factory=ChatSessionConfig)
|
||||
completion_report: StoredCompletionReport | None = None
|
||||
completion_report_repair_count: int = 0
|
||||
completion_report_repair_queued_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
notification_email_sent_at: datetime | None = None
|
||||
notification_email_skipped_at: datetime | None = None
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
|
||||
@@ -97,6 +111,8 @@ class ChatSessionInfo(BaseModel):
|
||||
successful_agent_schedules = _parse_json_field(
|
||||
prisma_session.successfulAgentSchedules, default={}
|
||||
)
|
||||
session_config = _parse_json_field(prisma_session.sessionConfig, default={})
|
||||
completion_report = _parse_json_field(prisma_session.completionReport)
|
||||
|
||||
# Calculate usage from token counts
|
||||
usage = []
|
||||
@@ -110,6 +126,20 @@ class ChatSessionInfo(BaseModel):
|
||||
)
|
||||
)
|
||||
|
||||
parsed_session_config = ChatSessionConfig.model_validate(session_config or {})
|
||||
parsed_completion_report = None
|
||||
if isinstance(completion_report, dict):
|
||||
try:
|
||||
parsed_completion_report = StoredCompletionReport.model_validate(
|
||||
completion_report
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Invalid completionReport payload on session %s",
|
||||
prisma_session.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return cls(
|
||||
session_id=prisma_session.id,
|
||||
user_id=prisma_session.userId,
|
||||
@@ -120,6 +150,15 @@ class ChatSessionInfo(BaseModel):
|
||||
updated_at=prisma_session.updatedAt,
|
||||
successful_agent_runs=successful_agent_runs,
|
||||
successful_agent_schedules=successful_agent_schedules,
|
||||
start_type=ChatSessionStartType(str(prisma_session.startType)),
|
||||
execution_tag=prisma_session.executionTag,
|
||||
session_config=parsed_session_config,
|
||||
completion_report=parsed_completion_report,
|
||||
completion_report_repair_count=prisma_session.completionReportRepairCount,
|
||||
completion_report_repair_queued_at=prisma_session.completionReportRepairQueuedAt,
|
||||
completed_at=prisma_session.completedAt,
|
||||
notification_email_sent_at=prisma_session.notificationEmailSentAt,
|
||||
notification_email_skipped_at=prisma_session.notificationEmailSkippedAt,
|
||||
)
|
||||
|
||||
|
||||
@@ -127,7 +166,13 @@ class ChatSession(ChatSessionInfo):
|
||||
messages: list[ChatMessage]
|
||||
|
||||
@classmethod
|
||||
def new(cls, user_id: str) -> Self:
|
||||
def new(
|
||||
cls,
|
||||
user_id: str,
|
||||
start_type: ChatSessionStartType = ChatSessionStartType.MANUAL,
|
||||
execution_tag: str | None = None,
|
||||
session_config: ChatSessionConfig | None = None,
|
||||
) -> Self:
|
||||
return cls(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
@@ -137,6 +182,9 @@ class ChatSession(ChatSessionInfo):
|
||||
credentials={},
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
start_type=start_type,
|
||||
execution_tag=execution_tag,
|
||||
session_config=session_config or ChatSessionConfig(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -152,6 +200,16 @@ class ChatSession(ChatSessionInfo):
|
||||
messages=[ChatMessage.from_db(m) for m in prisma_session.Messages],
|
||||
)
|
||||
|
||||
@property
|
||||
def is_manual(self) -> bool:
|
||||
return self.start_type == ChatSessionStartType.MANUAL
|
||||
|
||||
def allows_tool(self, tool_name: str) -> bool:
|
||||
return self.session_config.allows_tool(tool_name)
|
||||
|
||||
def disables_tool(self, tool_name: str) -> bool:
|
||||
return self.session_config.disables_tool(tool_name)
|
||||
|
||||
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
|
||||
"""Attach a tool_call to the current turn's assistant message.
|
||||
|
||||
@@ -524,6 +582,9 @@ async def _save_session_to_db(
|
||||
await db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
start_type=session.start_type,
|
||||
execution_tag=session.execution_tag,
|
||||
session_config=session.session_config.model_dump(mode="json"),
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
@@ -539,6 +600,19 @@ async def _save_session_to_db(
|
||||
successful_agent_schedules=session.successful_agent_schedules,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
start_type=session.start_type,
|
||||
execution_tag=session.execution_tag,
|
||||
session_config=session.session_config.model_dump(mode="json"),
|
||||
completion_report=(
|
||||
session.completion_report.model_dump(mode="json")
|
||||
if session.completion_report
|
||||
else None
|
||||
),
|
||||
completion_report_repair_count=session.completion_report_repair_count,
|
||||
completion_report_repair_queued_at=session.completion_report_repair_queued_at,
|
||||
completed_at=session.completed_at,
|
||||
notification_email_sent_at=session.notification_email_sent_at,
|
||||
notification_email_skipped_at=session.notification_email_skipped_at,
|
||||
)
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
@@ -601,7 +675,13 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(user_id: str) -> ChatSession:
|
||||
async def create_chat_session(
|
||||
user_id: str,
|
||||
start_type: ChatSessionStartType = ChatSessionStartType.MANUAL,
|
||||
execution_tag: str | None = None,
|
||||
session_config: ChatSessionConfig | None = None,
|
||||
initial_messages: list[ChatMessage] | None = None,
|
||||
) -> ChatSession:
|
||||
"""Create a new chat session and persist it.
|
||||
|
||||
Raises:
|
||||
@@ -609,14 +689,30 @@ async def create_chat_session(user_id: str) -> ChatSession:
|
||||
callers never receive a non-persisted session that only exists
|
||||
in cache (which would be lost when the cache expires).
|
||||
"""
|
||||
session = ChatSession.new(user_id)
|
||||
session = ChatSession.new(
|
||||
user_id,
|
||||
start_type=start_type,
|
||||
execution_tag=execution_tag,
|
||||
session_config=session_config,
|
||||
)
|
||||
if initial_messages:
|
||||
session.messages.extend(initial_messages)
|
||||
|
||||
# Create in database first - fail fast if this fails
|
||||
try:
|
||||
await chat_db().create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
start_type=session.start_type,
|
||||
execution_tag=session.execution_tag,
|
||||
session_config=session.session_config.model_dump(mode="json"),
|
||||
)
|
||||
if session.messages:
|
||||
await _save_session_to_db(
|
||||
session,
|
||||
0,
|
||||
skip_existence_check=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create session {session.session_id} in database: {e}")
|
||||
raise DatabaseError(
|
||||
@@ -636,6 +732,7 @@ async def get_user_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
with_auto: bool = False,
|
||||
) -> tuple[list[ChatSessionInfo], int]:
|
||||
"""Get chat sessions for a user from the database with total count.
|
||||
|
||||
@@ -644,8 +741,16 @@ async def get_user_sessions(
|
||||
number of sessions for the user (not just the current page).
|
||||
"""
|
||||
db = chat_db()
|
||||
sessions = await db.get_user_chat_sessions(user_id, limit, offset)
|
||||
total_count = await db.get_user_session_count(user_id)
|
||||
sessions = await db.get_user_chat_sessions(
|
||||
user_id,
|
||||
limit,
|
||||
offset,
|
||||
with_auto=with_auto,
|
||||
)
|
||||
total_count = await db.get_user_session_count(
|
||||
user_id,
|
||||
with_auto=with_auto,
|
||||
)
|
||||
|
||||
return sessions, total_count
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from .model import (
|
||||
get_chat_session,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from .session_types import ChatSessionConfig, ChatSessionStartType
|
||||
|
||||
messages = [
|
||||
ChatMessage(content="Hello, how are you?", role="user"),
|
||||
@@ -46,7 +47,15 @@ messages = [
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_serialization_deserialization():
|
||||
s = ChatSession.new(user_id="abc123")
|
||||
s = ChatSession.new(
|
||||
user_id="abc123",
|
||||
start_type=ChatSessionStartType.AUTOPILOT_NIGHTLY,
|
||||
execution_tag="autopilot-nightly:2026-03-13",
|
||||
session_config=ChatSessionConfig(
|
||||
extra_tools=["completion_report"],
|
||||
disabled_tools=["edit_agent"],
|
||||
),
|
||||
)
|
||||
s.messages = messages
|
||||
s.usage = [Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)]
|
||||
serialized = s.model_dump_json()
|
||||
|
||||
@@ -6,7 +6,7 @@ handling the distinction between:
|
||||
- Local mode vs E2B mode (storage/filesystem differences)
|
||||
"""
|
||||
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
from backend.copilot.tools import iter_available_tools
|
||||
|
||||
# Shared technical notes that apply to both SDK and baseline modes
|
||||
_SHARED_TOOL_NOTES = """\
|
||||
@@ -52,68 +52,17 @@ Examples:
|
||||
You can embed a reference inside any string argument, or use it as the entire
|
||||
value. Multiple references in one argument are all expanded.
|
||||
|
||||
**Structured data**: When the **entire** argument value is a single file
|
||||
reference (no surrounding text), the platform automatically parses the file
|
||||
content based on its extension or MIME type. Supported formats: JSON, JSONL,
|
||||
CSV, TSV, YAML, TOML, Parquet, and Excel (.xlsx — first sheet only).
|
||||
For example, pass `@@agptfile:workspace://<id>` where the file is a `.csv` and
|
||||
the rows will be parsed into `list[list[str]]` automatically. If the format is
|
||||
unrecognised or parsing fails, the content is returned as a plain string.
|
||||
Legacy `.xls` files are **not** supported — only the modern `.xlsx` format.
|
||||
**Type coercion**: The platform automatically coerces expanded string values
|
||||
to match the block's expected input types. For example, if a block expects
|
||||
`list[list[str]]` and you pass a string containing a JSON array (e.g. from
|
||||
an @@agptfile: expansion), the string will be parsed into the correct type.
|
||||
|
||||
**Type coercion**: The platform also coerces expanded values to match the
|
||||
block's expected input types. For example, if a block expects `list[list[str]]`
|
||||
and the expanded value is a JSON string, it will be parsed into the correct type.
|
||||
|
||||
### Media file inputs (format: "file")
|
||||
Some block inputs accept media files — their schema shows `"format": "file"`.
|
||||
These fields accept:
|
||||
- **`workspace://<file_id>`** or **`workspace://<file_id>#<mime>`** — preferred
|
||||
for large files (images, videos, PDFs). The platform passes the reference
|
||||
directly to the block without reading the content into memory.
|
||||
- **`data:<mime>;base64,<payload>`** — inline base64 data URI, suitable for
|
||||
small files only.
|
||||
|
||||
When a block input has `format: "file"`, **pass the `workspace://` URI
|
||||
directly as the value** (do NOT wrap it in `@@agptfile:`). This avoids large
|
||||
payloads in tool arguments and preserves binary content (images, videos)
|
||||
that would be corrupted by text encoding.
|
||||
|
||||
Example — committing an image file to GitHub:
|
||||
```json
|
||||
{
|
||||
"files": [{
|
||||
"path": "docs/hero.png",
|
||||
"content": "workspace://abc123#image/png",
|
||||
"operation": "upsert"
|
||||
}]
|
||||
}
|
||||
```
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
"""
|
||||
|
||||
# E2B-only notes — E2B has full internet access so gh CLI works there.
|
||||
# Not shown in local (bubblewrap) mode: --unshare-net blocks all network.
|
||||
_E2B_TOOL_NOTES = """
|
||||
### GitHub CLI (`gh`) and git
|
||||
- If the user has connected their GitHub account, both `gh` and `git` are
|
||||
pre-authenticated — use them directly without any manual login step.
|
||||
`git` HTTPS operations (clone, push, pull) work automatically.
|
||||
- If the token changes mid-session (e.g. user reconnects with a new token),
|
||||
run `gh auth setup-git` to re-register the credential helper.
|
||||
- If `gh` or `git` fails with an authentication error (e.g. "authentication
|
||||
required", "could not read Username", or exit code 128), call
|
||||
`connect_integration(provider="github")` to surface the GitHub credentials
|
||||
setup card so the user can connect their account. Once connected, retry
|
||||
the operation.
|
||||
- For operations that need broader access (e.g. private org repos, GitHub
|
||||
Actions), pass the required scopes: e.g.
|
||||
`connect_integration(provider="github", scopes=["repo", "read:org"])`.
|
||||
"""
|
||||
|
||||
|
||||
# Environment-specific supplement templates
|
||||
def _build_storage_supplement(
|
||||
@@ -124,7 +73,6 @@ def _build_storage_supplement(
|
||||
storage_system_1_persistence: list[str],
|
||||
file_move_name_1_to_2: str,
|
||||
file_move_name_2_to_1: str,
|
||||
extra_notes: str = "",
|
||||
) -> str:
|
||||
"""Build storage/filesystem supplement for a specific environment.
|
||||
|
||||
@@ -139,7 +87,6 @@ def _build_storage_supplement(
|
||||
storage_system_1_persistence: List of persistence behavior descriptions
|
||||
file_move_name_1_to_2: Direction label for primary→persistent
|
||||
file_move_name_2_to_1: Direction label for persistent→primary
|
||||
extra_notes: Environment-specific notes appended after shared notes
|
||||
"""
|
||||
# Format lists as bullet points with proper indentation
|
||||
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
|
||||
@@ -173,16 +120,12 @@ def _build_storage_supplement(
|
||||
|
||||
### File persistence
|
||||
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
|
||||
{_SHARED_TOOL_NOTES}{extra_notes}"""
|
||||
{_SHARED_TOOL_NOTES}"""
|
||||
|
||||
|
||||
# Pre-built supplements for common environments
|
||||
def _get_local_storage_supplement(cwd: str) -> str:
|
||||
"""Local ephemeral storage (files lost between turns).
|
||||
|
||||
Network is isolated (bubblewrap --unshare-net), so internet-dependent CLIs
|
||||
like gh will not work — no integration env-var notes are included.
|
||||
"""
|
||||
"""Local ephemeral storage (files lost between turns)."""
|
||||
return _build_storage_supplement(
|
||||
working_dir=cwd,
|
||||
sandbox_type="in a network-isolated sandbox",
|
||||
@@ -200,11 +143,7 @@ def _get_local_storage_supplement(cwd: str) -> str:
|
||||
|
||||
|
||||
def _get_cloud_sandbox_supplement() -> str:
|
||||
"""Cloud persistent sandbox (files survive across turns in session).
|
||||
|
||||
E2B has full internet access, so integration tokens (GH_TOKEN etc.) are
|
||||
injected per command in bash_exec — include the CLI guidance notes.
|
||||
"""
|
||||
"""Cloud persistent sandbox (files survive across turns in session)."""
|
||||
return _build_storage_supplement(
|
||||
working_dir="/home/user",
|
||||
sandbox_type="in a cloud sandbox with full internet access",
|
||||
@@ -219,11 +158,10 @@ def _get_cloud_sandbox_supplement() -> str:
|
||||
],
|
||||
file_move_name_1_to_2="Sandbox → Persistent",
|
||||
file_move_name_2_to_1="Persistent → Sandbox",
|
||||
extra_notes=_E2B_TOOL_NOTES,
|
||||
)
|
||||
|
||||
|
||||
def _generate_tool_documentation() -> str:
|
||||
def _generate_tool_documentation(session=None) -> str:
|
||||
"""Auto-generate tool documentation from TOOL_REGISTRY.
|
||||
|
||||
NOTE: This is ONLY used in baseline mode (direct OpenAI API).
|
||||
@@ -239,11 +177,7 @@ def _generate_tool_documentation() -> str:
|
||||
docs = "\n## AVAILABLE TOOLS\n\n"
|
||||
|
||||
# Sort tools alphabetically for consistent output
|
||||
# Filter by is_available to match get_available_tools() behavior
|
||||
for name in sorted(TOOL_REGISTRY.keys()):
|
||||
tool = TOOL_REGISTRY[name]
|
||||
if not tool.is_available:
|
||||
continue
|
||||
for name, tool in sorted(iter_available_tools(session), key=lambda item: item[0]):
|
||||
schema = tool.as_openai_tool()
|
||||
desc = schema["function"].get("description", "No description available")
|
||||
# Format as bullet list with tool name in code style
|
||||
@@ -271,7 +205,7 @@ def get_sdk_supplement(use_e2b: bool, cwd: str = "") -> str:
|
||||
return _get_local_storage_supplement(cwd)
|
||||
|
||||
|
||||
def get_baseline_supplement() -> str:
|
||||
def get_baseline_supplement(session=None) -> str:
|
||||
"""Get the supplement for baseline mode (direct OpenAI API).
|
||||
|
||||
Baseline mode INCLUDES auto-generated tool documentation because the
|
||||
@@ -281,5 +215,5 @@ def get_baseline_supplement() -> str:
|
||||
Returns:
|
||||
The supplement string to append to the system prompt
|
||||
"""
|
||||
tool_docs = _generate_tool_documentation()
|
||||
tool_docs = _generate_tool_documentation(session)
|
||||
return tool_docs + _SHARED_TOOL_NOTES
|
||||
|
||||
@@ -3,45 +3,12 @@
|
||||
This module provides the integration layer between the Claude Agent SDK
|
||||
and the existing CoPilot tool system, enabling drop-in replacement of
|
||||
the current LLM orchestration with the battle-tested Claude Agent SDK.
|
||||
|
||||
Submodule imports are deferred via PEP 562 ``__getattr__`` to break a
|
||||
circular import cycle::
|
||||
|
||||
sdk/__init__ → tool_adapter → copilot.tools (TOOL_REGISTRY)
|
||||
copilot.tools → run_block → sdk.file_ref (no cycle here, but…)
|
||||
sdk/__init__ → service → copilot.prompting → copilot.tools (cycle!)
|
||||
|
||||
``tool_adapter`` uses ``TOOL_REGISTRY`` at **module level** to build the
|
||||
static ``COPILOT_TOOL_NAMES`` list, so the import cannot be deferred to
|
||||
function scope without a larger refactor (moving tool-name registration
|
||||
to a separate lightweight module). The lazy-import pattern here is the
|
||||
least invasive way to break the cycle while keeping module-level constants
|
||||
intact.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from .service import stream_chat_completion_sdk
|
||||
from .tool_adapter import create_copilot_mcp_server
|
||||
|
||||
__all__ = [
|
||||
"stream_chat_completion_sdk",
|
||||
"create_copilot_mcp_server",
|
||||
]
|
||||
|
||||
# Dispatch table for PEP 562 lazy imports. Each entry is a (module, attr)
|
||||
# pair so new exports can be added without touching __getattr__ itself.
|
||||
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
|
||||
"stream_chat_completion_sdk": (".service", "stream_chat_completion_sdk"),
|
||||
"create_copilot_mcp_server": (".tool_adapter", "create_copilot_mcp_server"),
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
entry = _LAZY_IMPORTS.get(name)
|
||||
if entry is not None:
|
||||
module_path, attr = entry
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(module_path, package=__name__)
|
||||
value = getattr(module, attr)
|
||||
globals()[name] = value
|
||||
return value
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@@ -41,20 +41,12 @@ from typing import Any
|
||||
from backend.copilot.context import (
|
||||
get_current_sandbox,
|
||||
get_sdk_cwd,
|
||||
get_workspace_manager,
|
||||
is_allowed_local_path,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.workspace_files import get_manager
|
||||
from backend.util.file import parse_workspace_uri
|
||||
from backend.util.file_content_parser import (
|
||||
BINARY_FORMATS,
|
||||
MIME_TO_FORMAT,
|
||||
PARSE_EXCEPTIONS,
|
||||
infer_format_from_uri,
|
||||
parse_file_content,
|
||||
)
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
class FileRefExpansionError(Exception):
|
||||
@@ -82,8 +74,6 @@ _FILE_REF_RE = re.compile(
|
||||
_MAX_EXPAND_CHARS = 200_000
|
||||
# Maximum total characters across all @@agptfile: expansions in one string.
|
||||
_MAX_TOTAL_EXPAND_CHARS = 1_000_000
|
||||
# Maximum raw byte size for bare ref structured parsing (10 MB).
|
||||
_MAX_BARE_REF_BYTES = 10_000_000
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -93,11 +83,6 @@ class FileRef:
|
||||
end_line: int | None # 1-indexed, inclusive
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API (top-down: main functions first, helpers below)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def parse_file_ref(text: str) -> FileRef | None:
|
||||
"""Return a :class:`FileRef` if *text* is a bare file reference token.
|
||||
|
||||
@@ -119,6 +104,17 @@ def parse_file_ref(text: str) -> FileRef | None:
|
||||
return FileRef(uri=m.group(1), start_line=start, end_line=end)
|
||||
|
||||
|
||||
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
|
||||
"""Slice *text* to the requested 1-indexed line range (inclusive)."""
|
||||
if start is None and end is None:
|
||||
return text
|
||||
lines = text.splitlines(keepends=True)
|
||||
s = (start - 1) if start is not None else 0
|
||||
e = end if end is not None else len(lines)
|
||||
selected = list(itertools.islice(lines, s, e))
|
||||
return "".join(selected)
|
||||
|
||||
|
||||
async def read_file_bytes(
|
||||
uri: str,
|
||||
user_id: str | None,
|
||||
@@ -134,47 +130,27 @@ async def read_file_bytes(
|
||||
if plain.startswith("workspace://"):
|
||||
if not user_id:
|
||||
raise ValueError("workspace:// file references require authentication")
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
ws = parse_workspace_uri(plain)
|
||||
try:
|
||||
data = await (
|
||||
return await (
|
||||
manager.read_file(ws.file_ref)
|
||||
if ws.is_path
|
||||
else manager.read_file_by_id(ws.file_ref)
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"File not found: {plain}")
|
||||
except (PermissionError, OSError) as exc:
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
except (AttributeError, TypeError, RuntimeError) as exc:
|
||||
# AttributeError/TypeError: workspace manager returned an
|
||||
# unexpected type or interface; RuntimeError: async runtime issues.
|
||||
logger.warning("Unexpected error reading %s: %s", plain, exc)
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
# NOTE: Workspace API does not support pre-read size checks;
|
||||
# the full file is loaded before the size guard below.
|
||||
if len(data) > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large ({len(data)} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
return data
|
||||
|
||||
if is_allowed_local_path(plain, get_sdk_cwd()):
|
||||
resolved = os.path.realpath(os.path.expanduser(plain))
|
||||
try:
|
||||
# Read with a one-byte overshoot to detect files that exceed the limit
|
||||
# without a separate os.path.getsize call (avoids TOCTOU race).
|
||||
with open(resolved, "rb") as fh:
|
||||
data = fh.read(_MAX_BARE_REF_BYTES + 1)
|
||||
if len(data) > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large (>{_MAX_BARE_REF_BYTES} bytes, "
|
||||
f"limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
return data
|
||||
return fh.read()
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"File not found: {plain}")
|
||||
except OSError as exc:
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
@@ -186,33 +162,9 @@ async def read_file_bytes(
|
||||
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
|
||||
) from exc
|
||||
try:
|
||||
data = bytes(await sandbox.files.read(remote, format="bytes"))
|
||||
except (FileNotFoundError, OSError, UnicodeDecodeError) as exc:
|
||||
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
|
||||
return bytes(await sandbox.files.read(remote, format="bytes"))
|
||||
except Exception as exc:
|
||||
# E2B SDK raises SandboxException subclasses (NotFoundException,
|
||||
# TimeoutException, NotEnoughSpaceException, etc.) which don't
|
||||
# inherit from standard exceptions. Import lazily to avoid a
|
||||
# hard dependency on e2b at module level.
|
||||
try:
|
||||
from e2b.exceptions import SandboxException # noqa: PLC0415
|
||||
|
||||
if isinstance(exc, SandboxException):
|
||||
raise ValueError(
|
||||
f"Failed to read from sandbox: {plain}: {exc}"
|
||||
) from exc
|
||||
except ImportError:
|
||||
pass
|
||||
# Re-raise unexpected exceptions (TypeError, AttributeError, etc.)
|
||||
# so they surface as real bugs rather than being silently masked.
|
||||
raise
|
||||
# NOTE: E2B sandbox API does not support pre-read size checks;
|
||||
# the full file is loaded before the size guard below.
|
||||
if len(data) > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large ({len(data)} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
return data
|
||||
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
|
||||
|
||||
raise ValueError(
|
||||
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
|
||||
@@ -226,13 +178,15 @@ async def resolve_file_ref(
|
||||
) -> str:
|
||||
"""Resolve a :class:`FileRef` to its text content."""
|
||||
raw = await read_file_bytes(ref.uri, user_id, session)
|
||||
return _apply_line_range(_to_str(raw), ref.start_line, ref.end_line)
|
||||
return _apply_line_range(
|
||||
raw.decode("utf-8", errors="replace"), ref.start_line, ref.end_line
|
||||
)
|
||||
|
||||
|
||||
async def expand_file_refs_in_string(
|
||||
text: str,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
session: "ChatSession",
|
||||
*,
|
||||
raise_on_error: bool = False,
|
||||
) -> str:
|
||||
@@ -278,9 +232,6 @@ async def expand_file_refs_in_string(
|
||||
if len(content) > _MAX_EXPAND_CHARS:
|
||||
content = content[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
|
||||
remaining = _MAX_TOTAL_EXPAND_CHARS - total_chars
|
||||
# remaining == 0 means the budget was exactly exhausted by the
|
||||
# previous ref. The elif below (len > remaining) won't catch
|
||||
# this since 0 > 0 is false, so we need the <= 0 check.
|
||||
if remaining <= 0:
|
||||
content = "[file-ref budget exhausted: total expansion limit reached]"
|
||||
elif len(content) > remaining:
|
||||
@@ -301,31 +252,13 @@ async def expand_file_refs_in_string(
|
||||
async def expand_file_refs_in_args(
|
||||
args: dict[str, Any],
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
*,
|
||||
input_schema: dict[str, Any] | None = None,
|
||||
session: "ChatSession",
|
||||
) -> dict[str, Any]:
|
||||
"""Recursively expand ``@@agptfile:...`` references in tool call arguments.
|
||||
|
||||
String values are expanded in-place. Nested dicts and lists are
|
||||
traversed. Non-string scalars are returned unchanged.
|
||||
|
||||
**Bare references** (the entire argument value is a single
|
||||
``@@agptfile:...`` token with no surrounding text) are resolved and then
|
||||
parsed according to the file's extension or MIME type. See
|
||||
:mod:`backend.util.file_content_parser` for the full list of supported
|
||||
formats (JSON, JSONL, CSV, TSV, YAML, TOML, Parquet, Excel).
|
||||
|
||||
When *input_schema* is provided and the target property has
|
||||
``"type": "string"``, structured parsing is skipped — the raw file content
|
||||
is returned as a plain string so blocks receive the original text.
|
||||
|
||||
If the format is unrecognised or parsing fails, the content is returned as
|
||||
a plain string (the fallback).
|
||||
|
||||
**Embedded references** (``@@agptfile:`` mixed with other text) always
|
||||
produce a plain string — structured parsing only applies to bare refs.
|
||||
|
||||
Raises :class:`FileRefExpansionError` if any reference fails to resolve,
|
||||
so the tool is *not* executed with an error string as its input. The
|
||||
caller (the MCP tool wrapper) should convert this into an MCP error
|
||||
@@ -334,382 +267,15 @@ async def expand_file_refs_in_args(
|
||||
if not args:
|
||||
return args
|
||||
|
||||
properties = (input_schema or {}).get("properties", {})
|
||||
|
||||
async def _expand(
|
||||
value: Any,
|
||||
*,
|
||||
prop_schema: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
"""Recursively expand a single argument value.
|
||||
|
||||
Strings are checked for ``@@agptfile:`` references and expanded
|
||||
(bare refs get structured parsing; embedded refs get inline
|
||||
substitution). Dicts and lists are traversed recursively,
|
||||
threading the corresponding sub-schema from *prop_schema* so
|
||||
that nested fields also receive correct type-aware expansion.
|
||||
Non-string scalars pass through unchanged.
|
||||
"""
|
||||
async def _expand(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
ref = parse_file_ref(value)
|
||||
if ref is not None:
|
||||
# MediaFileType fields: return the raw URI immediately —
|
||||
# no file reading, no format inference, no content parsing.
|
||||
if _is_media_file_field(prop_schema):
|
||||
return ref.uri
|
||||
|
||||
fmt = infer_format_from_uri(ref.uri)
|
||||
# Workspace URIs by ID (workspace://abc123) have no extension.
|
||||
# When the MIME fragment is also missing, fall back to the
|
||||
# workspace file manager's metadata for format detection.
|
||||
if fmt is None and ref.uri.startswith("workspace://"):
|
||||
fmt = await _infer_format_from_workspace(ref.uri, user_id, session)
|
||||
return await _expand_bare_ref(ref, fmt, user_id, session, prop_schema)
|
||||
|
||||
# Not a bare ref — do normal inline expansion.
|
||||
return await expand_file_refs_in_string(
|
||||
value, user_id, session, raise_on_error=True
|
||||
)
|
||||
if isinstance(value, dict):
|
||||
# When the schema says this is an object but doesn't define
|
||||
# inner properties, skip expansion — the caller (e.g.
|
||||
# RunBlockTool) will expand with the actual nested schema.
|
||||
if (
|
||||
prop_schema is not None
|
||||
and prop_schema.get("type") == "object"
|
||||
and "properties" not in prop_schema
|
||||
):
|
||||
return value
|
||||
nested_props = (prop_schema or {}).get("properties", {})
|
||||
return {
|
||||
k: await _expand(v, prop_schema=nested_props.get(k))
|
||||
for k, v in value.items()
|
||||
}
|
||||
return {k: await _expand(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
items_schema = (prop_schema or {}).get("items")
|
||||
return [await _expand(item, prop_schema=items_schema) for item in value]
|
||||
return [await _expand(item) for item in value]
|
||||
return value
|
||||
|
||||
return {k: await _expand(v, prop_schema=properties.get(k)) for k, v in args.items()}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers (used by the public functions above)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
|
||||
"""Slice *text* to the requested 1-indexed line range (inclusive).
|
||||
|
||||
When the requested range extends beyond the file, a note is appended
|
||||
so the LLM knows it received the entire remaining content.
|
||||
"""
|
||||
if start is None and end is None:
|
||||
return text
|
||||
lines = text.splitlines(keepends=True)
|
||||
total = len(lines)
|
||||
s = (start - 1) if start is not None else 0
|
||||
e = end if end is not None else total
|
||||
selected = list(itertools.islice(lines, s, e))
|
||||
result = "".join(selected)
|
||||
if end is not None and end > total:
|
||||
result += f"\n[Note: file has only {total} lines]\n"
|
||||
return result
|
||||
|
||||
|
||||
def _to_str(content: str | bytes) -> str:
|
||||
"""Decode *content* to a string if it is bytes, otherwise return as-is."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return content.decode("utf-8", errors="replace")
|
||||
|
||||
|
||||
def _check_content_size(content: str | bytes) -> None:
|
||||
"""Raise :class:`ValueError` if *content* exceeds the byte limit.
|
||||
|
||||
Raises ``ValueError`` (not ``FileRefExpansionError``) so that the caller
|
||||
(``_expand_bare_ref``) can unify all resolution errors into a single
|
||||
``except ValueError`` → ``FileRefExpansionError`` handler, keeping the
|
||||
error-flow consistent with ``read_file_bytes`` and ``resolve_file_ref``.
|
||||
|
||||
For ``bytes``, the length is the byte count directly. For ``str``,
|
||||
we encode to UTF-8 first because multi-byte characters (e.g. emoji)
|
||||
mean the byte size can be up to 4x the character count.
|
||||
"""
|
||||
if isinstance(content, bytes):
|
||||
size = len(content)
|
||||
else:
|
||||
char_len = len(content)
|
||||
# Fast lower bound: UTF-8 byte count >= char count.
|
||||
# If char count already exceeds the limit, reject immediately
|
||||
# without allocating an encoded copy.
|
||||
if char_len > _MAX_BARE_REF_BYTES:
|
||||
size = char_len # real byte size is even larger
|
||||
# Fast upper bound: each char is at most 4 UTF-8 bytes.
|
||||
# If worst-case is still under the limit, skip encoding entirely.
|
||||
elif char_len * 4 <= _MAX_BARE_REF_BYTES:
|
||||
return
|
||||
else:
|
||||
# Edge case: char count is under limit but multibyte chars
|
||||
# might push byte count over. Encode to get exact size.
|
||||
size = len(content.encode("utf-8"))
|
||||
if size > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large for structured parsing "
|
||||
f"({size} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
|
||||
|
||||
async def _infer_format_from_workspace(
|
||||
uri: str,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
) -> str | None:
|
||||
"""Look up workspace file metadata to infer the format.
|
||||
|
||||
Workspace URIs by ID (``workspace://abc123``) have no file extension.
|
||||
When the MIME fragment is also absent, we query the workspace file
|
||||
manager for the file's stored MIME type and original filename.
|
||||
"""
|
||||
if not user_id:
|
||||
return None
|
||||
try:
|
||||
ws = parse_workspace_uri(uri)
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
info = await (
|
||||
manager.get_file_info(ws.file_ref)
|
||||
if not ws.is_path
|
||||
else manager.get_file_info_by_path(ws.file_ref)
|
||||
)
|
||||
if info is None:
|
||||
return None
|
||||
# Try MIME type first, then filename extension.
|
||||
mime = (info.mime_type or "").split(";", 1)[0].strip().lower()
|
||||
return MIME_TO_FORMAT.get(mime) or infer_format_from_uri(info.name)
|
||||
except (
|
||||
ValueError,
|
||||
FileNotFoundError,
|
||||
OSError,
|
||||
PermissionError,
|
||||
AttributeError,
|
||||
TypeError,
|
||||
):
|
||||
# Expected failures: bad URI, missing file, permission denied, or
|
||||
# workspace manager returning unexpected types. Propagate anything
|
||||
# else (e.g. programming errors) so they don't get silently swallowed.
|
||||
logger.debug("workspace metadata lookup failed for %s", uri, exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def _is_media_file_field(prop_schema: dict[str, Any] | None) -> bool:
|
||||
"""Return True if *prop_schema* describes a MediaFileType field (format: file)."""
|
||||
if prop_schema is None:
|
||||
return False
|
||||
return (
|
||||
prop_schema.get("type") == "string"
|
||||
and prop_schema.get("format") == MediaFileType.string_format
|
||||
)
|
||||
|
||||
|
||||
async def _expand_bare_ref(
|
||||
ref: FileRef,
|
||||
fmt: str | None,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
prop_schema: dict[str, Any] | None,
|
||||
) -> Any:
|
||||
"""Resolve and parse a bare ``@@agptfile:`` reference.
|
||||
|
||||
This is the structured-parsing path: the file is read, optionally parsed
|
||||
according to *fmt*, and adapted to the target *prop_schema*.
|
||||
|
||||
Raises :class:`FileRefExpansionError` on resolution or parsing failure.
|
||||
|
||||
Note: MediaFileType fields (format: "file") are handled earlier in
|
||||
``_expand`` to avoid unnecessary format inference and file I/O.
|
||||
"""
|
||||
try:
|
||||
if fmt is not None and fmt in BINARY_FORMATS:
|
||||
# Binary formats need raw bytes, not UTF-8 text.
|
||||
# Line ranges are meaningless for binary formats (parquet/xlsx)
|
||||
# — ignore them and parse full bytes. Warn so the caller/model
|
||||
# knows the range was silently dropped.
|
||||
if ref.start_line is not None or ref.end_line is not None:
|
||||
logger.warning(
|
||||
"Line range [%s-%s] ignored for binary format %s (%s); "
|
||||
"binary formats are always parsed in full.",
|
||||
ref.start_line,
|
||||
ref.end_line,
|
||||
fmt,
|
||||
ref.uri,
|
||||
)
|
||||
content: str | bytes = await read_file_bytes(ref.uri, user_id, session)
|
||||
else:
|
||||
content = await resolve_file_ref(ref, user_id, session)
|
||||
except ValueError as exc:
|
||||
raise FileRefExpansionError(str(exc)) from exc
|
||||
|
||||
# For known formats this rejects files >10 MB before parsing.
|
||||
# For unknown formats _MAX_EXPAND_CHARS (200K chars) below is stricter,
|
||||
# but this check still guards the parsing path which has no char limit.
|
||||
# _check_content_size raises ValueError, which we unify here just like
|
||||
# resolution errors above.
|
||||
try:
|
||||
_check_content_size(content)
|
||||
except ValueError as exc:
|
||||
raise FileRefExpansionError(str(exc)) from exc
|
||||
|
||||
# When the schema declares this parameter as "string",
|
||||
# return raw file content — don't parse into a structured
|
||||
# type that would need json.dumps() serialisation.
|
||||
expect_string = (prop_schema or {}).get("type") == "string"
|
||||
if expect_string:
|
||||
if isinstance(content, bytes):
|
||||
raise FileRefExpansionError(
|
||||
f"Cannot use {fmt} file as text input: "
|
||||
f"binary formats (parquet, xlsx) must be passed "
|
||||
f"to a block that accepts structured data (list/object), "
|
||||
f"not a string-typed parameter."
|
||||
)
|
||||
return content
|
||||
|
||||
if fmt is not None:
|
||||
# Use strict mode for binary formats so we surface the
|
||||
# actual error (e.g. missing pyarrow/openpyxl, corrupt
|
||||
# file) instead of silently returning garbled bytes.
|
||||
strict = fmt in BINARY_FORMATS
|
||||
try:
|
||||
parsed = parse_file_content(content, fmt, strict=strict)
|
||||
except PARSE_EXCEPTIONS as exc:
|
||||
raise FileRefExpansionError(f"Failed to parse {fmt} file: {exc}") from exc
|
||||
# Normalize bytes fallback to str so tools never
|
||||
# receive raw bytes when parsing fails.
|
||||
if isinstance(parsed, bytes):
|
||||
parsed = _to_str(parsed)
|
||||
return _adapt_to_schema(parsed, prop_schema)
|
||||
|
||||
# Unknown format — return as plain string, but apply
|
||||
# the same per-ref character limit used by inline refs
|
||||
# to prevent injecting unexpectedly large content.
|
||||
text = _to_str(content)
|
||||
if len(text) > _MAX_EXPAND_CHARS:
|
||||
text = text[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
|
||||
return text
|
||||
|
||||
|
||||
def _adapt_to_schema(parsed: Any, prop_schema: dict[str, Any] | None) -> Any:
|
||||
"""Adapt a parsed file value to better fit the target schema type.
|
||||
|
||||
When the parser returns a natural type (e.g. dict from YAML, list from CSV)
|
||||
that doesn't match the block's expected type, this function converts it to
|
||||
a more useful representation instead of relying on pydantic's generic
|
||||
coercion (which can produce awkward results like flattened dicts → lists).
|
||||
|
||||
Returns *parsed* unchanged when no adaptation is needed.
|
||||
"""
|
||||
if prop_schema is None:
|
||||
return parsed
|
||||
|
||||
target_type = prop_schema.get("type")
|
||||
|
||||
# Dict → array: delegate to helper.
|
||||
if isinstance(parsed, dict) and target_type == "array":
|
||||
return _adapt_dict_to_array(parsed, prop_schema)
|
||||
|
||||
# List → object: delegate to helper (raises for non-tabular lists).
|
||||
if isinstance(parsed, list) and target_type == "object":
|
||||
return _adapt_list_to_object(parsed)
|
||||
|
||||
# Tabular list → Any (no type): convert to list of dicts.
|
||||
# Blocks like FindInDictionaryBlock have `input: Any` which produces
|
||||
# a schema with no "type" key. Tabular [[header],[rows]] is unusable
|
||||
# for key lookup, but [{col: val}, ...] works with FindInDict's
|
||||
# list-of-dicts branch (line 195-199 in data_manipulation.py).
|
||||
if isinstance(parsed, list) and target_type is None and _is_tabular(parsed):
|
||||
return _tabular_to_list_of_dicts(parsed)
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
def _adapt_dict_to_array(parsed: dict, prop_schema: dict[str, Any]) -> Any:
|
||||
"""Adapt a parsed dict to an array-typed field.
|
||||
|
||||
Extracts list-valued entries when the target item type is ``array``,
|
||||
passes through unchanged when item type is ``string`` (lets pydantic error),
|
||||
or wraps in ``[parsed]`` as a fallback.
|
||||
"""
|
||||
items_type = (prop_schema.get("items") or {}).get("type")
|
||||
if items_type == "array":
|
||||
# Target is List[List[Any]] — extract list-typed values from the
|
||||
# dict as inner lists. E.g. YAML {"fruits": [{...},...]}} with
|
||||
# ConcatenateLists (List[List[Any]]) → [[{...},...]].
|
||||
list_values = [v for v in parsed.values() if isinstance(v, list)]
|
||||
if list_values:
|
||||
return list_values
|
||||
if items_type == "string":
|
||||
# Target is List[str] — wrapping a dict would give [dict]
|
||||
# which can't coerce to strings. Return unchanged and let
|
||||
# pydantic surface a clear validation error.
|
||||
return parsed
|
||||
# Fallback: wrap in a single-element list so the block gets [dict]
|
||||
# instead of pydantic flattening keys/values into a flat list.
|
||||
return [parsed]
|
||||
|
||||
|
||||
def _adapt_list_to_object(parsed: list) -> Any:
|
||||
"""Adapt a parsed list to an object-typed field.
|
||||
|
||||
Converts tabular lists to column-dicts; raises for non-tabular lists.
|
||||
"""
|
||||
if _is_tabular(parsed):
|
||||
return _tabular_to_column_dict(parsed)
|
||||
# Non-tabular list (e.g. a plain Python list from a YAML file) cannot
|
||||
# be meaningfully coerced to an object. Raise explicitly so callers
|
||||
# get a clear error rather than pydantic silently wrapping the list.
|
||||
raise FileRefExpansionError(
|
||||
"Cannot adapt a non-tabular list to an object-typed field. "
|
||||
"Expected a tabular structure ([[header], [row1], ...]) or a dict."
|
||||
)
|
||||
|
||||
|
||||
def _is_tabular(parsed: Any) -> bool:
|
||||
"""Check if parsed data is in tabular format: [[header], [row1], ...].
|
||||
|
||||
Uses isinstance checks because this is a structural type guard on
|
||||
opaque parser output (Any), not duck typing. A Protocol wouldn't
|
||||
help here — we need to verify exact list-of-lists shape.
|
||||
"""
|
||||
if not isinstance(parsed, list) or len(parsed) < 2:
|
||||
return False
|
||||
header = parsed[0]
|
||||
if not isinstance(header, list) or not header:
|
||||
return False
|
||||
if not all(isinstance(h, str) for h in header):
|
||||
return False
|
||||
return all(isinstance(row, list) for row in parsed[1:])
|
||||
|
||||
|
||||
def _tabular_to_list_of_dicts(parsed: list) -> list[dict[str, Any]]:
|
||||
"""Convert [[header], [row1], ...] → [{header[0]: row[0], ...}, ...].
|
||||
|
||||
Ragged rows (fewer columns than the header) get None for missing values.
|
||||
Extra values beyond the header length are silently dropped.
|
||||
"""
|
||||
header = parsed[0]
|
||||
return [
|
||||
dict(itertools.zip_longest(header, row[: len(header)], fillvalue=None))
|
||||
for row in parsed[1:]
|
||||
]
|
||||
|
||||
|
||||
def _tabular_to_column_dict(parsed: list) -> dict[str, list]:
|
||||
"""Convert [[header], [row1], ...] → {"col1": [val1, ...], ...}.
|
||||
|
||||
Ragged rows (fewer columns than the header) get None for missing values,
|
||||
ensuring all columns have equal length.
|
||||
"""
|
||||
header = parsed[0]
|
||||
return {
|
||||
col: [row[i] if i < len(row) else None for row in parsed[1:]]
|
||||
for i, col in enumerate(header)
|
||||
}
|
||||
return {k: await _expand(v) for k, v in args.items()}
|
||||
|
||||
@@ -175,199 +175,6 @@ async def test_expand_args_replaces_file_ref_in_nested_dict():
|
||||
assert result["count"] == 42
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# expand_file_refs_in_args — bare ref structured parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_json_returns_parsed_dict():
|
||||
"""Bare ref to a .json file returns parsed dict, not raw string."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
json_file = os.path.join(sdk_cwd, "data.json")
|
||||
with open(json_file, "w") as f:
|
||||
f.write('{"key": "value", "count": 42}')
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{json_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["data"] == {"key": "value", "count": 42}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_csv_returns_parsed_table():
|
||||
"""Bare ref to a .csv file returns list[list[str]] table."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
csv_file = os.path.join(sdk_cwd, "data.csv")
|
||||
with open(csv_file, "w") as f:
|
||||
f.write("Name,Score\nAlice,90\nBob,85")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"input": f"@@agptfile:{csv_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["input"] == [
|
||||
["Name", "Score"],
|
||||
["Alice", "90"],
|
||||
["Bob", "85"],
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_unknown_extension_returns_string():
|
||||
"""Bare ref to a file with unknown extension returns plain string."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
txt_file = os.path.join(sdk_cwd, "readme.txt")
|
||||
with open(txt_file, "w") as f:
|
||||
f.write("plain text content")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{txt_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["data"] == "plain text content"
|
||||
assert isinstance(result["data"], str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_invalid_json_falls_back_to_string():
|
||||
"""Bare ref to a .json file with invalid JSON falls back to string."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
json_file = os.path.join(sdk_cwd, "bad.json")
|
||||
with open(json_file, "w") as f:
|
||||
f.write("not valid json {{{")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{json_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["data"] == "not valid json {{{"
|
||||
assert isinstance(result["data"], str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedded_ref_always_returns_string_even_for_json():
|
||||
"""Embedded ref (text around it) returns plain string, not parsed JSON."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
json_file = os.path.join(sdk_cwd, "data.json")
|
||||
with open(json_file, "w") as f:
|
||||
f.write('{"key": "value"}')
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"prefix @@agptfile:{json_file} suffix"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert isinstance(result["data"], str)
|
||||
assert result["data"].startswith("prefix ")
|
||||
assert result["data"].endswith(" suffix")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_yaml_returns_parsed_dict():
|
||||
"""Bare ref to a .yaml file returns parsed dict."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
yaml_file = os.path.join(sdk_cwd, "config.yaml")
|
||||
with open(yaml_file, "w") as f:
|
||||
f.write("name: test\ncount: 42\n")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"config": f"@@agptfile:{yaml_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["config"] == {"name": "test", "count": 42}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_binary_with_line_range_ignores_range():
|
||||
"""Bare ref to a binary file (.parquet) with line range parses the full file.
|
||||
|
||||
Binary formats (parquet, xlsx) ignore line ranges — the full content is
|
||||
parsed and the range is silently dropped with a log warning.
|
||||
"""
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pytest.skip("pandas not installed")
|
||||
try:
|
||||
import pyarrow # noqa: F401 # pyright: ignore[reportMissingImports]
|
||||
except ImportError:
|
||||
pytest.skip("pyarrow not installed")
|
||||
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
parquet_file = os.path.join(sdk_cwd, "data.parquet")
|
||||
import io as _io
|
||||
|
||||
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
|
||||
buf = _io.BytesIO()
|
||||
df.to_parquet(buf, index=False)
|
||||
with open(parquet_file, "wb") as f:
|
||||
f.write(buf.getvalue())
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
# Line range [1-2] should be silently ignored for binary formats.
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{parquet_file}[1-2]"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
# Full file is returned despite the line range.
|
||||
assert result["data"] == [["A", "B"], [1, 4], [2, 5], [3, 6]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_toml_returns_parsed_dict():
|
||||
"""Bare ref to a .toml file returns parsed dict."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
toml_file = os.path.join(sdk_cwd, "config.toml")
|
||||
with open(toml_file, "w") as f:
|
||||
f.write('name = "test"\ncount = 42\n')
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"config": f"@@agptfile:{toml_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["config"] == {"name": "test", "count": 42}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_file_handler — extended to accept workspace:// and local paths
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -412,7 +219,7 @@ async def test_read_file_handler_workspace_uri():
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", mock_session),
|
||||
), patch(
|
||||
"backend.copilot.sdk.file_ref.get_workspace_manager",
|
||||
"backend.copilot.sdk.file_ref.get_manager",
|
||||
new=AsyncMock(return_value=mock_manager),
|
||||
):
|
||||
result = await _read_file_handler(
|
||||
@@ -469,7 +276,7 @@ async def test_read_file_bytes_workspace_virtual_path():
|
||||
mock_manager.read_file.return_value = b"virtual path content"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.get_workspace_manager",
|
||||
"backend.copilot.sdk.file_ref.get_manager",
|
||||
new=AsyncMock(return_value=mock_manager),
|
||||
):
|
||||
result = await read_file_bytes("workspace:///reports/q1.md", "user-1", session)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,7 +12,7 @@ import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, cast
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
import openai
|
||||
from claude_agent_sdk import (
|
||||
@@ -29,7 +29,6 @@ from langfuse import propagate_attributes
|
||||
from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.context import get_workspace_manager
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.executor.cluster_lock import AsyncClusterLock
|
||||
from backend.util.exceptions import NotFoundError
|
||||
@@ -57,12 +56,13 @@ from ..response_model import (
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from ..service import (
|
||||
_build_system_prompt,
|
||||
_generate_session_title,
|
||||
_is_langfuse_configured,
|
||||
_resolve_system_prompt,
|
||||
)
|
||||
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
|
||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
from ..tools.workspace_files import get_manager
|
||||
from ..tracking import track_user_message
|
||||
from .compaction import CompactionTracker, filter_compaction_messages
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
@@ -88,6 +88,10 @@ logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
class _ClaudeSDKTransport(Protocol):
|
||||
async def write(self, data: str) -> None: ...
|
||||
|
||||
|
||||
def _setup_langfuse_otel() -> None:
|
||||
"""Configure OTEL tracing for the Claude Agent SDK → Langfuse.
|
||||
|
||||
@@ -137,6 +141,16 @@ def _setup_langfuse_otel() -> None:
|
||||
_setup_langfuse_otel()
|
||||
|
||||
|
||||
async def _write_multimodal_query(
|
||||
client: ClaudeSDKClient,
|
||||
user_message: dict[str, Any],
|
||||
) -> None:
|
||||
transport = cast(_ClaudeSDKTransport | None, getattr(client, "_transport", None))
|
||||
if transport is None:
|
||||
raise RuntimeError("Claude SDK transport is unavailable for multimodal input")
|
||||
await transport.write(json.dumps(user_message) + "\n")
|
||||
|
||||
|
||||
# Set to hold background tasks to prevent garbage collection
|
||||
_background_tasks: set[asyncio.Task[Any]] = set()
|
||||
|
||||
@@ -565,7 +579,7 @@ async def _prepare_file_attachments(
|
||||
return empty
|
||||
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to create workspace manager for file attachments",
|
||||
@@ -690,7 +704,7 @@ async def stream_chat_completion_sdk(
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Generate title for new sessions (first user message)
|
||||
if is_user_message and not session.title:
|
||||
if is_user_message and session.is_manual and not session.title:
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
if len(user_messages) == 1:
|
||||
first_message = user_messages[0].content or message or ""
|
||||
@@ -769,7 +783,7 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
return None
|
||||
try:
|
||||
sandbox = await get_or_create_sandbox(
|
||||
return await get_or_create_sandbox(
|
||||
session_id,
|
||||
api_key=e2b_api_key,
|
||||
template=config.e2b_sandbox_template,
|
||||
@@ -783,9 +797,7 @@ async def stream_chat_completion_sdk(
|
||||
e2b_err,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
return sandbox
|
||||
return None
|
||||
|
||||
async def _fetch_transcript():
|
||||
"""Download transcript for --resume if applicable."""
|
||||
@@ -807,7 +819,11 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
e2b_sandbox, (base_system_prompt, _), dl = await asyncio.gather(
|
||||
_setup_e2b(),
|
||||
_build_system_prompt(user_id, has_conversation_history=has_history),
|
||||
_resolve_system_prompt(
|
||||
session,
|
||||
user_id,
|
||||
has_conversation_history=has_history,
|
||||
),
|
||||
_fetch_transcript(),
|
||||
)
|
||||
|
||||
@@ -864,7 +880,7 @@ async def stream_chat_completion_sdk(
|
||||
"Claude Code CLI subscription (requires `claude login`)."
|
||||
)
|
||||
|
||||
mcp_server = create_copilot_mcp_server(use_e2b=use_e2b)
|
||||
mcp_server = create_copilot_mcp_server(session, use_e2b=use_e2b)
|
||||
|
||||
sdk_model = _resolve_sdk_model()
|
||||
|
||||
@@ -878,7 +894,7 @@ async def stream_chat_completion_sdk(
|
||||
on_compact=compaction.on_compact,
|
||||
)
|
||||
|
||||
allowed = get_copilot_tool_names(use_e2b=use_e2b)
|
||||
allowed = get_copilot_tool_names(session, use_e2b=use_e2b)
|
||||
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
|
||||
|
||||
def _on_stderr(line: str) -> None:
|
||||
@@ -979,10 +995,7 @@ async def stream_chat_completion_sdk(
|
||||
"parent_tool_use_id": None,
|
||||
"session_id": session_id,
|
||||
}
|
||||
assert client._transport is not None # noqa: SLF001
|
||||
await client._transport.write( # noqa: SLF001
|
||||
json.dumps(user_msg) + "\n"
|
||||
)
|
||||
await _write_multimodal_query(client, user_msg)
|
||||
# Capture user message in transcript (multimodal)
|
||||
transcript_builder.append_user(content=content_blocks)
|
||||
else:
|
||||
|
||||
@@ -20,7 +20,7 @@ class _FakeFileInfo:
|
||||
size_bytes: int
|
||||
|
||||
|
||||
_PATCH_TARGET = "backend.copilot.sdk.service.get_workspace_manager"
|
||||
_PATCH_TARGET = "backend.copilot.sdk.service.get_manager"
|
||||
|
||||
|
||||
class TestPrepareFileAttachments:
|
||||
@@ -205,6 +205,29 @@ class TestPromptSupplement:
|
||||
):
|
||||
assert "`browser_navigate`" in docs
|
||||
|
||||
def test_baseline_supplement_respects_session_disabled_tools(self):
|
||||
"""Session-specific docs should hide disabled tools and include added session tools."""
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.session_types import (
|
||||
ChatSessionConfig,
|
||||
ChatSessionStartType,
|
||||
)
|
||||
|
||||
session = ChatSession.new(
|
||||
"user-1",
|
||||
start_type=ChatSessionStartType.AUTOPILOT_NIGHTLY,
|
||||
session_config=ChatSessionConfig(
|
||||
extra_tools=["completion_report"],
|
||||
disabled_tools=["edit_agent"],
|
||||
),
|
||||
)
|
||||
|
||||
docs = get_baseline_supplement(session)
|
||||
|
||||
assert "`completion_report`" in docs
|
||||
assert "`edit_agent`" not in docs
|
||||
|
||||
def test_baseline_supplement_includes_workflows(self):
|
||||
"""Baseline supplement should include workflow guidance in tool descriptions."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
@@ -219,15 +242,13 @@ class TestPromptSupplement:
|
||||
def test_baseline_supplement_completeness(self):
|
||||
"""All available tools from TOOL_REGISTRY should appear in baseline supplement."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
from backend.copilot.tools import iter_available_tools
|
||||
|
||||
docs = get_baseline_supplement()
|
||||
|
||||
# Verify each available registered tool is documented
|
||||
# (matches _generate_tool_documentation which filters by is_available)
|
||||
for tool_name, tool in TOOL_REGISTRY.items():
|
||||
if not tool.is_available:
|
||||
continue
|
||||
# (matches _generate_tool_documentation which filters with iter_available_tools)
|
||||
for tool_name, _ in iter_available_tools():
|
||||
assert (
|
||||
f"`{tool_name}`" in docs
|
||||
), f"Tool '{tool_name}' missing from baseline supplement"
|
||||
@@ -277,14 +298,12 @@ class TestPromptSupplement:
|
||||
def test_baseline_supplement_no_duplicate_tools(self):
|
||||
"""No tool should appear multiple times in baseline supplement."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
from backend.copilot.tools import iter_available_tools
|
||||
|
||||
docs = get_baseline_supplement()
|
||||
|
||||
# Count occurrences of each available tool in the entire supplement
|
||||
for tool_name, tool in TOOL_REGISTRY.items():
|
||||
if not tool.is_available:
|
||||
continue
|
||||
for tool_name, _ in iter_available_tools():
|
||||
# Count how many times this tool appears as a bullet point
|
||||
count = docs.count(f"- **`{tool_name}`**")
|
||||
assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)"
|
||||
|
||||
@@ -32,7 +32,7 @@ from backend.copilot.sdk.file_ref import (
|
||||
expand_file_refs_in_args,
|
||||
read_file_bytes,
|
||||
)
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
from backend.copilot.tools import iter_available_tools
|
||||
from backend.copilot.tools.base import BaseTool
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
@@ -338,7 +338,11 @@ def _text_from_mcp_result(result: dict[str, Any]) -> str:
|
||||
)
|
||||
|
||||
|
||||
def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
def create_copilot_mcp_server(
|
||||
session: ChatSession,
|
||||
*,
|
||||
use_e2b: bool = False,
|
||||
):
|
||||
"""Create an in-process MCP server configuration for CoPilot tools.
|
||||
|
||||
When *use_e2b* is True, five additional MCP file tools are registered
|
||||
@@ -347,7 +351,7 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
:func:`get_sdk_disallowed_tools`.
|
||||
"""
|
||||
|
||||
def _truncating(fn, tool_name: str, input_schema: dict[str, Any] | None = None):
|
||||
def _truncating(fn, tool_name: str):
|
||||
"""Wrap a tool handler so its response is truncated to stay under the
|
||||
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
|
||||
response adapter before the SDK can apply its own head-truncation.
|
||||
@@ -361,9 +365,7 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
user_id, session = get_execution_context()
|
||||
if session is not None:
|
||||
try:
|
||||
args = await expand_file_refs_in_args(
|
||||
args, user_id, session, input_schema=input_schema
|
||||
)
|
||||
args = await expand_file_refs_in_args(args, user_id, session)
|
||||
except FileRefExpansionError as exc:
|
||||
return _mcp_error(
|
||||
f"@@agptfile: reference could not be resolved: {exc}. "
|
||||
@@ -389,14 +391,13 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
|
||||
sdk_tools = []
|
||||
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
for tool_name, base_tool in iter_available_tools(session):
|
||||
handler = create_tool_handler(base_tool)
|
||||
schema = _build_input_schema(base_tool)
|
||||
decorated = tool(
|
||||
tool_name,
|
||||
base_tool.description,
|
||||
schema,
|
||||
)(_truncating(handler, tool_name, input_schema=schema))
|
||||
_build_input_schema(base_tool),
|
||||
)(_truncating(handler, tool_name))
|
||||
sdk_tools.append(decorated)
|
||||
|
||||
# E2B file tools replace SDK built-in Read/Write/Edit/Glob/Grep.
|
||||
@@ -478,25 +479,30 @@ DANGEROUS_PATTERNS = [
|
||||
r"subprocess",
|
||||
]
|
||||
|
||||
# Static tool name list for the non-E2B case (backward compatibility).
|
||||
COPILOT_TOOL_NAMES = [
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
|
||||
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
|
||||
*_SDK_BUILTIN_TOOLS,
|
||||
]
|
||||
|
||||
|
||||
def get_copilot_tool_names(*, use_e2b: bool = False) -> list[str]:
|
||||
def get_copilot_tool_names(
|
||||
session: ChatSession,
|
||||
*,
|
||||
use_e2b: bool = False,
|
||||
) -> list[str]:
|
||||
"""Build the ``allowed_tools`` list for :class:`ClaudeAgentOptions`.
|
||||
|
||||
When *use_e2b* is True the SDK built-in file tools are replaced by MCP
|
||||
equivalents that route to the E2B sandbox.
|
||||
"""
|
||||
tool_names = [
|
||||
f"{MCP_TOOL_PREFIX}{name}" for name, _ in iter_available_tools(session)
|
||||
]
|
||||
|
||||
if not use_e2b:
|
||||
return list(COPILOT_TOOL_NAMES)
|
||||
return [
|
||||
*tool_names,
|
||||
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
|
||||
*_SDK_BUILTIN_TOOLS,
|
||||
]
|
||||
|
||||
return [
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
|
||||
*tool_names,
|
||||
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in E2B_FILE_TOOL_NAMES],
|
||||
*_SDK_BUILTIN_ALWAYS,
|
||||
|
||||
@@ -3,11 +3,14 @@
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import get_sdk_cwd
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.session_types import ChatSessionConfig, ChatSessionStartType
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
from .tool_adapter import (
|
||||
_MCP_MAX_CHARS,
|
||||
_text_from_mcp_result,
|
||||
get_copilot_tool_names,
|
||||
pop_pending_tool_output,
|
||||
set_execution_context,
|
||||
stash_pending_tool_output,
|
||||
@@ -168,3 +171,20 @@ class TestTruncationAndStashIntegration:
|
||||
text = _text_from_mcp_result(truncated)
|
||||
assert len(text) < len(big_text)
|
||||
assert len(str(truncated)) <= _MCP_MAX_CHARS
|
||||
|
||||
|
||||
class TestSessionToolFiltering:
|
||||
def test_disabled_tools_are_removed_from_sdk_allowed_tools(self):
|
||||
session = ChatSession.new(
|
||||
"user-1",
|
||||
start_type=ChatSessionStartType.AUTOPILOT_NIGHTLY,
|
||||
session_config=ChatSessionConfig(
|
||||
extra_tools=["completion_report"],
|
||||
disabled_tools=["edit_agent"],
|
||||
),
|
||||
)
|
||||
|
||||
tool_names = get_copilot_tool_names(session)
|
||||
|
||||
assert "mcp__copilot__completion_report" in tool_names
|
||||
assert "mcp__copilot__edit_agent" not in tool_names
|
||||
|
||||
@@ -22,7 +22,7 @@ from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
from backend.util.settings import AppEnvironment, Settings
|
||||
|
||||
from .config import ChatConfig
|
||||
from .model import ChatSessionInfo, get_chat_session, upsert_chat_session
|
||||
from .model import ChatSession, ChatSessionInfo, get_chat_session, upsert_chat_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -64,7 +64,13 @@ def _is_langfuse_configured() -> bool:
|
||||
)
|
||||
|
||||
|
||||
async def _get_system_prompt_template(context: str) -> str:
|
||||
async def _get_system_prompt_template(
|
||||
context: str,
|
||||
*,
|
||||
prompt_name: str | None = None,
|
||||
fallback_prompt: str | None = None,
|
||||
template_vars: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
"""Get the system prompt, trying Langfuse first with fallback to default.
|
||||
|
||||
Args:
|
||||
@@ -73,6 +79,11 @@ async def _get_system_prompt_template(context: str) -> str:
|
||||
Returns:
|
||||
The compiled system prompt string.
|
||||
"""
|
||||
resolved_prompt_name = prompt_name or config.langfuse_prompt_name
|
||||
resolved_template_vars = {
|
||||
"users_information": context,
|
||||
**(template_vars or {}),
|
||||
}
|
||||
if _is_langfuse_configured():
|
||||
try:
|
||||
# Use asyncio.to_thread to avoid blocking the event loop
|
||||
@@ -85,16 +96,16 @@ async def _get_system_prompt_template(context: str) -> str:
|
||||
)
|
||||
prompt = await asyncio.to_thread(
|
||||
langfuse.get_prompt,
|
||||
config.langfuse_prompt_name,
|
||||
resolved_prompt_name,
|
||||
label=label,
|
||||
cache_ttl_seconds=config.langfuse_prompt_cache_ttl,
|
||||
)
|
||||
return prompt.compile(users_information=context)
|
||||
return prompt.compile(**resolved_template_vars)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch prompt from Langfuse, using default: {e}")
|
||||
|
||||
# Fallback to default prompt
|
||||
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
|
||||
return (fallback_prompt or DEFAULT_SYSTEM_PROMPT).format(**resolved_template_vars)
|
||||
|
||||
|
||||
async def _build_system_prompt(
|
||||
@@ -131,6 +142,21 @@ async def _build_system_prompt(
|
||||
return compiled, understanding
|
||||
|
||||
|
||||
async def _resolve_system_prompt(
|
||||
session: ChatSession,
|
||||
user_id: str | None,
|
||||
*,
|
||||
has_conversation_history: bool = False,
|
||||
) -> tuple[str, Any]:
|
||||
override = session.session_config.system_prompt_override
|
||||
if override:
|
||||
return override, None
|
||||
return await _build_system_prompt(
|
||||
user_id,
|
||||
has_conversation_history=has_conversation_history,
|
||||
)
|
||||
|
||||
|
||||
async def _generate_session_title(
|
||||
message: str,
|
||||
user_id: str | None = None,
|
||||
|
||||
60
autogpt_platform/backend/backend/copilot/session_types.py
Normal file
60
autogpt_platform/backend/backend/copilot/session_types.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class ChatSessionStartType(str, Enum):
|
||||
MANUAL = "MANUAL"
|
||||
AUTOPILOT_NIGHTLY = "AUTOPILOT_NIGHTLY"
|
||||
AUTOPILOT_CALLBACK = "AUTOPILOT_CALLBACK"
|
||||
AUTOPILOT_INVITE_CTA = "AUTOPILOT_INVITE_CTA"
|
||||
|
||||
|
||||
class ChatSessionConfig(BaseModel):
|
||||
system_prompt_override: str | None = None
|
||||
initial_user_message: str | None = None
|
||||
initial_assistant_message: str | None = None
|
||||
extra_tools: list[str] = Field(default_factory=list)
|
||||
disabled_tools: list[str] = Field(default_factory=list)
|
||||
|
||||
def allows_tool(self, tool_name: str) -> bool:
|
||||
return tool_name in self.extra_tools
|
||||
|
||||
def disables_tool(self, tool_name: str) -> bool:
|
||||
return tool_name in self.disabled_tools
|
||||
|
||||
|
||||
class CompletionReportInput(BaseModel):
|
||||
thoughts: str
|
||||
should_notify_user: bool
|
||||
email_title: str | None = None
|
||||
email_body: str | None = None
|
||||
callback_session_message: str | None = None
|
||||
approval_summary: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_notification_fields(self) -> "CompletionReportInput":
|
||||
if self.should_notify_user:
|
||||
required_fields = {
|
||||
"email_title": self.email_title,
|
||||
"email_body": self.email_body,
|
||||
"callback_session_message": self.callback_session_message,
|
||||
}
|
||||
missing = [
|
||||
field_name for field_name, value in required_fields.items() if not value
|
||||
]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
"Missing required notification fields: " + ", ".join(missing)
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class StoredCompletionReport(CompletionReportInput):
|
||||
has_pending_approvals: bool
|
||||
pending_approval_count: int
|
||||
pending_approval_graph_exec_id: str | None = None
|
||||
saved_at: datetime
|
||||
@@ -17,11 +17,12 @@ Subscribers:
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from collections.abc import Awaitable
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import orjson
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from backend.api.model import CopilotCompletionPayload
|
||||
from backend.data.notification_bus import (
|
||||
@@ -55,6 +56,12 @@ _listener_sessions: dict[int, tuple[str, asyncio.Task]] = {}
|
||||
# Timeout for putting chunks into subscriber queues (seconds)
|
||||
# If the queue is full and doesn't drain within this time, send an overflow error
|
||||
QUEUE_PUT_TIMEOUT = 5.0
|
||||
SESSION_LOOKUP_RETRY_SECONDS = 0.05
|
||||
STREAM_REPLAY_COUNT = 1000
|
||||
STREAM_XREAD_BLOCK_MS = 5000
|
||||
STREAM_XREAD_COUNT = 100
|
||||
STALE_SESSION_BUFFER_SECONDS = 300
|
||||
UNSUBSCRIBE_TIMEOUT_SECONDS = 5.0
|
||||
|
||||
# Lua script for atomic compare-and-swap status update (idempotent completion)
|
||||
# Returns 1 if status was updated, 0 if already completed/failed
|
||||
@@ -68,19 +75,24 @@ return 0
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveSession:
|
||||
SessionStatus = Literal["running", "completed", "failed"]
|
||||
RedisHash = dict[str, str]
|
||||
RedisStreamMessages = list[tuple[str, list[tuple[str, RedisHash]]]]
|
||||
|
||||
|
||||
class ActiveSession(BaseModel):
|
||||
"""Represents an active streaming session (metadata only, no in-memory queues)."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
session_id: str
|
||||
user_id: str | None
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
turn_id: str = ""
|
||||
blocking: bool = False # If True, HTTP request is waiting for completion
|
||||
status: Literal["running", "completed", "failed"] = "running"
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
asyncio_task: asyncio.Task | None = None
|
||||
status: SessionStatus = "running"
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
def _get_session_meta_key(session_id: str) -> str:
|
||||
@@ -93,7 +105,54 @@ def _get_turn_stream_key(turn_id: str) -> str:
|
||||
return f"{config.turn_stream_prefix}{turn_id}"
|
||||
|
||||
|
||||
def _parse_session_meta(meta: dict[Any, Any], session_id: str = "") -> ActiveSession:
|
||||
async def _redis_hset_mapping(redis: Any, key: str, mapping: RedisHash) -> int:
|
||||
return await cast(Awaitable[int], redis.hset(key, mapping=mapping))
|
||||
|
||||
|
||||
async def _redis_hgetall(redis: Any, key: str) -> RedisHash:
|
||||
return cast(
|
||||
RedisHash,
|
||||
await cast(Awaitable[dict[str, str]], redis.hgetall(key)),
|
||||
)
|
||||
|
||||
|
||||
async def _redis_hget(redis: Any, key: str, field: str) -> str | None:
|
||||
return cast(
|
||||
str | None,
|
||||
await cast(Awaitable[str | None], redis.hget(key, field)),
|
||||
)
|
||||
|
||||
|
||||
async def _redis_xread(
|
||||
redis: Any,
|
||||
streams: dict[str, str],
|
||||
*,
|
||||
count: int,
|
||||
block: int | None,
|
||||
) -> RedisStreamMessages:
|
||||
return cast(
|
||||
RedisStreamMessages,
|
||||
await cast(
|
||||
Awaitable[RedisStreamMessages],
|
||||
redis.xread(streams, count=count, block=block),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def _redis_complete_session(
|
||||
redis: Any,
|
||||
meta_key: str,
|
||||
status: SessionStatus,
|
||||
) -> int:
|
||||
return int(
|
||||
await cast(
|
||||
Awaitable[int | str],
|
||||
redis.eval(COMPLETE_SESSION_SCRIPT, 1, meta_key, status),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _parse_session_meta(meta: RedisHash, session_id: str = "") -> ActiveSession:
|
||||
"""Parse a raw Redis hash into a typed ActiveSession.
|
||||
|
||||
Centralises the ``meta.get(...)`` boilerplate so callers don't repeat it.
|
||||
@@ -107,7 +166,7 @@ def _parse_session_meta(meta: dict[Any, Any], session_id: str = "") -> ActiveSes
|
||||
tool_name=meta.get("tool_name", ""),
|
||||
turn_id=meta.get("turn_id", "") or session_id,
|
||||
blocking=meta.get("blocking") == "1",
|
||||
status=meta.get("status", "running"), # type: ignore[arg-type]
|
||||
status=cast(SessionStatus, meta.get("status", "running")),
|
||||
)
|
||||
|
||||
|
||||
@@ -170,7 +229,8 @@ async def create_session(
|
||||
# No need to delete old stream — each turn_id is a fresh UUID
|
||||
|
||||
hset_start = time.perf_counter()
|
||||
await redis.hset( # type: ignore[misc]
|
||||
await _redis_hset_mapping(
|
||||
redis,
|
||||
meta_key,
|
||||
mapping={
|
||||
"session_id": session_id,
|
||||
@@ -280,6 +340,108 @@ async def publish_chunk(
|
||||
return message_id
|
||||
|
||||
|
||||
def _decode_stream_chunk(msg_data: RedisHash) -> StreamBaseResponse | None:
|
||||
raw_data = msg_data.get("data")
|
||||
if raw_data is None:
|
||||
return None
|
||||
|
||||
chunk_data = orjson.loads(raw_data)
|
||||
return _reconstruct_chunk(chunk_data)
|
||||
|
||||
|
||||
async def _replay_messages(
|
||||
messages: RedisStreamMessages,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
*,
|
||||
last_message_id: str,
|
||||
) -> tuple[int, str]:
|
||||
replayed_count = 0
|
||||
replay_last_id = last_message_id
|
||||
|
||||
for _stream_name, stream_messages in messages:
|
||||
for msg_id, msg_data in stream_messages:
|
||||
replay_last_id = msg_id
|
||||
try:
|
||||
chunk = _decode_stream_chunk(msg_data)
|
||||
if chunk is None:
|
||||
continue
|
||||
await subscriber_queue.put(chunk)
|
||||
replayed_count += 1
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to replay message: %s", exc)
|
||||
|
||||
return replayed_count, replay_last_id
|
||||
|
||||
|
||||
async def _deliver_message_to_queue(
|
||||
session_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
chunk: StreamBaseResponse,
|
||||
*,
|
||||
last_delivered_id: str,
|
||||
log_meta: dict[str, Any],
|
||||
) -> bool:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
subscriber_queue.put(chunk),
|
||||
timeout=QUEUE_PUT_TIMEOUT,
|
||||
)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"timeout_s": QUEUE_PUT_TIMEOUT,
|
||||
"reason": "queue_full",
|
||||
}
|
||||
},
|
||||
)
|
||||
try:
|
||||
overflow_error = StreamError(
|
||||
errorText="Message delivery timeout - some messages may have been missed",
|
||||
code="QUEUE_OVERFLOW",
|
||||
details={
|
||||
"last_delivered_id": last_delivered_id,
|
||||
"recovery_hint": f"Reconnect with last_message_id={last_delivered_id}",
|
||||
},
|
||||
)
|
||||
subscriber_queue.put_nowait(overflow_error)
|
||||
except asyncio.QueueFull:
|
||||
logger.error(
|
||||
f"Cannot deliver overflow error for session {session_id}, queue completely blocked"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def _handle_xread_timeout(
|
||||
redis: Any,
|
||||
session_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
) -> bool:
|
||||
meta_key = _get_session_meta_key(session_id)
|
||||
status = await _redis_hget(redis, meta_key, "status")
|
||||
if status != "running":
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
subscriber_queue.put(StreamFinish()),
|
||||
timeout=QUEUE_PUT_TIMEOUT,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Timeout delivering finish event for session {session_id}")
|
||||
return False
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
subscriber_queue.put(StreamHeartbeat()),
|
||||
timeout=QUEUE_PUT_TIMEOUT,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Timeout delivering heartbeat for session {session_id}")
|
||||
return True
|
||||
|
||||
|
||||
async def subscribe_to_session(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
@@ -313,7 +475,7 @@ async def subscribe_to_session(
|
||||
redis_start = time.perf_counter()
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_session_meta_key(session_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
meta = await _redis_hgetall(redis, meta_key)
|
||||
hgetall_time = (time.perf_counter() - redis_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
|
||||
@@ -328,8 +490,8 @@ async def subscribe_to_session(
|
||||
"[TIMING] Session not found on first attempt, retrying after 50ms delay",
|
||||
extra={"json_fields": {**log_meta}},
|
||||
)
|
||||
await asyncio.sleep(0.05) # 50ms
|
||||
meta = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
await asyncio.sleep(SESSION_LOOKUP_RETRY_SECONDS)
|
||||
meta = await _redis_hgetall(redis, meta_key)
|
||||
if not meta:
|
||||
elapsed = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
@@ -374,7 +536,12 @@ async def subscribe_to_session(
|
||||
|
||||
# Step 1: Replay messages from Redis Stream
|
||||
xread_start = time.perf_counter()
|
||||
messages = await redis.xread({stream_key: last_message_id}, block=None, count=1000)
|
||||
messages = await _redis_xread(
|
||||
redis,
|
||||
{stream_key: last_message_id},
|
||||
block=None,
|
||||
count=STREAM_REPLAY_COUNT,
|
||||
)
|
||||
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={session_status}",
|
||||
@@ -387,22 +554,11 @@ async def subscribe_to_session(
|
||||
},
|
||||
)
|
||||
|
||||
replayed_count = 0
|
||||
replay_last_id = last_message_id
|
||||
if messages:
|
||||
for _stream_name, stream_messages in messages:
|
||||
for msg_id, msg_data in stream_messages:
|
||||
replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||
if "data" in msg_data:
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data["data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
await subscriber_queue.put(chunk)
|
||||
replayed_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to replay message: {e}")
|
||||
replayed_count, replay_last_id = await _replay_messages(
|
||||
messages,
|
||||
subscriber_queue,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
|
||||
@@ -455,7 +611,7 @@ async def _stream_listener(
|
||||
session_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
last_replayed_id: str,
|
||||
log_meta: dict | None = None,
|
||||
log_meta: dict[str, Any] | None = None,
|
||||
turn_id: str = "",
|
||||
) -> None:
|
||||
"""Listen to Redis Stream for new messages using blocking XREAD.
|
||||
@@ -499,8 +655,11 @@ async def _stream_listener(
|
||||
# Short timeout prevents frontend timeout (12s) while waiting for heartbeats (15s)
|
||||
xread_start = time.perf_counter()
|
||||
xread_count += 1
|
||||
messages = await redis.xread(
|
||||
{stream_key: current_id}, block=5000, count=100
|
||||
messages = await _redis_xread(
|
||||
redis,
|
||||
{stream_key: current_id},
|
||||
block=STREAM_XREAD_BLOCK_MS,
|
||||
count=STREAM_XREAD_COUNT,
|
||||
)
|
||||
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||
|
||||
@@ -532,114 +691,66 @@ async def _stream_listener(
|
||||
)
|
||||
|
||||
if not messages:
|
||||
# Timeout - check if session is still running
|
||||
meta_key = _get_session_meta_key(session_id)
|
||||
status = await redis.hget(meta_key, "status") # type: ignore[misc]
|
||||
# Stop if session metadata is gone (TTL expired) or status is not "running"
|
||||
if status != "running":
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
subscriber_queue.put(StreamFinish()),
|
||||
timeout=QUEUE_PUT_TIMEOUT,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"Timeout delivering finish event for session {session_id}"
|
||||
)
|
||||
if not await _handle_xread_timeout(
|
||||
redis,
|
||||
session_id,
|
||||
subscriber_queue,
|
||||
):
|
||||
break
|
||||
# Session still running - send heartbeat to keep connection alive
|
||||
# This prevents frontend timeout (12s) during long-running operations
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
subscriber_queue.put(StreamHeartbeat()),
|
||||
timeout=QUEUE_PUT_TIMEOUT,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"Timeout delivering heartbeat for session {session_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
for _stream_name, stream_messages in messages:
|
||||
for msg_id, msg_data in stream_messages:
|
||||
current_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
|
||||
if "data" not in msg_data:
|
||||
continue
|
||||
|
||||
current_id = msg_id
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data["data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
subscriber_queue.put(chunk),
|
||||
timeout=QUEUE_PUT_TIMEOUT,
|
||||
)
|
||||
# Update last delivered ID on successful delivery
|
||||
last_delivered_id = current_id
|
||||
messages_delivered += 1
|
||||
if first_message_time is None:
|
||||
first_message_time = time.perf_counter()
|
||||
elapsed = (first_message_time - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": elapsed,
|
||||
"chunk_type": type(chunk).__name__,
|
||||
}
|
||||
},
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"timeout_s": QUEUE_PUT_TIMEOUT,
|
||||
"reason": "queue_full",
|
||||
}
|
||||
},
|
||||
)
|
||||
# Send overflow error with recovery info
|
||||
try:
|
||||
overflow_error = StreamError(
|
||||
errorText="Message delivery timeout - some messages may have been missed",
|
||||
code="QUEUE_OVERFLOW",
|
||||
details={
|
||||
"last_delivered_id": last_delivered_id,
|
||||
"recovery_hint": f"Reconnect with last_message_id={last_delivered_id}",
|
||||
},
|
||||
)
|
||||
subscriber_queue.put_nowait(overflow_error)
|
||||
except asyncio.QueueFull:
|
||||
# Queue is completely stuck, nothing more we can do
|
||||
logger.error(
|
||||
f"Cannot deliver overflow error for session {session_id}, "
|
||||
"queue completely blocked"
|
||||
)
|
||||
|
||||
# Stop listening on finish
|
||||
if isinstance(chunk, StreamFinish):
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] StreamFinish received in {total_time / 1000:.1f}s; delivered={messages_delivered}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time,
|
||||
"messages_delivered": messages_delivered,
|
||||
}
|
||||
},
|
||||
)
|
||||
return
|
||||
chunk = _decode_stream_chunk(msg_data)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error processing stream message: {e}",
|
||||
extra={"json_fields": {**log_meta, "error": str(e)}},
|
||||
)
|
||||
continue
|
||||
|
||||
if chunk is None:
|
||||
continue
|
||||
|
||||
delivered = await _deliver_message_to_queue(
|
||||
session_id,
|
||||
subscriber_queue,
|
||||
chunk,
|
||||
last_delivered_id=last_delivered_id,
|
||||
log_meta=log_meta,
|
||||
)
|
||||
if delivered:
|
||||
last_delivered_id = current_id
|
||||
messages_delivered += 1
|
||||
if first_message_time is None:
|
||||
first_message_time = time.perf_counter()
|
||||
elapsed = (first_message_time - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": elapsed,
|
||||
"chunk_type": type(chunk).__name__,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if isinstance(chunk, StreamFinish):
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] StreamFinish received in {total_time / 1000:.1f}s; delivered={messages_delivered}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time,
|
||||
"messages_delivered": messages_delivered,
|
||||
}
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
except asyncio.CancelledError:
|
||||
elapsed = (time.perf_counter() - start_time) * 1000
|
||||
@@ -712,16 +823,16 @@ async def mark_session_completed(
|
||||
Returns:
|
||||
True if session was newly marked completed, False if already completed/failed
|
||||
"""
|
||||
status: Literal["completed", "failed"] = "failed" if error_message else "completed"
|
||||
status: SessionStatus = "failed" if error_message else "completed"
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_session_meta_key(session_id)
|
||||
|
||||
# Resolve turn_id for publishing to the correct stream
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
meta = await _redis_hgetall(redis, meta_key)
|
||||
turn_id = _parse_session_meta(meta, session_id).turn_id if meta else session_id
|
||||
|
||||
# Atomic compare-and-swap: only update if status is "running"
|
||||
result = await redis.eval(COMPLETE_SESSION_SCRIPT, 1, meta_key, status) # type: ignore[misc]
|
||||
result = await _redis_complete_session(redis, meta_key, status)
|
||||
|
||||
if result == 0:
|
||||
logger.debug(f"Session {session_id} already completed/failed, skipping")
|
||||
@@ -774,6 +885,18 @@ async def mark_session_completed(
|
||||
f"for session {session_id}: {e}"
|
||||
)
|
||||
|
||||
try:
|
||||
from backend.copilot.autopilot import handle_non_manual_session_completion
|
||||
|
||||
await handle_non_manual_session_completion(session_id)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to process non-manual completion for session %s: %s",
|
||||
session_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -788,7 +911,7 @@ async def get_session(session_id: str) -> ActiveSession | None:
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_session_meta_key(session_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
meta = await _redis_hgetall(redis, meta_key)
|
||||
|
||||
if not meta:
|
||||
return None
|
||||
@@ -815,7 +938,7 @@ async def get_session_with_expiry_info(
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_session_meta_key(session_id)
|
||||
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
meta = await _redis_hgetall(redis, meta_key)
|
||||
|
||||
if not meta:
|
||||
# Metadata expired — we can't resolve turn_id, so check using
|
||||
@@ -847,7 +970,7 @@ async def get_active_session(
|
||||
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_session_meta_key(session_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
meta = await _redis_hgetall(redis, meta_key)
|
||||
|
||||
if not meta:
|
||||
return None, "0-0"
|
||||
@@ -871,7 +994,9 @@ async def get_active_session(
|
||||
try:
|
||||
created_at = datetime.fromisoformat(created_at_str)
|
||||
age_seconds = (datetime.now(timezone.utc) - created_at).total_seconds()
|
||||
stale_threshold = COPILOT_CONSUMER_TIMEOUT_SECONDS + 300 # + 5min buffer
|
||||
stale_threshold = (
|
||||
COPILOT_CONSUMER_TIMEOUT_SECONDS + STALE_SESSION_BUFFER_SECONDS
|
||||
)
|
||||
if age_seconds > stale_threshold:
|
||||
logger.warning(
|
||||
f"[STALE_SESSION] Auto-completing stale session {session_id[:8]}... "
|
||||
@@ -946,7 +1071,11 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
||||
}
|
||||
|
||||
chunk_type = chunk_data.get("type")
|
||||
chunk_class = type_to_class.get(chunk_type) # type: ignore[arg-type]
|
||||
if not isinstance(chunk_type, str):
|
||||
logger.warning(f"Unknown chunk type: {chunk_type}")
|
||||
return None
|
||||
|
||||
chunk_class = type_to_class.get(chunk_type)
|
||||
|
||||
if chunk_class is None:
|
||||
logger.warning(f"Unknown chunk type: {chunk_type}")
|
||||
@@ -1011,7 +1140,7 @@ async def unsubscribe_from_session(
|
||||
|
||||
try:
|
||||
# Wait for the task to be cancelled with a timeout
|
||||
await asyncio.wait_for(listener_task, timeout=5.0)
|
||||
await asyncio.wait_for(listener_task, timeout=UNSUBSCRIBE_TIMEOUT_SECONDS)
|
||||
except asyncio.CancelledError:
|
||||
# Expected - the task was successfully cancelled
|
||||
pass
|
||||
|
||||
@@ -12,7 +12,7 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .connect_integration import ConnectIntegrationTool
|
||||
from .completion_report import CompletionReportTool
|
||||
from .continue_run_block import ContinueRunBlockTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
@@ -51,10 +51,12 @@ if TYPE_CHECKING:
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
SESSION_SCOPED_TOOL_NAMES = {"completion_report"}
|
||||
|
||||
# Single source of truth for all tools
|
||||
TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"add_understanding": AddUnderstandingTool(),
|
||||
"completion_report": CompletionReportTool(),
|
||||
"create_agent": CreateAgentTool(),
|
||||
"customize_agent": CustomizeAgentTool(),
|
||||
"edit_agent": EditAgentTool(),
|
||||
@@ -85,7 +87,6 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"browser_screenshot": BrowserScreenshotTool(),
|
||||
# Sandboxed code execution (bubblewrap)
|
||||
"bash_exec": BashExecTool(),
|
||||
"connect_integration": ConnectIntegrationTool(),
|
||||
# Persistent workspace tools (cloud storage, survives across sessions)
|
||||
# Feature request tools
|
||||
"search_feature_requests": SearchFeatureRequestsTool(),
|
||||
@@ -105,16 +106,38 @@ find_agent_tool = TOOL_REGISTRY["find_agent"]
|
||||
run_agent_tool = TOOL_REGISTRY["run_agent"]
|
||||
|
||||
|
||||
def get_available_tools() -> list[ChatCompletionToolParam]:
|
||||
def is_tool_enabled(tool_name: str, session: "ChatSession | None" = None) -> bool:
|
||||
if tool_name not in TOOL_REGISTRY:
|
||||
return False
|
||||
if session is not None and session.disables_tool(tool_name):
|
||||
return False
|
||||
if tool_name not in SESSION_SCOPED_TOOL_NAMES:
|
||||
return True
|
||||
if session is None:
|
||||
return False
|
||||
return session.allows_tool(tool_name)
|
||||
|
||||
|
||||
def iter_available_tools(
|
||||
session: "ChatSession | None" = None,
|
||||
) -> list[tuple[str, BaseTool]]:
|
||||
return [
|
||||
(tool_name, tool)
|
||||
for tool_name, tool in TOOL_REGISTRY.items()
|
||||
if tool.is_available and is_tool_enabled(tool_name, session)
|
||||
]
|
||||
|
||||
|
||||
def get_available_tools(
|
||||
session: "ChatSession | None" = None,
|
||||
) -> list[ChatCompletionToolParam]:
|
||||
"""Return OpenAI tool schemas for tools available in the current environment.
|
||||
|
||||
Called per-request so that env-var or binary availability is evaluated
|
||||
fresh each time (e.g. browser_* tools are excluded when agent-browser
|
||||
CLI is not installed).
|
||||
"""
|
||||
return [
|
||||
tool.as_openai_tool() for tool in TOOL_REGISTRY.values() if tool.is_available
|
||||
]
|
||||
return [tool.as_openai_tool() for _, tool in iter_available_tools(session)]
|
||||
|
||||
|
||||
def get_tool(tool_name: str) -> BaseTool | None:
|
||||
@@ -130,6 +153,9 @@ async def execute_tool(
|
||||
tool_call_id: str,
|
||||
) -> "StreamToolOutputAvailable":
|
||||
"""Execute a tool by name."""
|
||||
if not is_tool_enabled(tool_name, session):
|
||||
raise ValueError(f"Tool {tool_name} is not enabled for this session")
|
||||
|
||||
tool = get_tool(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
|
||||
@@ -32,7 +32,6 @@ import shutil
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.context import get_workspace_manager
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.util.request import validate_url_host
|
||||
|
||||
@@ -44,6 +43,7 @@ from .models import (
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .workspace_files import get_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -194,7 +194,7 @@ async def _save_browser_state(
|
||||
),
|
||||
}
|
||||
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
await manager.write_file(
|
||||
content=json.dumps(state).encode("utf-8"),
|
||||
filename=_STATE_FILENAME,
|
||||
@@ -218,7 +218,7 @@ async def _restore_browser_state(
|
||||
Returns True on success (or no state to restore), False on failure.
|
||||
"""
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
|
||||
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
|
||||
if file_info is None:
|
||||
@@ -360,7 +360,7 @@ async def close_browser_session(session_name: str, user_id: str | None = None) -
|
||||
# Delete persisted browser state (cookies, localStorage) from workspace.
|
||||
if user_id:
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_name)
|
||||
manager = await get_manager(user_id, session_name)
|
||||
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
|
||||
if file_info is not None:
|
||||
await manager.delete_file(file_info.id)
|
||||
|
||||
@@ -897,7 +897,7 @@ class TestHasLocalSession:
|
||||
# _save_browser_state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_GET_MANAGER = "backend.copilot.tools.agent_browser.get_workspace_manager"
|
||||
_GET_MANAGER = "backend.copilot.tools.agent_browser.get_manager"
|
||||
|
||||
|
||||
def _make_mock_manager():
|
||||
|
||||
@@ -22,7 +22,6 @@ from e2b import AsyncSandbox
|
||||
from e2b.exceptions import TimeoutException
|
||||
|
||||
from backend.copilot.context import E2B_WORKDIR, get_current_sandbox
|
||||
from backend.copilot.integration_creds import get_integration_env_vars
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
@@ -97,9 +96,7 @@ class BashExecTool(BaseTool):
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
return await self._execute_on_e2b(
|
||||
sandbox, command, timeout, session_id, user_id
|
||||
)
|
||||
return await self._execute_on_e2b(sandbox, command, timeout, session_id)
|
||||
|
||||
# Bubblewrap fallback: local isolated execution.
|
||||
if not has_full_sandbox():
|
||||
@@ -136,27 +133,14 @@ class BashExecTool(BaseTool):
|
||||
command: str,
|
||||
timeout: int,
|
||||
session_id: str | None,
|
||||
user_id: str | None = None,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute *command* on the E2B sandbox via commands.run().
|
||||
|
||||
Integration tokens (e.g. GH_TOKEN) are injected into the sandbox env
|
||||
for any user with connected accounts. E2B has full internet access, so
|
||||
CLI tools like ``gh`` work without manual authentication.
|
||||
"""
|
||||
envs: dict[str, str] = {
|
||||
"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin",
|
||||
}
|
||||
if user_id is not None:
|
||||
integration_env = await get_integration_env_vars(user_id)
|
||||
envs.update(integration_env)
|
||||
|
||||
"""Execute *command* on the E2B sandbox via commands.run()."""
|
||||
try:
|
||||
result = await sandbox.commands.run(
|
||||
f"bash -c {shlex.quote(command)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
timeout=timeout,
|
||||
envs=envs,
|
||||
envs={"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"},
|
||||
)
|
||||
return BashExecResponse(
|
||||
message=f"Command executed on E2B (exit {result.exit_code})",
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
"""Tests for BashExecTool — E2B path with token injection."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ._test_data import make_session
|
||||
from .bash_exec import BashExecTool
|
||||
from .models import BashExecResponse
|
||||
|
||||
_USER = "user-bash-exec-test"
|
||||
|
||||
|
||||
def _make_tool() -> BashExecTool:
|
||||
return BashExecTool()
|
||||
|
||||
|
||||
def _make_sandbox(exit_code: int = 0, stdout: str = "", stderr: str = "") -> MagicMock:
|
||||
result = MagicMock()
|
||||
result.exit_code = exit_code
|
||||
result.stdout = stdout
|
||||
result.stderr = stderr
|
||||
|
||||
sandbox = MagicMock()
|
||||
sandbox.commands.run = AsyncMock(return_value=result)
|
||||
return sandbox
|
||||
|
||||
|
||||
class TestBashExecE2BTokenInjection:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_token_injected_when_user_id_set(self):
|
||||
"""When user_id is provided, integration env vars are merged into sandbox envs."""
|
||||
tool = _make_tool()
|
||||
session = make_session(user_id=_USER)
|
||||
sandbox = _make_sandbox(stdout="ok")
|
||||
env_vars = {"GH_TOKEN": "gh-secret", "GITHUB_TOKEN": "gh-secret"}
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value=env_vars),
|
||||
) as mock_get_env:
|
||||
result = await tool._execute_on_e2b(
|
||||
sandbox=sandbox,
|
||||
command="echo hi",
|
||||
timeout=10,
|
||||
session_id=session.session_id,
|
||||
user_id=_USER,
|
||||
)
|
||||
|
||||
mock_get_env.assert_awaited_once_with(_USER)
|
||||
call_kwargs = sandbox.commands.run.call_args[1]
|
||||
assert call_kwargs["envs"]["GH_TOKEN"] == "gh-secret"
|
||||
assert call_kwargs["envs"]["GITHUB_TOKEN"] == "gh-secret"
|
||||
assert isinstance(result, BashExecResponse)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_token_injection_when_user_id_is_none(self):
|
||||
"""When user_id is None, get_integration_env_vars must NOT be called."""
|
||||
tool = _make_tool()
|
||||
session = make_session(user_id=_USER)
|
||||
sandbox = _make_sandbox(stdout="ok")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value={"GH_TOKEN": "should-not-appear"}),
|
||||
) as mock_get_env:
|
||||
result = await tool._execute_on_e2b(
|
||||
sandbox=sandbox,
|
||||
command="echo hi",
|
||||
timeout=10,
|
||||
session_id=session.session_id,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
mock_get_env.assert_not_called()
|
||||
call_kwargs = sandbox.commands.run.call_args[1]
|
||||
assert "GH_TOKEN" not in call_kwargs["envs"]
|
||||
assert isinstance(result, BashExecResponse)
|
||||
@@ -0,0 +1,83 @@
|
||||
"""Tool for finalizing non-manual Copilot sessions."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.constants import COPILOT_SESSION_PREFIX
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.session_types import CompletionReportInput
|
||||
from backend.data.db_accessors import review_db
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import CompletionReportSavedResponse, ErrorResponse, ToolResponseBase
|
||||
|
||||
|
||||
class CompletionReportTool(BaseTool):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "completion_report"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Finalize a non-manual session after you have finished the work. "
|
||||
"Use this exactly once at the end of the flow. "
|
||||
"Summarize what you did, state whether the user should be notified, "
|
||||
"and provide any email/callback content that should be used."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
schema = CompletionReportInput.model_json_schema()
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": schema.get("properties", {}),
|
||||
"required": [
|
||||
"thoughts",
|
||||
"should_notify_user",
|
||||
"email_title",
|
||||
"email_body",
|
||||
"callback_session_message",
|
||||
"approval_summary",
|
||||
],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
if session.is_manual:
|
||||
return ErrorResponse(
|
||||
message="completion_report is only available in non-manual sessions.",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
report = CompletionReportInput.model_validate(kwargs)
|
||||
except Exception as exc:
|
||||
return ErrorResponse(
|
||||
message="completion_report arguments are invalid.",
|
||||
error=str(exc),
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
pending_approval_count = await review_db().count_pending_reviews_for_graph_exec(
|
||||
f"{COPILOT_SESSION_PREFIX}{session.session_id}",
|
||||
session.user_id,
|
||||
)
|
||||
|
||||
if pending_approval_count > 0 and not report.approval_summary:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"approval_summary is required because this session has pending approvals."
|
||||
),
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
return CompletionReportSavedResponse(
|
||||
message="Completion report recorded successfully.",
|
||||
session_id=session.session_id,
|
||||
has_pending_approvals=pending_approval_count > 0,
|
||||
pending_approval_count=pending_approval_count,
|
||||
)
|
||||
@@ -0,0 +1,95 @@
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.session_types import ChatSessionStartType
|
||||
from backend.copilot.tools.completion_report import CompletionReportTool
|
||||
from backend.copilot.tools.models import CompletionReportSavedResponse, ResponseType
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_report_rejects_manual_sessions() -> None:
|
||||
tool = CompletionReportTool()
|
||||
session = ChatSession.new("user-1")
|
||||
|
||||
response = await tool._execute(
|
||||
user_id="user-1",
|
||||
session=session,
|
||||
thoughts="Wrapped up the session.",
|
||||
should_notify_user=False,
|
||||
email_title=None,
|
||||
email_body=None,
|
||||
callback_session_message=None,
|
||||
approval_summary=None,
|
||||
)
|
||||
|
||||
assert response.type == ResponseType.ERROR
|
||||
assert "non-manual sessions" in response.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_report_requires_approval_summary_when_pending(
|
||||
mocker,
|
||||
) -> None:
|
||||
tool = CompletionReportTool()
|
||||
session = ChatSession.new(
|
||||
"user-1",
|
||||
start_type=ChatSessionStartType.AUTOPILOT_NIGHTLY,
|
||||
)
|
||||
|
||||
review_store = Mock()
|
||||
review_store.count_pending_reviews_for_graph_exec = AsyncMock(return_value=2)
|
||||
mocker.patch(
|
||||
"backend.copilot.tools.completion_report.review_db",
|
||||
return_value=review_store,
|
||||
)
|
||||
|
||||
response = await tool._execute(
|
||||
user_id="user-1",
|
||||
session=session,
|
||||
thoughts="Prepared a recommendation for the user.",
|
||||
should_notify_user=True,
|
||||
email_title="Your nightly update",
|
||||
email_body="I found something worth reviewing.",
|
||||
callback_session_message="Let's review the next step together.",
|
||||
approval_summary=None,
|
||||
)
|
||||
|
||||
assert response.type == ResponseType.ERROR
|
||||
assert "approval_summary is required" in response.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_report_succeeds_without_pending_approvals(
|
||||
mocker,
|
||||
) -> None:
|
||||
tool = CompletionReportTool()
|
||||
session = ChatSession.new(
|
||||
"user-1",
|
||||
start_type=ChatSessionStartType.AUTOPILOT_CALLBACK,
|
||||
)
|
||||
|
||||
review_store = Mock()
|
||||
review_store.count_pending_reviews_for_graph_exec = AsyncMock(return_value=0)
|
||||
mocker.patch(
|
||||
"backend.copilot.tools.completion_report.review_db",
|
||||
return_value=review_store,
|
||||
)
|
||||
|
||||
response = await tool._execute(
|
||||
user_id="user-1",
|
||||
session=session,
|
||||
thoughts="Reviewed the account and prepared a useful follow-up.",
|
||||
should_notify_user=True,
|
||||
email_title="Autopilot found something useful",
|
||||
email_body="I put together a recommendation for you.",
|
||||
callback_session_message="Open this chat and I will walk you through it.",
|
||||
approval_summary=None,
|
||||
)
|
||||
|
||||
assert response.type == ResponseType.COMPLETION_REPORT_SAVED
|
||||
response = cast(CompletionReportSavedResponse, response)
|
||||
assert response.has_pending_approvals is False
|
||||
assert response.pending_approval_count == 0
|
||||
@@ -1,215 +0,0 @@
|
||||
"""Tool for prompting the user to connect a required integration.
|
||||
|
||||
When the copilot encounters an authentication failure (e.g. `gh` CLI returns
|
||||
"authentication required"), it calls this tool to surface the credentials
|
||||
setup card in the chat — the same UI that appears when a GitHub block runs
|
||||
without configured credentials.
|
||||
"""
|
||||
|
||||
import functools
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.models import (
|
||||
ErrorResponse,
|
||||
ResponseType,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
|
||||
|
||||
class _ProviderInfo(TypedDict):
|
||||
name: str
|
||||
types: list[str]
|
||||
# Default OAuth scopes requested when the agent doesn't specify any.
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
class _CredentialEntry(TypedDict):
|
||||
"""Shape of each entry inside SetupRequirementsResponse.user_readiness.missing_credentials."""
|
||||
|
||||
id: str
|
||||
title: str
|
||||
provider: str
|
||||
provider_name: str
|
||||
type: str
|
||||
types: list[str]
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _is_github_oauth_configured() -> bool:
|
||||
"""Return True if GitHub OAuth env vars are set.
|
||||
|
||||
Evaluated lazily (not at import time) to avoid triggering Secrets() during
|
||||
module import, which can fail in environments where secrets are not loaded.
|
||||
"""
|
||||
from backend.blocks.github._auth import GITHUB_OAUTH_IS_CONFIGURED
|
||||
|
||||
return GITHUB_OAUTH_IS_CONFIGURED
|
||||
|
||||
|
||||
# Registry of known providers: name + supported credential types for the UI.
|
||||
# When adding a new provider, also add its env var names to
|
||||
# backend.copilot.integration_creds.PROVIDER_ENV_VARS.
|
||||
def _get_provider_info() -> dict[str, _ProviderInfo]:
|
||||
"""Build the provider registry, evaluating OAuth config lazily."""
|
||||
return {
|
||||
"github": {
|
||||
"name": "GitHub",
|
||||
"types": (
|
||||
["api_key", "oauth2"] if _is_github_oauth_configured() else ["api_key"]
|
||||
),
|
||||
# Default: repo scope covers clone/push/pull for public and private repos.
|
||||
# Agent can request additional scopes (e.g. "read:org") via the scopes param.
|
||||
"scopes": ["repo"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ConnectIntegrationTool(BaseTool):
|
||||
"""Surface the credentials setup UI when an integration is not connected."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "connect_integration"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Prompt the user to connect a required integration (e.g. GitHub). "
|
||||
"Call this when an external CLI or API call fails because the user "
|
||||
"has not connected the relevant account. "
|
||||
"The tool surfaces a credentials setup card in the chat so the user "
|
||||
"can authenticate without leaving the page. "
|
||||
"After the user connects the account, retry the operation. "
|
||||
"In E2B/cloud sandbox mode the token (GH_TOKEN/GITHUB_TOKEN) is "
|
||||
"automatically injected per-command in bash_exec — no manual export needed. "
|
||||
"In local bubblewrap mode network is isolated so GitHub CLI commands "
|
||||
"will still fail after connecting; inform the user of this limitation."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"provider": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Integration provider slug, e.g. 'github'. "
|
||||
"Must be one of the supported providers."
|
||||
),
|
||||
"enum": list(_get_provider_info().keys()),
|
||||
},
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Brief explanation of why the integration is needed, "
|
||||
"shown to the user in the setup card."
|
||||
),
|
||||
"maxLength": 500,
|
||||
},
|
||||
"scopes": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"OAuth scopes to request. Omit to use the provider default. "
|
||||
"Add extra scopes when you need more access — e.g. for GitHub: "
|
||||
"'repo' (clone/push/pull), 'read:org' (org membership), "
|
||||
"'workflow' (GitHub Actions). "
|
||||
"Requesting only the scopes you actually need is best practice."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["provider"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
# Require auth so only authenticated users can trigger the setup card.
|
||||
# The card itself is user-agnostic (no per-user data needed), so
|
||||
# user_id is intentionally unused in _execute.
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
del user_id # setup card is user-agnostic; auth is enforced via requires_auth
|
||||
session_id = session.session_id if session else None
|
||||
provider: str = (kwargs.get("provider") or "").strip().lower()
|
||||
reason: str = (kwargs.get("reason") or "").strip()[
|
||||
:500
|
||||
] # cap LLM-controlled text
|
||||
extra_scopes: list[str] = [
|
||||
str(s).strip() for s in (kwargs.get("scopes") or []) if str(s).strip()
|
||||
]
|
||||
|
||||
provider_info = _get_provider_info()
|
||||
info = provider_info.get(provider)
|
||||
if not info:
|
||||
supported = ", ".join(f"'{p}'" for p in provider_info)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Unknown provider '{provider}'. "
|
||||
f"Supported providers: {supported}."
|
||||
),
|
||||
error="unknown_provider",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
provider_name: str = info["name"]
|
||||
supported_types: list[str] = info["types"]
|
||||
# Merge agent-requested scopes with provider defaults (deduplicated, order preserved).
|
||||
default_scopes: list[str] = info["scopes"]
|
||||
seen: set[str] = set()
|
||||
scopes: list[str] = []
|
||||
for s in default_scopes + extra_scopes:
|
||||
if s not in seen:
|
||||
seen.add(s)
|
||||
scopes.append(s)
|
||||
field_key = f"{provider}_credentials"
|
||||
|
||||
message_parts = [
|
||||
f"To continue, please connect your {provider_name} account.",
|
||||
]
|
||||
if reason:
|
||||
message_parts.append(reason)
|
||||
|
||||
credential_entry: _CredentialEntry = {
|
||||
"id": field_key,
|
||||
"title": f"{provider_name} Credentials",
|
||||
"provider": provider,
|
||||
"provider_name": provider_name,
|
||||
"type": supported_types[0],
|
||||
"types": supported_types,
|
||||
"scopes": scopes,
|
||||
}
|
||||
missing_credentials: dict[str, _CredentialEntry] = {field_key: credential_entry}
|
||||
|
||||
return SetupRequirementsResponse(
|
||||
type=ResponseType.SETUP_REQUIREMENTS,
|
||||
message=" ".join(message_parts),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=f"connect_{provider}",
|
||||
agent_name=provider_name,
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_credentials,
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": [missing_credentials[field_key]],
|
||||
"inputs": [],
|
||||
"execution_modes": [],
|
||||
},
|
||||
),
|
||||
)
|
||||
@@ -1,135 +0,0 @@
|
||||
"""Tests for ConnectIntegrationTool."""
|
||||
|
||||
import pytest
|
||||
|
||||
from ._test_data import make_session
|
||||
from .connect_integration import ConnectIntegrationTool
|
||||
from .models import ErrorResponse, SetupRequirementsResponse
|
||||
|
||||
_TEST_USER_ID = "test-user-connect-integration"
|
||||
|
||||
|
||||
class TestConnectIntegrationTool:
|
||||
def _make_tool(self) -> ConnectIntegrationTool:
|
||||
return ConnectIntegrationTool()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unknown_provider_returns_error(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="nonexistent"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "unknown_provider"
|
||||
assert "nonexistent" in result.message
|
||||
assert "github" in result.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_empty_provider_returns_error(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider=""
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "unknown_provider"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_github_provider_returns_setup_response(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert result.setup_info.agent_name == "GitHub"
|
||||
assert result.setup_info.agent_id == "connect_github"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_github_has_missing_credentials_in_readiness(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
readiness = result.setup_info.user_readiness
|
||||
assert readiness.has_all_credentials is False
|
||||
assert readiness.ready_to_run is False
|
||||
assert "github_credentials" in readiness.missing_credentials
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_github_requirements_include_credential_entry(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
creds = result.setup_info.requirements["credentials"]
|
||||
assert len(creds) == 1
|
||||
assert creds[0]["provider"] == "github"
|
||||
assert creds[0]["id"] == "github_credentials"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_reason_appears_in_message(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
reason = "Needed to create a pull request."
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github", reason=reason
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert reason in result.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_session_id_propagated(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert result.session_id == session.session_id
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_provider_case_insensitive(self):
|
||||
"""Provider slug is normalised to lowercase before lookup."""
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="GitHub"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
|
||||
def test_tool_name(self):
|
||||
assert ConnectIntegrationTool().name == "connect_integration"
|
||||
|
||||
def test_requires_auth(self):
|
||||
assert ConnectIntegrationTool().requires_auth is True
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unauthenticated_user_gets_need_login_response(self):
|
||||
"""execute() with user_id=None must return NeedLoginResponse, not the setup card.
|
||||
|
||||
This verifies that the requires_auth guard in BaseTool.execute() fires
|
||||
before _execute() is called, so unauthenticated callers cannot probe
|
||||
which integrations are configured.
|
||||
"""
|
||||
import json
|
||||
|
||||
tool = self._make_tool()
|
||||
# Session still needs a user_id string; the None is passed to execute()
|
||||
# to simulate an unauthenticated call.
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool.execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
tool_call_id="test-call-id",
|
||||
provider="github",
|
||||
)
|
||||
raw = result.output
|
||||
output = json.loads(raw) if isinstance(raw, str) else raw
|
||||
assert output.get("type") == "need_login"
|
||||
assert result.success is False
|
||||
@@ -16,6 +16,7 @@ class ResponseType(str, Enum):
|
||||
ERROR = "error"
|
||||
NO_RESULTS = "no_results"
|
||||
NEED_LOGIN = "need_login"
|
||||
COMPLETION_REPORT_SAVED = "completion_report_saved"
|
||||
|
||||
# Agent discovery & execution
|
||||
AGENTS_FOUND = "agents_found"
|
||||
@@ -138,7 +139,7 @@ class NoResultsResponse(ToolResponseBase):
|
||||
"""Response when no agents found."""
|
||||
|
||||
type: ResponseType = ResponseType.NO_RESULTS
|
||||
suggestions: list[str] = []
|
||||
suggestions: list[str] = Field(default_factory=list)
|
||||
name: str = "no_results"
|
||||
|
||||
|
||||
@@ -170,8 +171,8 @@ class AgentDetails(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
in_library: bool = False
|
||||
inputs: dict[str, Any] = {}
|
||||
credentials: list[CredentialsMetaInput] = []
|
||||
inputs: dict[str, Any] = Field(default_factory=dict)
|
||||
credentials: list[CredentialsMetaInput] = Field(default_factory=list)
|
||||
execution_options: ExecutionOptions = Field(default_factory=ExecutionOptions)
|
||||
trigger_info: dict[str, Any] | None = None
|
||||
|
||||
@@ -191,7 +192,7 @@ class UserReadiness(BaseModel):
|
||||
"""User readiness status."""
|
||||
|
||||
has_all_credentials: bool = False
|
||||
missing_credentials: dict[str, Any] = {}
|
||||
missing_credentials: dict[str, Any] = Field(default_factory=dict)
|
||||
ready_to_run: bool = False
|
||||
|
||||
|
||||
@@ -248,6 +249,14 @@ class ErrorResponse(ToolResponseBase):
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class CompletionReportSavedResponse(ToolResponseBase):
|
||||
"""Response for completion_report."""
|
||||
|
||||
type: ResponseType = ResponseType.COMPLETION_REPORT_SAVED
|
||||
has_pending_approvals: bool = False
|
||||
pending_approval_count: int = 0
|
||||
|
||||
|
||||
class InputValidationErrorResponse(ToolResponseBase):
|
||||
"""Response when run_agent receives unknown input fields."""
|
||||
|
||||
@@ -436,9 +445,9 @@ class BlockDetails(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
inputs: dict[str, Any] = {}
|
||||
outputs: dict[str, Any] = {}
|
||||
credentials: list[CredentialsMetaInput] = []
|
||||
inputs: dict[str, Any] = Field(default_factory=dict)
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
credentials: list[CredentialsMetaInput] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BlockDetailsResponse(ToolResponseBase):
|
||||
@@ -631,7 +640,7 @@ class FolderInfo(BaseModel):
|
||||
class FolderTreeInfo(FolderInfo):
|
||||
"""Folder with nested children for tree display."""
|
||||
|
||||
children: list["FolderTreeInfo"] = []
|
||||
children: list["FolderTreeInfo"] = Field(default_factory=list)
|
||||
|
||||
|
||||
class FolderCreatedResponse(ToolResponseBase):
|
||||
@@ -678,6 +687,6 @@ class AgentsMovedToFolderResponse(ToolResponseBase):
|
||||
|
||||
type: ResponseType = ResponseType.AGENTS_MOVED_TO_FOLDER
|
||||
agent_ids: list[str]
|
||||
agent_names: list[str] = []
|
||||
agent_names: list[str] = Field(default_factory=list)
|
||||
folder_id: str | None = None
|
||||
count: int = 0
|
||||
|
||||
@@ -12,7 +12,6 @@ from backend.copilot.constants import (
|
||||
COPILOT_SESSION_PREFIX,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.sdk.file_ref import FileRefExpansionError, expand_file_refs_in_args
|
||||
from backend.data.db_accessors import review_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
@@ -198,29 +197,6 @@ class RunBlockTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Expand @@agptfile: refs in input_data with the block's input
|
||||
# schema. The generic _truncating wrapper skips opaque object
|
||||
# properties (input_data has no declared inner properties in the
|
||||
# tool schema), so file ref tokens are still intact here.
|
||||
# Using the block's schema lets us return raw text for string-typed
|
||||
# fields and parsed structures for list/dict-typed fields.
|
||||
if input_data:
|
||||
try:
|
||||
input_data = await expand_file_refs_in_args(
|
||||
input_data,
|
||||
user_id,
|
||||
session,
|
||||
input_schema=input_schema,
|
||||
)
|
||||
except FileRefExpansionError as exc:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Failed to resolve file reference: {exc}. "
|
||||
"Ensure the file exists before referencing it."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if missing_credentials:
|
||||
# Return setup requirements response with missing credentials
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
|
||||
@@ -10,11 +10,11 @@ from pydantic import BaseModel
|
||||
from backend.copilot.context import (
|
||||
E2B_WORKDIR,
|
||||
get_current_sandbox,
|
||||
get_workspace_manager,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.sandbox import make_session_path
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
@@ -218,6 +218,12 @@ def _is_text_mime(mime_type: str) -> bool:
|
||||
return any(mime_type.startswith(t) for t in _TEXT_MIME_PREFIXES)
|
||||
|
||||
|
||||
async def get_manager(user_id: str, session_id: str) -> WorkspaceManager:
|
||||
"""Create a session-scoped WorkspaceManager."""
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
return WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
|
||||
async def _resolve_file(
|
||||
manager: WorkspaceManager,
|
||||
file_id: str | None,
|
||||
@@ -380,7 +386,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
files = await manager.list_files(
|
||||
path=path_prefix, limit=limit, include_all_sessions=include_all_sessions
|
||||
)
|
||||
@@ -530,7 +536,7 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
@@ -766,7 +772,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
|
||||
try:
|
||||
await scan_content_safe(content, filename=filename)
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
rec = await manager.write_file(
|
||||
content=content,
|
||||
filename=filename,
|
||||
@@ -893,7 +899,7 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
|
||||
@@ -92,6 +92,19 @@ def user_db():
|
||||
return user_db
|
||||
|
||||
|
||||
def invited_user_db():
|
||||
if db.is_connected():
|
||||
from backend.data import invited_user as _invited_user_db
|
||||
|
||||
invited_user_db = _invited_user_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
invited_user_db = get_database_manager_async_client()
|
||||
|
||||
return invited_user_db
|
||||
|
||||
|
||||
def understanding_db():
|
||||
if db.is_connected():
|
||||
from backend.data import understanding as _understanding_db
|
||||
|
||||
@@ -79,6 +79,7 @@ from backend.data.graph import (
|
||||
from backend.data.human_review import (
|
||||
cancel_pending_reviews_for_execution,
|
||||
check_approval,
|
||||
count_pending_reviews_for_graph_exec,
|
||||
delete_review_by_node_exec_id,
|
||||
get_or_create_human_review,
|
||||
get_pending_reviews_for_execution,
|
||||
@@ -86,6 +87,7 @@ from backend.data.human_review import (
|
||||
has_pending_reviews_for_graph_exec,
|
||||
update_review_processed_status,
|
||||
)
|
||||
from backend.data.invited_user import list_invited_users_for_auth_users
|
||||
from backend.data.notifications import (
|
||||
clear_all_user_notification_batches,
|
||||
create_or_add_to_user_notification_batch,
|
||||
@@ -107,6 +109,7 @@ from backend.data.user import (
|
||||
get_user_email_verification,
|
||||
get_user_integrations,
|
||||
get_user_notification_preference,
|
||||
list_users,
|
||||
update_user_integrations,
|
||||
)
|
||||
from backend.data.workspace import (
|
||||
@@ -115,6 +118,7 @@ from backend.data.workspace import (
|
||||
get_or_create_workspace,
|
||||
get_workspace_file,
|
||||
get_workspace_file_by_path,
|
||||
get_workspace_files_by_ids,
|
||||
list_workspace_files,
|
||||
soft_delete_workspace_file,
|
||||
)
|
||||
@@ -237,6 +241,7 @@ class DatabaseManager(AppService):
|
||||
|
||||
# ============ User + Integrations ============ #
|
||||
get_user_by_id = _(get_user_by_id)
|
||||
list_users = _(list_users)
|
||||
get_user_integrations = _(get_user_integrations)
|
||||
update_user_integrations = _(update_user_integrations)
|
||||
|
||||
@@ -249,6 +254,7 @@ class DatabaseManager(AppService):
|
||||
# ============ Human In The Loop ============ #
|
||||
cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution)
|
||||
check_approval = _(check_approval)
|
||||
count_pending_reviews_for_graph_exec = _(count_pending_reviews_for_graph_exec)
|
||||
delete_review_by_node_exec_id = _(delete_review_by_node_exec_id)
|
||||
get_or_create_human_review = _(get_or_create_human_review)
|
||||
get_pending_reviews_for_execution = _(get_pending_reviews_for_execution)
|
||||
@@ -313,12 +319,16 @@ class DatabaseManager(AppService):
|
||||
# ============ Workspace ============ #
|
||||
count_workspace_files = _(count_workspace_files)
|
||||
create_workspace_file = _(create_workspace_file)
|
||||
get_workspace_files_by_ids = _(get_workspace_files_by_ids)
|
||||
get_or_create_workspace = _(get_or_create_workspace)
|
||||
get_workspace_file = _(get_workspace_file)
|
||||
get_workspace_file_by_path = _(get_workspace_file_by_path)
|
||||
list_workspace_files = _(list_workspace_files)
|
||||
soft_delete_workspace_file = _(soft_delete_workspace_file)
|
||||
|
||||
# ============ Invited Users ============ #
|
||||
list_invited_users_for_auth_users = _(list_invited_users_for_auth_users)
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = _(get_business_understanding)
|
||||
upsert_business_understanding = _(upsert_business_understanding)
|
||||
@@ -328,8 +338,28 @@ class DatabaseManager(AppService):
|
||||
update_block_optimized_description = _(update_block_optimized_description)
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
get_chat_messages_since = _(chat_db.get_chat_messages_since)
|
||||
get_chat_session_callback_token = _(chat_db.get_chat_session_callback_token)
|
||||
get_chat_session = _(chat_db.get_chat_session)
|
||||
create_chat_session_callback_token = _(chat_db.create_chat_session_callback_token)
|
||||
create_chat_session = _(chat_db.create_chat_session)
|
||||
get_manual_chat_sessions_since = _(chat_db.get_manual_chat_sessions_since)
|
||||
get_pending_notification_chat_sessions = _(
|
||||
chat_db.get_pending_notification_chat_sessions
|
||||
)
|
||||
get_pending_notification_chat_sessions_for_user = _(
|
||||
chat_db.get_pending_notification_chat_sessions_for_user
|
||||
)
|
||||
get_recent_completion_report_chat_sessions = _(
|
||||
chat_db.get_recent_completion_report_chat_sessions
|
||||
)
|
||||
get_recent_sent_email_chat_sessions = _(chat_db.get_recent_sent_email_chat_sessions)
|
||||
has_recent_manual_message = _(chat_db.has_recent_manual_message)
|
||||
has_session_since = _(chat_db.has_session_since)
|
||||
mark_chat_session_callback_token_consumed = _(
|
||||
chat_db.mark_chat_session_callback_token_consumed
|
||||
)
|
||||
session_exists_for_execution_tag = _(chat_db.session_exists_for_execution_tag)
|
||||
update_chat_session = _(chat_db.update_chat_session)
|
||||
add_chat_message = _(chat_db.add_chat_message)
|
||||
add_chat_messages_batch = _(chat_db.add_chat_messages_batch)
|
||||
@@ -374,10 +404,18 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
get_marketplace_graphs_for_monitoring = _(d.get_marketplace_graphs_for_monitoring)
|
||||
|
||||
# Human In The Loop
|
||||
count_pending_reviews_for_graph_exec = _(d.count_pending_reviews_for_graph_exec)
|
||||
has_pending_reviews_for_graph_exec = _(d.has_pending_reviews_for_graph_exec)
|
||||
|
||||
# User Emails
|
||||
get_user_email_by_id = _(d.get_user_email_by_id)
|
||||
list_users = _(d.list_users)
|
||||
|
||||
# CoPilot Chat Sessions
|
||||
get_recent_completion_report_chat_sessions = _(
|
||||
d.get_recent_completion_report_chat_sessions
|
||||
)
|
||||
get_recent_sent_email_chat_sessions = _(d.get_recent_sent_email_chat_sessions)
|
||||
|
||||
# Library
|
||||
list_library_agents = _(d.list_library_agents)
|
||||
@@ -433,12 +471,14 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
|
||||
# ============ User + Integrations ============ #
|
||||
get_user_by_id = d.get_user_by_id
|
||||
list_users = d.list_users
|
||||
get_user_integrations = d.get_user_integrations
|
||||
update_user_integrations = d.update_user_integrations
|
||||
|
||||
# ============ Human In The Loop ============ #
|
||||
cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution
|
||||
check_approval = d.check_approval
|
||||
count_pending_reviews_for_graph_exec = d.count_pending_reviews_for_graph_exec
|
||||
delete_review_by_node_exec_id = d.delete_review_by_node_exec_id
|
||||
get_or_create_human_review = d.get_or_create_human_review
|
||||
get_pending_reviews_for_execution = d.get_pending_reviews_for_execution
|
||||
@@ -506,12 +546,16 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
# ============ Workspace ============ #
|
||||
count_workspace_files = d.count_workspace_files
|
||||
create_workspace_file = d.create_workspace_file
|
||||
get_workspace_files_by_ids = d.get_workspace_files_by_ids
|
||||
get_or_create_workspace = d.get_or_create_workspace
|
||||
get_workspace_file = d.get_workspace_file
|
||||
get_workspace_file_by_path = d.get_workspace_file_by_path
|
||||
list_workspace_files = d.list_workspace_files
|
||||
soft_delete_workspace_file = d.soft_delete_workspace_file
|
||||
|
||||
# ============ Invited Users ============ #
|
||||
list_invited_users_for_auth_users = d.list_invited_users_for_auth_users
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = d.get_business_understanding
|
||||
upsert_business_understanding = d.upsert_business_understanding
|
||||
@@ -520,8 +564,26 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
get_blocks_needing_optimization = d.get_blocks_needing_optimization
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
get_chat_messages_since = d.get_chat_messages_since
|
||||
get_chat_session_callback_token = d.get_chat_session_callback_token
|
||||
get_chat_session = d.get_chat_session
|
||||
create_chat_session_callback_token = d.create_chat_session_callback_token
|
||||
create_chat_session = d.create_chat_session
|
||||
get_manual_chat_sessions_since = d.get_manual_chat_sessions_since
|
||||
get_pending_notification_chat_sessions = d.get_pending_notification_chat_sessions
|
||||
get_pending_notification_chat_sessions_for_user = (
|
||||
d.get_pending_notification_chat_sessions_for_user
|
||||
)
|
||||
get_recent_completion_report_chat_sessions = (
|
||||
d.get_recent_completion_report_chat_sessions
|
||||
)
|
||||
get_recent_sent_email_chat_sessions = d.get_recent_sent_email_chat_sessions
|
||||
has_recent_manual_message = d.has_recent_manual_message
|
||||
has_session_since = d.has_session_since
|
||||
mark_chat_session_callback_token_consumed = (
|
||||
d.mark_chat_session_callback_token_consumed
|
||||
)
|
||||
session_exists_for_execution_tag = d.session_exists_for_execution_tag
|
||||
update_chat_session = d.update_chat_session
|
||||
add_chat_message = d.add_chat_message
|
||||
add_chat_messages_batch = d.add_chat_messages_batch
|
||||
|
||||
@@ -342,6 +342,19 @@ async def has_pending_reviews_for_graph_exec(graph_exec_id: str) -> bool:
|
||||
return count > 0
|
||||
|
||||
|
||||
async def count_pending_reviews_for_graph_exec(
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
) -> int:
|
||||
return await PendingHumanReview.prisma().count(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"graphExecId": graph_exec_id,
|
||||
"status": ReviewStatus.WAITING,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_node_id(node_exec_id: str, get_node_execution) -> str:
|
||||
"""Resolve node_id from a node_exec_id.
|
||||
|
||||
|
||||
@@ -21,8 +21,8 @@ from backend.data.model import User
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.tally import get_business_understanding_input_from_tally, mask_email
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
merge_business_understanding_data,
|
||||
parse_business_understanding_input,
|
||||
)
|
||||
from backend.data.user import get_user_by_email, get_user_by_id
|
||||
from backend.executor.cluster_lock import AsyncClusterLock
|
||||
@@ -63,18 +63,16 @@ class InvitedUserRecord(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, invited_user: "prisma.models.InvitedUser") -> "InvitedUserRecord":
|
||||
payload = (
|
||||
invited_user.tallyUnderstanding
|
||||
if isinstance(invited_user.tallyUnderstanding, dict)
|
||||
else None
|
||||
)
|
||||
payload = parse_business_understanding_input(invited_user.tallyUnderstanding)
|
||||
return cls(
|
||||
id=invited_user.id,
|
||||
email=invited_user.email,
|
||||
status=invited_user.status,
|
||||
auth_user_id=invited_user.authUserId,
|
||||
name=invited_user.name,
|
||||
tally_understanding=payload,
|
||||
tally_understanding=(
|
||||
payload.model_dump(mode="json") if payload is not None else None
|
||||
),
|
||||
tally_status=invited_user.tallyStatus,
|
||||
tally_computed_at=invited_user.tallyComputedAt,
|
||||
tally_error=invited_user.tallyError,
|
||||
@@ -185,19 +183,13 @@ async def _apply_tally_understanding(
|
||||
invited_user: "prisma.models.InvitedUser",
|
||||
tx,
|
||||
) -> None:
|
||||
if not isinstance(invited_user.tallyUnderstanding, dict):
|
||||
return
|
||||
|
||||
try:
|
||||
input_data = BusinessUnderstandingInput.model_validate(
|
||||
invited_user.tallyUnderstanding
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Malformed tallyUnderstanding for invited user %s; skipping",
|
||||
invited_user.id,
|
||||
exc_info=True,
|
||||
)
|
||||
input_data = parse_business_understanding_input(invited_user.tallyUnderstanding)
|
||||
if input_data is None:
|
||||
if invited_user.tallyUnderstanding is not None:
|
||||
logger.warning(
|
||||
"Malformed tallyUnderstanding for invited user %s; skipping",
|
||||
invited_user.id,
|
||||
)
|
||||
return
|
||||
|
||||
payload = merge_business_understanding_data({}, input_data)
|
||||
@@ -223,6 +215,18 @@ async def list_invited_users(
|
||||
return [InvitedUserRecord.from_db(iu) for iu in invited_users], total
|
||||
|
||||
|
||||
async def list_invited_users_for_auth_users(
|
||||
auth_user_ids: list[str],
|
||||
) -> list[InvitedUserRecord]:
|
||||
if not auth_user_ids:
|
||||
return []
|
||||
|
||||
invited_users = await prisma.models.InvitedUser.prisma().find_many(
|
||||
where={"authUserId": {"in": auth_user_ids}}
|
||||
)
|
||||
return [InvitedUserRecord.from_db(invited_user) for invited_user in invited_users]
|
||||
|
||||
|
||||
async def create_invited_user(
|
||||
email: str, name: Optional[str] = None
|
||||
) -> InvitedUserRecord:
|
||||
|
||||
@@ -31,6 +31,18 @@ def _json_to_list(value: Any) -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
def parse_business_understanding_input(
|
||||
payload: Any,
|
||||
) -> "BusinessUnderstandingInput | None":
|
||||
if payload is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return BusinessUnderstandingInput.model_validate(payload)
|
||||
except pydantic.ValidationError:
|
||||
return None
|
||||
|
||||
|
||||
class BusinessUnderstandingInput(pydantic.BaseModel):
|
||||
"""Input model for updating business understanding - all fields optional for incremental updates."""
|
||||
|
||||
|
||||
@@ -62,6 +62,61 @@ async def get_user_by_id(user_id: str) -> User:
|
||||
return User.from_db(user)
|
||||
|
||||
|
||||
async def list_users(
|
||||
limit: int = 500,
|
||||
cursor: str | None = None,
|
||||
) -> list[User]:
|
||||
try:
|
||||
kwargs: dict = {
|
||||
"take": limit,
|
||||
"order": {"id": "asc"},
|
||||
}
|
||||
if cursor is not None:
|
||||
kwargs["cursor"] = {"id": cursor}
|
||||
kwargs["skip"] = 1
|
||||
users = await PrismaUser.prisma().find_many(**kwargs)
|
||||
return [User.from_db(user) for user in users]
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to list users: {e}") from e
|
||||
|
||||
|
||||
async def search_users(query: str, limit: int = 20) -> list[User]:
|
||||
normalized_query = query.strip()
|
||||
if not normalized_query:
|
||||
return []
|
||||
|
||||
try:
|
||||
users = await PrismaUser.prisma().find_many(
|
||||
where={
|
||||
"OR": [
|
||||
{
|
||||
"email": {
|
||||
"contains": normalized_query,
|
||||
"mode": "insensitive",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": {
|
||||
"contains": normalized_query,
|
||||
"mode": "insensitive",
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": {
|
||||
"contains": normalized_query,
|
||||
"mode": "insensitive",
|
||||
}
|
||||
},
|
||||
]
|
||||
},
|
||||
order={"updatedAt": "desc"},
|
||||
take=limit,
|
||||
)
|
||||
return [User.from_db(user) for user in users]
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to search users for query {query!r}: {e}") from e
|
||||
|
||||
|
||||
async def get_user_email_by_id(user_id: str) -> Optional[str]:
|
||||
try:
|
||||
user = await prisma.user.find_unique(where={"id": user_id})
|
||||
|
||||
@@ -254,6 +254,23 @@ async def list_workspace_files(
|
||||
return [WorkspaceFile.from_db(f) for f in files]
|
||||
|
||||
|
||||
async def get_workspace_files_by_ids(
|
||||
workspace_id: str,
|
||||
file_ids: list[str],
|
||||
) -> list[WorkspaceFile]:
|
||||
if not file_ids:
|
||||
return []
|
||||
|
||||
files = await UserWorkspaceFile.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": file_ids},
|
||||
"workspaceId": workspace_id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
)
|
||||
return [WorkspaceFile.from_db(file) for file in files]
|
||||
|
||||
|
||||
async def count_workspace_files(
|
||||
workspace_id: str,
|
||||
path_prefix: Optional[str] = None,
|
||||
|
||||
@@ -24,6 +24,12 @@ from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
|
||||
from backend.copilot.autopilot import (
|
||||
dispatch_nightly_copilot as dispatch_nightly_copilot_async,
|
||||
)
|
||||
from backend.copilot.autopilot import (
|
||||
send_nightly_copilot_emails as send_nightly_copilot_emails_async,
|
||||
)
|
||||
from backend.copilot.optimize_blocks import optimize_block_descriptions
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.data.model import CredentialsMetaInput, GraphInput
|
||||
@@ -259,6 +265,16 @@ def cleanup_oauth_tokens():
|
||||
run_async(_cleanup())
|
||||
|
||||
|
||||
def dispatch_nightly_copilot():
|
||||
"""Dispatch proactive nightly copilot sessions."""
|
||||
return run_async(dispatch_nightly_copilot_async())
|
||||
|
||||
|
||||
def send_nightly_copilot_emails():
|
||||
"""Send emails for completed non-manual copilot sessions."""
|
||||
return run_async(send_nightly_copilot_emails_async())
|
||||
|
||||
|
||||
def execution_accuracy_alerts():
|
||||
"""Check execution accuracy and send alerts if drops are detected."""
|
||||
return report_execution_accuracy_alerts()
|
||||
@@ -404,7 +420,7 @@ class GraphExecutionJobInfo(GraphExecutionJobArgs):
|
||||
) -> "GraphExecutionJobInfo":
|
||||
# Extract timezone from the trigger if it's a CronTrigger
|
||||
timezone_str = "UTC"
|
||||
if hasattr(job_obj.trigger, "timezone"):
|
||||
if isinstance(job_obj.trigger, CronTrigger):
|
||||
timezone_str = str(job_obj.trigger.timezone)
|
||||
|
||||
return GraphExecutionJobInfo(
|
||||
@@ -619,6 +635,24 @@ class Scheduler(AppService):
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
self.scheduler.add_job(
|
||||
dispatch_nightly_copilot,
|
||||
id="dispatch_nightly_copilot",
|
||||
trigger=CronTrigger(minute="0,30", timezone=ZoneInfo("UTC")),
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
self.scheduler.add_job(
|
||||
send_nightly_copilot_emails,
|
||||
id="send_nightly_copilot_emails",
|
||||
trigger=CronTrigger(minute="15,45", timezone=ZoneInfo("UTC")),
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
|
||||
self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED)
|
||||
self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES)
|
||||
@@ -793,6 +827,14 @@ class Scheduler(AppService):
|
||||
"""Manually trigger embedding backfill for approved store agents."""
|
||||
return ensure_embeddings_coverage()
|
||||
|
||||
@expose
|
||||
def execute_dispatch_nightly_copilot(self):
|
||||
return dispatch_nightly_copilot()
|
||||
|
||||
@expose
|
||||
def execute_send_nightly_copilot_emails(self):
|
||||
return send_nightly_copilot_emails()
|
||||
|
||||
|
||||
class SchedulerClient(AppServiceClient):
|
||||
@classmethod
|
||||
|
||||
@@ -25,35 +25,6 @@ logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
|
||||
_on_creds_changed: Callable[[str, str], None] | None = None
|
||||
|
||||
|
||||
def register_creds_changed_hook(hook: Callable[[str, str], None]) -> None:
|
||||
"""Register a callback invoked after any credential is created/updated/deleted.
|
||||
|
||||
The callback receives ``(user_id, provider)`` and should be idempotent.
|
||||
Only one hook can be registered at a time; calling this again replaces the
|
||||
previous hook. Intended to be called once at application startup by the
|
||||
copilot module to bust its token cache without creating an import cycle.
|
||||
"""
|
||||
global _on_creds_changed
|
||||
_on_creds_changed = hook
|
||||
|
||||
|
||||
def _bust_copilot_cache(user_id: str, provider: str) -> None:
|
||||
"""Invoke the registered hook (if any) to bust downstream token caches."""
|
||||
if _on_creds_changed is not None:
|
||||
try:
|
||||
_on_creds_changed(user_id, provider)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Credential-change hook failed for user=%s provider=%s",
|
||||
user_id,
|
||||
provider,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
class IntegrationCredentialsManager:
|
||||
"""
|
||||
Handles the lifecycle of integration credentials.
|
||||
@@ -98,11 +69,7 @@ class IntegrationCredentialsManager:
|
||||
return self._locks
|
||||
|
||||
async def create(self, user_id: str, credentials: Credentials) -> None:
|
||||
result = await self.store.add_creds(user_id, credentials)
|
||||
# Bust the copilot token cache so that the next bash_exec picks up the
|
||||
# new credential immediately instead of waiting for _NULL_CACHE_TTL.
|
||||
_bust_copilot_cache(user_id, credentials.provider)
|
||||
return result
|
||||
return await self.store.add_creds(user_id, credentials)
|
||||
|
||||
async def exists(self, user_id: str, credentials_id: str) -> bool:
|
||||
return (await self.store.get_creds_by_id(user_id, credentials_id)) is not None
|
||||
@@ -189,8 +156,6 @@ class IntegrationCredentialsManager:
|
||||
|
||||
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
|
||||
await self.store.update_creds(user_id, fresh_credentials)
|
||||
# Bust copilot cache so the refreshed token is picked up immediately.
|
||||
_bust_copilot_cache(user_id, fresh_credentials.provider)
|
||||
if _lock and (await _lock.locked()) and (await _lock.owned()):
|
||||
try:
|
||||
await _lock.release()
|
||||
@@ -203,17 +168,10 @@ class IntegrationCredentialsManager:
|
||||
async def update(self, user_id: str, updated: Credentials) -> None:
|
||||
async with self._locked(user_id, updated.id):
|
||||
await self.store.update_creds(user_id, updated)
|
||||
# Bust the copilot token cache so the updated credential is picked up immediately.
|
||||
_bust_copilot_cache(user_id, updated.provider)
|
||||
|
||||
async def delete(self, user_id: str, credentials_id: str) -> None:
|
||||
async with self._locked(user_id, credentials_id):
|
||||
# Read inside the lock to avoid TOCTOU — another coroutine could
|
||||
# delete the same credential between the read and the delete.
|
||||
creds = await self.store.get_creds_by_id(user_id, credentials_id)
|
||||
await self.store.delete_creds_by_id(user_id, credentials_id)
|
||||
if creds:
|
||||
_bust_copilot_cache(user_id, creds.provider)
|
||||
|
||||
# -- Locking utilities -- #
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import pathlib
|
||||
from typing import Any
|
||||
|
||||
from postmarker.core import PostmarkClient
|
||||
from postmarker.models.emails import EmailManager
|
||||
@@ -7,12 +8,14 @@ from prisma.enums import NotificationType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
NotificationDataType_co,
|
||||
NotificationEventModel,
|
||||
NotificationTypeOverride,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.text import TextFormatter
|
||||
from backend.util.url import get_frontend_base_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
@@ -46,6 +49,102 @@ class EmailSender:
|
||||
|
||||
MAX_EMAIL_CHARS = 5_000_000 # ~5MB buffer
|
||||
|
||||
def _get_unsubscribe_link(self, user_unsubscribe_link: str | None) -> str:
|
||||
return user_unsubscribe_link or f"{get_frontend_base_url()}/profile/settings"
|
||||
|
||||
def _format_template_email(
|
||||
self,
|
||||
*,
|
||||
subject_template: str,
|
||||
content_template: str,
|
||||
data: Any,
|
||||
unsubscribe_link: str,
|
||||
) -> tuple[str, str]:
|
||||
return self.formatter.format_email(
|
||||
base_template=self._read_template("templates/base.html.jinja2"),
|
||||
subject_template=subject_template,
|
||||
content_template=content_template,
|
||||
data=data,
|
||||
unsubscribe_link=unsubscribe_link,
|
||||
)
|
||||
|
||||
def _build_large_output_summary(
|
||||
self,
|
||||
data: (
|
||||
NotificationEventModel[NotificationDataType_co]
|
||||
| list[NotificationEventModel[NotificationDataType_co]]
|
||||
),
|
||||
*,
|
||||
email_size: int,
|
||||
base_url: str,
|
||||
) -> str:
|
||||
if isinstance(data, list):
|
||||
if not data:
|
||||
return (
|
||||
"⚠️ A notification generated a very large output "
|
||||
f"({email_size / 1_000_000:.2f} MB)."
|
||||
)
|
||||
event = data[0]
|
||||
else:
|
||||
event = data
|
||||
execution_url = (
|
||||
f"{base_url}/executions/{event.id}" if event.id is not None else None
|
||||
)
|
||||
|
||||
if isinstance(event.data, AgentRunData):
|
||||
lines = [
|
||||
f"⚠️ Your agent '{event.data.agent_name}' generated a very large output ({email_size / 1_000_000:.2f} MB).",
|
||||
"",
|
||||
f"Execution time: {event.data.execution_time}",
|
||||
f"Credits used: {event.data.credits_used}",
|
||||
]
|
||||
if execution_url is not None:
|
||||
lines.append(f"View full results: {execution_url}")
|
||||
return "\n".join(lines)
|
||||
|
||||
lines = [
|
||||
f"⚠️ A notification generated a very large output ({email_size / 1_000_000:.2f} MB).",
|
||||
]
|
||||
if execution_url is not None:
|
||||
lines.extend(["", f"View full results: {execution_url}"])
|
||||
return "\n".join(lines)
|
||||
|
||||
def send_template(
|
||||
self,
|
||||
*,
|
||||
user_email: str,
|
||||
subject: str,
|
||||
template_name: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
user_unsubscribe_link: str | None = None,
|
||||
) -> None:
|
||||
"""Send an email using a named Jinja2 template file.
|
||||
|
||||
Unlike ``send_templated`` (which resolves templates via
|
||||
``NotificationType``), this method accepts a template filename
|
||||
directly. Both delegate to the shared ``_format_template_email``
|
||||
+ ``_send_email`` pipeline.
|
||||
"""
|
||||
if not self.postmark:
|
||||
logger.warning("Postmark client not initialized, email not sent")
|
||||
return
|
||||
|
||||
unsubscribe_link = self._get_unsubscribe_link(user_unsubscribe_link)
|
||||
|
||||
_, full_message = self._format_template_email(
|
||||
subject_template="{{ subject }}",
|
||||
content_template=self._read_template(f"templates/{template_name}"),
|
||||
data={"subject": subject, **(data or {})},
|
||||
unsubscribe_link=unsubscribe_link,
|
||||
)
|
||||
|
||||
self._send_email(
|
||||
user_email=user_email,
|
||||
subject=subject,
|
||||
body=full_message,
|
||||
user_unsubscribe_link=unsubscribe_link,
|
||||
)
|
||||
|
||||
def send_templated(
|
||||
self,
|
||||
notification: NotificationType,
|
||||
@@ -62,21 +161,18 @@ class EmailSender:
|
||||
return
|
||||
|
||||
template = self._get_template(notification)
|
||||
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
base_url = get_frontend_base_url()
|
||||
unsubscribe_link = self._get_unsubscribe_link(user_unsub_link)
|
||||
|
||||
# Normalize data
|
||||
template_data = {"notifications": data} if isinstance(data, list) else data
|
||||
|
||||
try:
|
||||
subject, full_message = self.formatter.format_email(
|
||||
base_template=template.base_template,
|
||||
subject, full_message = self._format_template_email(
|
||||
subject_template=template.subject_template,
|
||||
content_template=template.body_template,
|
||||
data=template_data,
|
||||
unsubscribe_link=f"{base_url}/profile/settings",
|
||||
unsubscribe_link=unsubscribe_link,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting full message: {e}")
|
||||
@@ -90,20 +186,17 @@ class EmailSender:
|
||||
"Sending summary email instead."
|
||||
)
|
||||
|
||||
# Create lightweight summary
|
||||
summary_message = (
|
||||
f"⚠️ Your agent '{getattr(data, 'agent_name', 'Unknown')}' "
|
||||
f"generated a very large output ({email_size / 1_000_000:.2f} MB).\n\n"
|
||||
f"Execution time: {getattr(data, 'execution_time', 'N/A')}\n"
|
||||
f"Credits used: {getattr(data, 'credits_used', 'N/A')}\n"
|
||||
f"View full results: {base_url}/executions/{getattr(data, 'id', 'N/A')}"
|
||||
summary_message = self._build_large_output_summary(
|
||||
data,
|
||||
email_size=email_size,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
self._send_email(
|
||||
user_email=user_email,
|
||||
subject=f"{subject} (Output Too Large)",
|
||||
body=summary_message,
|
||||
user_unsubscribe_link=user_unsub_link,
|
||||
user_unsubscribe_link=unsubscribe_link,
|
||||
)
|
||||
return # Skip sending full email
|
||||
|
||||
@@ -112,7 +205,7 @@ class EmailSender:
|
||||
user_email=user_email,
|
||||
subject=subject,
|
||||
body=full_message,
|
||||
user_unsubscribe_link=user_unsub_link,
|
||||
user_unsubscribe_link=unsubscribe_link,
|
||||
)
|
||||
|
||||
def _get_template(self, notification: NotificationType):
|
||||
@@ -123,17 +216,18 @@ class EmailSender:
|
||||
logger.debug(
|
||||
f"Template full path: {pathlib.Path(__file__).parent / template_path}"
|
||||
)
|
||||
base_template_path = "templates/base.html.jinja2"
|
||||
with open(pathlib.Path(__file__).parent / base_template_path, "r") as file:
|
||||
base_template = file.read()
|
||||
with open(pathlib.Path(__file__).parent / template_path, "r") as file:
|
||||
template = file.read()
|
||||
base_template = self._read_template("templates/base.html.jinja2")
|
||||
template = self._read_template(template_path)
|
||||
return Template(
|
||||
subject_template=notification_type_override.subject,
|
||||
body_template=template,
|
||||
base_template=base_template,
|
||||
)
|
||||
|
||||
def _read_template(self, template_path: str) -> str:
|
||||
with open(pathlib.Path(__file__).parent / template_path, "r") as file:
|
||||
return file.read()
|
||||
|
||||
def _send_email(
|
||||
self,
|
||||
user_email: str,
|
||||
@@ -144,18 +238,33 @@ class EmailSender:
|
||||
if not self.postmark:
|
||||
logger.warning("Email tried to send without postmark configured")
|
||||
return
|
||||
sender_email = settings.config.postmark_sender_email
|
||||
if not sender_email:
|
||||
logger.warning("postmark_sender_email not configured, email not sent")
|
||||
return
|
||||
unsubscribe_link = self._get_unsubscribe_link(user_unsubscribe_link)
|
||||
logger.debug(f"Sending email to {user_email} with subject {subject}")
|
||||
self.postmark.emails.send(
|
||||
From=settings.config.postmark_sender_email,
|
||||
From=sender_email,
|
||||
To=user_email,
|
||||
Subject=subject,
|
||||
HtmlBody=body,
|
||||
Headers=(
|
||||
{
|
||||
"List-Unsubscribe-Post": "List-Unsubscribe=One-Click",
|
||||
"List-Unsubscribe": f"<{user_unsubscribe_link}>",
|
||||
}
|
||||
if user_unsubscribe_link
|
||||
else None
|
||||
),
|
||||
Headers={
|
||||
"List-Unsubscribe-Post": "List-Unsubscribe=One-Click",
|
||||
"List-Unsubscribe": f"<{unsubscribe_link}>",
|
||||
},
|
||||
)
|
||||
|
||||
def send_html(
|
||||
self,
|
||||
user_email: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
user_unsubscribe_link: str | None = None,
|
||||
) -> None:
|
||||
self._send_email(
|
||||
user_email=user_email,
|
||||
subject=subject,
|
||||
body=body,
|
||||
user_unsubscribe_link=user_unsubscribe_link,
|
||||
)
|
||||
|
||||
241
autogpt_platform/backend/backend/notifications/email_test.py
Normal file
241
autogpt_platform/backend/backend/notifications/email_test.py
Normal file
@@ -0,0 +1,241 @@
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
|
||||
from backend.api.test_helpers import override_config
|
||||
from backend.copilot.autopilot_email import _markdown_to_email_html
|
||||
from backend.notifications.email import EmailSender, settings
|
||||
from backend.util.settings import AppEnvironment
|
||||
|
||||
|
||||
def test_markdown_to_email_html_renders_bold_and_italic() -> None:
|
||||
html = _markdown_to_email_html("**bold** and *italic*")
|
||||
assert "<strong>bold</strong>" in html
|
||||
assert "<em>italic</em>" in html
|
||||
assert 'style="' in html
|
||||
|
||||
|
||||
def test_markdown_to_email_html_renders_links() -> None:
|
||||
html = _markdown_to_email_html("[click here](https://example.com)")
|
||||
assert 'href="https://example.com"' in html
|
||||
assert "click here" in html
|
||||
assert "color: #7733F5" in html
|
||||
|
||||
|
||||
def test_markdown_to_email_html_renders_bullet_list() -> None:
|
||||
html = _markdown_to_email_html("- item one\n- item two")
|
||||
assert "<ul" in html
|
||||
assert "<li" in html
|
||||
assert "item one" in html
|
||||
assert "item two" in html
|
||||
|
||||
|
||||
def test_markdown_to_email_html_handles_empty_input() -> None:
|
||||
assert _markdown_to_email_html(None) == ""
|
||||
assert _markdown_to_email_html("") == ""
|
||||
assert _markdown_to_email_html(" ") == ""
|
||||
|
||||
|
||||
def test_send_template_renders_nightly_copilot_email(mocker) -> None:
|
||||
sender = EmailSender()
|
||||
sender.postmark = cast(Any, object())
|
||||
send_email = mocker.patch.object(sender, "_send_email")
|
||||
|
||||
sender.send_template(
|
||||
user_email="user@example.com",
|
||||
subject="Autopilot update",
|
||||
template_name="nightly_copilot.html.jinja2",
|
||||
data={
|
||||
"email_body_html": _markdown_to_email_html(
|
||||
"I found something useful for you.\n\n"
|
||||
"Open Copilot and I will walk you through it."
|
||||
),
|
||||
"cta_url": "https://example.com/copilot?callbackToken=token-1",
|
||||
"cta_label": "Open Copilot",
|
||||
},
|
||||
)
|
||||
|
||||
body = send_email.call_args.kwargs["body"]
|
||||
|
||||
assert "I found something useful for you." in body
|
||||
assert "Open Copilot" in body
|
||||
assert "Approval needed" not in body
|
||||
assert send_email.call_args.kwargs["user_unsubscribe_link"].endswith(
|
||||
"/profile/settings"
|
||||
)
|
||||
|
||||
|
||||
def test_send_template_renders_nightly_copilot_approval_block(mocker) -> None:
|
||||
sender = EmailSender()
|
||||
sender.postmark = cast(Any, object())
|
||||
send_email = mocker.patch.object(sender, "_send_email")
|
||||
|
||||
sender.send_template(
|
||||
user_email="user@example.com",
|
||||
subject="Autopilot update",
|
||||
template_name="nightly_copilot.html.jinja2",
|
||||
data={
|
||||
"email_body_html": _markdown_to_email_html(
|
||||
"I prepared a change worth reviewing."
|
||||
),
|
||||
"approval_summary_html": _markdown_to_email_html(
|
||||
"I drafted a follow-up because it matches your recent activity."
|
||||
),
|
||||
"cta_url": "https://example.com/copilot?sessionId=session-1&showAutopilot=1",
|
||||
"cta_label": "Review in Copilot",
|
||||
},
|
||||
)
|
||||
|
||||
body = send_email.call_args.kwargs["body"]
|
||||
|
||||
assert "Approval needed" in body
|
||||
assert "If you want it to happen, please hit approve." in body
|
||||
assert "Review in Copilot" in body
|
||||
|
||||
|
||||
def test_send_template_renders_nightly_copilot_callback_email(mocker) -> None:
|
||||
sender = EmailSender()
|
||||
sender.postmark = cast(Any, object())
|
||||
send_email = mocker.patch.object(sender, "_send_email")
|
||||
|
||||
sender.send_template(
|
||||
user_email="user@example.com",
|
||||
subject="Autopilot update",
|
||||
template_name="nightly_copilot_callback.html.jinja2",
|
||||
data={
|
||||
"email_body_html": _markdown_to_email_html(
|
||||
"I prepared a follow-up based on your recent work."
|
||||
),
|
||||
"cta_url": "https://example.com/copilot?callbackToken=token-1",
|
||||
"cta_label": "Open Copilot",
|
||||
},
|
||||
)
|
||||
|
||||
body = send_email.call_args.kwargs["body"]
|
||||
|
||||
assert "Autopilot picked up where you left off" in body
|
||||
assert "I prepared a follow-up based on your recent work." in body
|
||||
|
||||
|
||||
def test_send_template_renders_nightly_copilot_callback_approval_block(mocker) -> None:
|
||||
sender = EmailSender()
|
||||
sender.postmark = cast(Any, object())
|
||||
send_email = mocker.patch.object(sender, "_send_email")
|
||||
|
||||
sender.send_template(
|
||||
user_email="user@example.com",
|
||||
subject="Autopilot update",
|
||||
template_name="nightly_copilot_callback.html.jinja2",
|
||||
data={
|
||||
"email_body_html": _markdown_to_email_html(
|
||||
"I prepared a follow-up based on your recent work."
|
||||
),
|
||||
"approval_summary_html": _markdown_to_email_html(
|
||||
"I want your approval before I apply the next step."
|
||||
),
|
||||
"cta_url": "https://example.com/copilot?sessionId=session-1&showAutopilot=1",
|
||||
"cta_label": "Review in Copilot",
|
||||
},
|
||||
)
|
||||
|
||||
body = send_email.call_args.kwargs["body"]
|
||||
|
||||
assert "Approval needed" in body
|
||||
assert "I want your approval before I apply the next step." in body
|
||||
|
||||
|
||||
def test_send_template_renders_nightly_copilot_invite_cta_email(mocker) -> None:
|
||||
sender = EmailSender()
|
||||
sender.postmark = cast(Any, object())
|
||||
send_email = mocker.patch.object(sender, "_send_email")
|
||||
|
||||
sender.send_template(
|
||||
user_email="user@example.com",
|
||||
subject="Autopilot update",
|
||||
template_name="nightly_copilot_invite_cta.html.jinja2",
|
||||
data={
|
||||
"email_body_html": _markdown_to_email_html(
|
||||
"I put together an example of how Autopilot could help you."
|
||||
),
|
||||
"cta_url": "https://example.com/copilot?callbackToken=token-1",
|
||||
"cta_label": "Try Copilot",
|
||||
},
|
||||
)
|
||||
|
||||
body = send_email.call_args.kwargs["body"]
|
||||
|
||||
assert "Your Autopilot beta access is waiting" in body
|
||||
assert "I put together an example of how Autopilot could help you." in body
|
||||
assert "Try Copilot" in body
|
||||
|
||||
|
||||
def test_send_template_renders_nightly_copilot_invite_cta_approval_block(
|
||||
mocker,
|
||||
) -> None:
|
||||
sender = EmailSender()
|
||||
sender.postmark = cast(Any, object())
|
||||
send_email = mocker.patch.object(sender, "_send_email")
|
||||
|
||||
sender.send_template(
|
||||
user_email="user@example.com",
|
||||
subject="Autopilot update",
|
||||
template_name="nightly_copilot_invite_cta.html.jinja2",
|
||||
data={
|
||||
"email_body_html": _markdown_to_email_html(
|
||||
"I put together an example of how Autopilot could help you."
|
||||
),
|
||||
"approval_summary_html": _markdown_to_email_html(
|
||||
"If this looks useful, approve the next step to try it."
|
||||
),
|
||||
"cta_url": "https://example.com/copilot?sessionId=session-1&showAutopilot=1",
|
||||
"cta_label": "Review in Copilot",
|
||||
},
|
||||
)
|
||||
|
||||
body = send_email.call_args.kwargs["body"]
|
||||
|
||||
assert "Approval needed" in body
|
||||
assert "If this looks useful, approve the next step to try it." in body
|
||||
|
||||
|
||||
def test_send_template_still_sends_in_production(mocker) -> None:
|
||||
sender = EmailSender()
|
||||
sender.postmark = cast(Any, object())
|
||||
send_email = mocker.patch.object(sender, "_send_email")
|
||||
|
||||
with override_config(settings, "app_env", AppEnvironment.PRODUCTION):
|
||||
sender.send_template(
|
||||
user_email="user@example.com",
|
||||
subject="Autopilot update",
|
||||
template_name="nightly_copilot.html.jinja2",
|
||||
data={
|
||||
"email_body_html": _markdown_to_email_html(
|
||||
"I found something useful for you."
|
||||
),
|
||||
"cta_url": "https://example.com/copilot?callbackToken=token-1",
|
||||
"cta_label": "Open Copilot",
|
||||
},
|
||||
)
|
||||
|
||||
send_email.assert_called_once()
|
||||
|
||||
|
||||
def test_send_html_uses_default_unsubscribe_link(mocker) -> None:
|
||||
sender = EmailSender()
|
||||
send = mocker.Mock()
|
||||
sender.postmark = cast(Any, SimpleNamespace(emails=SimpleNamespace(send=send)))
|
||||
|
||||
mocker.patch(
|
||||
"backend.notifications.email.get_frontend_base_url",
|
||||
return_value="https://example.com",
|
||||
)
|
||||
with override_config(settings, "postmark_sender_email", "test@example.com"):
|
||||
sender.send_html(
|
||||
user_email="user@example.com",
|
||||
subject="Autopilot update",
|
||||
body="<p>Hello</p>",
|
||||
)
|
||||
|
||||
headers = send.call_args.kwargs["Headers"]
|
||||
|
||||
assert headers["List-Unsubscribe-Post"] == "List-Unsubscribe=One-Click"
|
||||
assert headers["List-Unsubscribe"] == "<https://example.com/profile/settings>"
|
||||
@@ -29,7 +29,6 @@
|
||||
</noscript>
|
||||
<![endif]-->
|
||||
<style type="text/css">
|
||||
/* RESET STYLES */
|
||||
html,
|
||||
body {
|
||||
margin: 0 !important;
|
||||
@@ -85,7 +84,6 @@
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
/* iOS BLUE LINKS */
|
||||
a[x-apple-data-detectors] {
|
||||
color: inherit !important;
|
||||
text-decoration: none !important;
|
||||
@@ -95,13 +93,11 @@
|
||||
line-height: inherit !important;
|
||||
}
|
||||
|
||||
/* ANDROID CENTER FIX */
|
||||
div[style*="margin: 16px 0;"] {
|
||||
margin: 0 !important;
|
||||
}
|
||||
|
||||
/* MEDIA QUERIES */
|
||||
@media all and (max-width:639px) {
|
||||
@media all and (max-width: 639px) {
|
||||
.wrapper {
|
||||
width: 100% !important;
|
||||
}
|
||||
@@ -113,8 +109,8 @@
|
||||
}
|
||||
|
||||
.row {
|
||||
padding-left: 20px !important;
|
||||
padding-right: 20px !important;
|
||||
padding-left: 24px !important;
|
||||
padding-right: 24px !important;
|
||||
}
|
||||
|
||||
.col-mobile {
|
||||
@@ -136,11 +132,6 @@
|
||||
float: none !important;
|
||||
}
|
||||
|
||||
.mobile-left {
|
||||
text-align: center !important;
|
||||
float: left !important;
|
||||
}
|
||||
|
||||
.mobile-hide {
|
||||
display: none !important;
|
||||
}
|
||||
@@ -155,9 +146,9 @@
|
||||
max-width: 100% !important;
|
||||
}
|
||||
|
||||
.ml-btn-container {
|
||||
width: 100% !important;
|
||||
max-width: 100% !important;
|
||||
.card-inner {
|
||||
padding-left: 24px !important;
|
||||
padding-right: 24px !important;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
@@ -174,170 +165,139 @@
|
||||
<title>{{data.title}}</title>
|
||||
</head>
|
||||
|
||||
<body style="margin: 0 !important; padding: 0 !important; background-color:#070629;">
|
||||
<body style="margin: 0 !important; padding: 0 !important; background-color: #070629;">
|
||||
<div class="document" role="article" aria-roledescription="email" aria-label lang dir="ltr"
|
||||
style="background-color:#070629; line-height: 100%; font-size:medium; font-size:max(16px, 1rem);">
|
||||
<!-- Main Content -->
|
||||
style="background-color: #070629; line-height: 100%; font-size: medium; font-size: max(16px, 1rem);">
|
||||
|
||||
<table width="100%" align="center" cellspacing="0" cellpadding="0" border="0">
|
||||
<tr>
|
||||
<td class="background" bgcolor="#070629" align="center" valign="top" style="padding: 0 8px;">
|
||||
<!-- Email Content -->
|
||||
<td align="center" valign="top" style="padding: 48px 16px 40px;">
|
||||
|
||||
<!-- ============ CARD ============ -->
|
||||
<table class="container" align="center" width="640" cellpadding="0" cellspacing="0" border="0"
|
||||
style="max-width: 640px;">
|
||||
|
||||
<!-- Gradient Accent Strip -->
|
||||
<tr>
|
||||
<td align="center">
|
||||
<!-- Logo Section -->
|
||||
<table class="container ml-4 ml-default-border" width="640" bgcolor="#E2ECFD" align="center" border="0"
|
||||
cellspacing="0" cellpadding="0" style="width: 640px; min-width: 640px;">
|
||||
<tr>
|
||||
<td class="ml-default-border container" height="40" style="line-height: 40px; min-width: 640px;">
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
|
||||
<tr>
|
||||
<td class="row" align="center" style="padding: 0 50px;">
|
||||
<img
|
||||
src="https://storage.mlcdn.com/account_image/597379/8QJ8kOjXakVvfe1kJLY2wWCObU1mp5EiDLfBlbQa.png"
|
||||
border="0" alt="" width="120" class="logo"
|
||||
style="max-width: 120px; display: inline-block;">
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
<td bgcolor="#7733F5"
|
||||
style="background: linear-gradient(90deg, #7733F5 0%, #60A5FA 35%, #EC4899 65%, #7733F5 100%); height: 6px; border-radius: 16px 16px 0 0; font-size: 0; line-height: 0;">
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<!-- Main Content Section -->
|
||||
<table class="container ml-6 ml-default-border" width="640" bgcolor="#E2ECFD" align="center" border="0"
|
||||
cellspacing="0" cellpadding="0" style="color: #070629; width: 640px; min-width: 640px;">
|
||||
<tr>
|
||||
<td class="row" style="padding: 0 50px;">
|
||||
{{data.message|safe}}
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
<!-- Logo -->
|
||||
<tr>
|
||||
<td bgcolor="#FFFFFF" align="center" class="card-inner" style="padding: 32px 48px 24px;">
|
||||
<img
|
||||
src="https://storage.mlcdn.com/account_image/597379/8QJ8kOjXakVvfe1kJLY2wWCObU1mp5EiDLfBlbQa.png"
|
||||
border="0" alt="AutoGPT" width="120"
|
||||
style="max-width: 120px; display: inline-block;">
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<!-- Signature Section -->
|
||||
<table class="container ml-8 ml-default-border" width="640" bgcolor="#E2ECFD" align="center" border="0"
|
||||
cellspacing="0" cellpadding="0" style="color: #070629; width: 640px; min-width: 640px;">
|
||||
<!-- Divider -->
|
||||
<tr>
|
||||
<td bgcolor="#FFFFFF" class="card-inner" style="padding: 0 48px;">
|
||||
<table width="100%" cellpadding="0" cellspacing="0" border="0">
|
||||
<tr>
|
||||
<td class="row mobile-center" align="left" style="padding: 0 50px;">
|
||||
<table class="ml-8 wrapper" border="0" cellspacing="0" cellpadding="0"
|
||||
style="color: #070629; text-align: left;">
|
||||
<tr>
|
||||
<td class="col center mobile-center" align>
|
||||
<p
|
||||
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-top: 0; margin-bottom: 0;">
|
||||
Thank you for being a part of the AutoGPT community! Join the conversation on our Discord <a href="https://discord.gg/autogpt" style="color: #4285F4; text-decoration: underline;">here</a> and share your thoughts with us anytime.
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<!-- Footer Section -->
|
||||
<table class="container ml-10 ml-default-border" width="640" bgcolor="#ffffff" align="center" border="0"
|
||||
cellspacing="0" cellpadding="0" style="width: 640px; min-width: 640px;">
|
||||
<tr>
|
||||
<td class="row" style="padding: 0 50px;">
|
||||
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
|
||||
<tr>
|
||||
<td>
|
||||
<!-- Footer Content -->
|
||||
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
|
||||
<tr>
|
||||
<td class="col" align="left" valign="middle" width="120">
|
||||
<img
|
||||
src="https://storage.mlcdn.com/account_image/597379/8QJ8kOjXakVvfe1kJLY2wWCObU1mp5EiDLfBlbQa.png"
|
||||
border="0" alt="" width="120" class="logo"
|
||||
style="max-width: 120px; display: inline-block;">
|
||||
</td>
|
||||
<td class="col" width="40" height="30" style="line-height: 30px;"></td>
|
||||
<td class="col mobile-left" align="right" valign="middle" width="250">
|
||||
<table role="presentation" cellpadding="0" cellspacing="0" border="0">
|
||||
<tr>
|
||||
<td align="center" valign="middle" width="18" style="padding: 0 5px 0 0;">
|
||||
<a href="https://x.com/auto_gpt" target="blank" style="text-decoration: none;">
|
||||
<img
|
||||
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/x.png"
|
||||
width="18" alt="x">
|
||||
</a>
|
||||
</td>
|
||||
<td align="center" valign="middle" width="18" style="padding: 0 5px;">
|
||||
<a href="https://discord.gg/autogpt" target="blank"
|
||||
style="text-decoration: none;">
|
||||
<img
|
||||
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/discord.png"
|
||||
width="18" alt="discord">
|
||||
</a>
|
||||
</td>
|
||||
<td align="center" valign="middle" width="18" style="padding: 0 0 0 5px;">
|
||||
<a href="https://agpt.co/" target="blank" style="text-decoration: none;">
|
||||
<img
|
||||
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/website.png"
|
||||
width="18" alt="website">
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center" style="text-align: left!important;">
|
||||
<h5
|
||||
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 15px; line-height: 125%; font-weight: bold; font-style: normal; text-decoration: none; margin-bottom: 6px;">
|
||||
AutoGPT
|
||||
</h5>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center" style="text-align: left!important;">
|
||||
<p
|
||||
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 14px; line-height: 150%; display: inline-block; margin-bottom: 0;">
|
||||
3rd Floor 1 Ashley Road, Cheshire, United Kingdom, WA14 2DT, Altrincham<br>United Kingdom
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td height="8" style="line-height: 8px;"></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left" style="text-align: left!important;">
|
||||
<p
|
||||
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 14px; line-height: 150%; display: inline-block; margin-bottom: 0;">
|
||||
You received this email because you signed up on our website.</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td height="1" style="line-height: 12px;"></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left">
|
||||
<p
|
||||
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 14px; line-height: 150%; display: inline-block; margin-bottom: 0;">
|
||||
<a href="{{data.unsubscribe_link}}"
|
||||
style="color: #4285F4; font-weight: normal; font-style: normal; text-decoration: underline;">Unsubscribe</a>
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td style="border-top: 1px solid #E8E5F0; font-size: 0; line-height: 0; height: 1px;"> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<!-- Content -->
|
||||
<tr>
|
||||
<td class="card-inner" bgcolor="#FFFFFF"
|
||||
style="padding: 36px 48px 44px; color: #1F1F20; font-family: 'Poppins', sans-serif; border-radius: 0 0 16px 16px;">
|
||||
{{data.message|safe}}
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
</table>
|
||||
<!-- ============ END CARD ============ -->
|
||||
|
||||
<!-- Spacer -->
|
||||
<table width="640" class="container" align="center" cellpadding="0" cellspacing="0" border="0"
|
||||
style="max-width: 640px;">
|
||||
<tr>
|
||||
<td style="height: 40px; font-size: 0; line-height: 0;"> </td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<!-- ============ FOOTER ============ -->
|
||||
<table class="container" align="center" width="640" cellpadding="0" cellspacing="0" border="0"
|
||||
style="max-width: 640px;">
|
||||
<tr>
|
||||
<td align="center" style="padding: 0 48px;">
|
||||
|
||||
<!-- Logo Text -->
|
||||
<p
|
||||
style="font-family: 'Poppins', sans-serif; font-size: 22px; font-weight: 700; color: #FFFFFF; margin: 0 0 16px; letter-spacing: -0.5px;">
|
||||
AutoGPT</p>
|
||||
|
||||
<!-- Social Icons -->
|
||||
<table role="presentation" cellpadding="0" cellspacing="0" border="0" align="center">
|
||||
<tr>
|
||||
<td align="center" valign="middle" style="padding: 0 8px;">
|
||||
<a href="https://x.com/auto_gpt" target="_blank" style="text-decoration: none;">
|
||||
<img
|
||||
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/white/x.png"
|
||||
width="22" alt="X" style="display: block; opacity: 0.6;">
|
||||
</a>
|
||||
</td>
|
||||
<td align="center" valign="middle" style="padding: 0 8px;">
|
||||
<a href="https://discord.gg/autogpt" target="_blank" style="text-decoration: none;">
|
||||
<img
|
||||
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/white/discord.png"
|
||||
width="22" alt="Discord" style="display: block; opacity: 0.6;">
|
||||
</a>
|
||||
</td>
|
||||
<td align="center" valign="middle" style="padding: 0 8px;">
|
||||
<a href="https://agpt.co/" target="_blank" style="text-decoration: none;">
|
||||
<img
|
||||
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/white/website.png"
|
||||
width="22" alt="Website" style="display: block; opacity: 0.6;">
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<!-- Spacer -->
|
||||
<table width="100%" cellpadding="0" cellspacing="0" border="0">
|
||||
<tr>
|
||||
<td style="height: 20px; font-size: 0; line-height: 0;"> </td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<!-- Divider -->
|
||||
<table width="100%" cellpadding="0" cellspacing="0" border="0">
|
||||
<tr>
|
||||
<td
|
||||
style="border-top: 1px solid rgba(255,255,255,0.08); font-size: 0; line-height: 0; height: 1px;">
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<!-- Footer Text -->
|
||||
<p
|
||||
style="font-family: 'Poppins', sans-serif; color: rgba(255,255,255,0.3); font-size: 12px; line-height: 165%; margin: 16px 0 0; text-align: center;">
|
||||
AutoGPT · 3rd Floor 1 Ashley Road, Altrincham, WA14 2DT, United Kingdom
|
||||
</p>
|
||||
<p
|
||||
style="font-family: 'Poppins', sans-serif; color: rgba(255,255,255,0.3); font-size: 12px; line-height: 165%; margin: 4px 0 0; text-align: center;">
|
||||
You received this email because you signed up on our website.
|
||||
<a href="{{data.unsubscribe_link}}"
|
||||
style="color: rgba(255,255,255,0.45); text-decoration: underline;">Unsubscribe</a>
|
||||
</p>
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
</html>
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
<div style="font-family: 'Poppins', sans-serif; color: #1F1F20;">
|
||||
{{ email_body_html|safe }}
|
||||
|
||||
{% if approval_summary_html %}
|
||||
<!-- Approval Callout -->
|
||||
<table width="100%" cellpadding="0" cellspacing="0" border="0" style="margin-top: 28px; margin-bottom: 8px;">
|
||||
<tr>
|
||||
<td bgcolor="#FFF3E6"
|
||||
style="background-color: #FFF3E6; border-left: 4px solid #FE8700; border-radius: 12px; padding: 20px 24px;">
|
||||
<table width="100%" cellpadding="0" cellspacing="0" border="0">
|
||||
<tr>
|
||||
<td style="padding: 0 0 12px 0;">
|
||||
<span
|
||||
style="display: inline-block; background-color: #FE8700; color: #FFFFFF; font-size: 11px; font-weight: 600; letter-spacing: 0.05em; text-transform: uppercase; padding: 4px 10px; border-radius: 999px;">
|
||||
Approval needed
|
||||
</span>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
{{ approval_summary_html|safe }}
|
||||
<p style="font-size: 14px; line-height: 165%; margin-top: 8px; margin-bottom: 0; color: #505057;">
|
||||
I thought this was a good idea. If you want it to happen, please hit approve.
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
{% endif %}
|
||||
|
||||
<!-- CTA Button -->
|
||||
<table cellpadding="0" cellspacing="0" border="0" style="margin-top: 32px;">
|
||||
<tr>
|
||||
<td align="center" bgcolor="#7733F5"
|
||||
style="background-color: #7733F5; border-radius: 12px;">
|
||||
<a href="{{ cta_url }}"
|
||||
style="display: inline-block; padding: 16px 36px; background-color: #7733F5; color: #FFFFFF; text-decoration: none; font-family: 'Poppins', sans-serif; font-weight: 600; font-size: 16px; border-radius: 12px; line-height: 1;">
|
||||
{{ cta_label }}
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
@@ -0,0 +1,58 @@
|
||||
<div style="font-family: 'Poppins', sans-serif; color: #1F1F20;">
|
||||
<!-- Header -->
|
||||
<h2
|
||||
style="font-size: 24px; line-height: 130%; font-weight: 700; margin-top: 0; margin-bottom: 8px; color: #1F1F20; letter-spacing: -0.5px;">
|
||||
Autopilot picked up where you left off
|
||||
</h2>
|
||||
<p style="font-size: 14px; line-height: 165%; margin-top: 0; margin-bottom: 28px; color: #505057;">
|
||||
We used your recent Copilot activity to prepare a concrete follow-up for you.
|
||||
</p>
|
||||
|
||||
<!-- Divider -->
|
||||
<table width="100%" cellpadding="0" cellspacing="0" border="0" style="margin-bottom: 28px;">
|
||||
<tr>
|
||||
<td style="border-top: 1px solid #E8E5F0; font-size: 0; line-height: 0; height: 1px;"> </td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<!-- Body -->
|
||||
{{ email_body_html|safe }}
|
||||
|
||||
{% if approval_summary_html %}
|
||||
<!-- Approval Callout -->
|
||||
<table width="100%" cellpadding="0" cellspacing="0" border="0" style="margin-top: 28px; margin-bottom: 8px;">
|
||||
<tr>
|
||||
<td bgcolor="#FFF3E6"
|
||||
style="background-color: #FFF3E6; border-left: 4px solid #FE8700; border-radius: 12px; padding: 20px 24px;">
|
||||
<table width="100%" cellpadding="0" cellspacing="0" border="0">
|
||||
<tr>
|
||||
<td style="padding: 0 0 12px 0;">
|
||||
<span
|
||||
style="display: inline-block; background-color: #FE8700; color: #FFFFFF; font-size: 11px; font-weight: 600; letter-spacing: 0.05em; text-transform: uppercase; padding: 4px 10px; border-radius: 999px;">
|
||||
Approval needed
|
||||
</span>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
{{ approval_summary_html|safe }}
|
||||
<p style="font-size: 14px; line-height: 165%; margin-top: 8px; margin-bottom: 0; color: #505057;">
|
||||
I thought this was a good idea. If you want it to happen, please hit approve.
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
{% endif %}
|
||||
|
||||
<!-- CTA Button -->
|
||||
<table cellpadding="0" cellspacing="0" border="0" style="margin-top: 32px;">
|
||||
<tr>
|
||||
<td align="center" bgcolor="#7733F5"
|
||||
style="background-color: #7733F5; border-radius: 12px;">
|
||||
<a href="{{ cta_url }}"
|
||||
style="display: inline-block; padding: 16px 36px; background-color: #7733F5; color: #FFFFFF; text-decoration: none; font-family: 'Poppins', sans-serif; font-weight: 600; font-size: 16px; border-radius: 12px; line-height: 1;">
|
||||
{{ cta_label }}
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
@@ -0,0 +1,64 @@
|
||||
<div style="font-family: 'Poppins', sans-serif; color: #1F1F20;">
|
||||
<!-- Header -->
|
||||
<h2
|
||||
style="font-size: 24px; line-height: 130%; font-weight: 700; margin-top: 0; margin-bottom: 8px; color: #1F1F20; letter-spacing: -0.5px;">
|
||||
Your Autopilot beta access is waiting
|
||||
</h2>
|
||||
<p style="font-size: 14px; line-height: 165%; margin-top: 0; margin-bottom: 28px; color: #505057;">
|
||||
You applied to try Autopilot. Here is a tailored example of how it can help once you jump back in.
|
||||
</p>
|
||||
|
||||
<!-- Highlight Card -->
|
||||
<table width="100%" cellpadding="0" cellspacing="0" border="0" style="margin-bottom: 28px;">
|
||||
<tr>
|
||||
<td bgcolor="#5424AE"
|
||||
style="background-color: #5424AE; border-radius: 12px; padding: 20px 24px;">
|
||||
<p
|
||||
style="font-size: 14px; line-height: 165%; margin-top: 0; margin-bottom: 0; color: #FFFFFF; font-weight: 500;">
|
||||
Autopilot works in the background to handle tasks, surface insights, and take action on your behalf.
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<!-- Body -->
|
||||
{{ email_body_html|safe }}
|
||||
|
||||
{% if approval_summary_html %}
|
||||
<!-- Approval Callout -->
|
||||
<table width="100%" cellpadding="0" cellspacing="0" border="0" style="margin-top: 28px; margin-bottom: 8px;">
|
||||
<tr>
|
||||
<td bgcolor="#FFF3E6"
|
||||
style="background-color: #FFF3E6; border-left: 4px solid #FE8700; border-radius: 12px; padding: 20px 24px;">
|
||||
<table width="100%" cellpadding="0" cellspacing="0" border="0">
|
||||
<tr>
|
||||
<td style="padding: 0 0 12px 0;">
|
||||
<span
|
||||
style="display: inline-block; background-color: #FE8700; color: #FFFFFF; font-size: 11px; font-weight: 600; letter-spacing: 0.05em; text-transform: uppercase; padding: 4px 10px; border-radius: 999px;">
|
||||
Approval needed
|
||||
</span>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
{{ approval_summary_html|safe }}
|
||||
<p style="font-size: 14px; line-height: 165%; margin-top: 8px; margin-bottom: 0; color: #505057;">
|
||||
I thought this was a good idea. If you want it to happen, please hit approve.
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
{% endif %}
|
||||
|
||||
<!-- CTA Button -->
|
||||
<table cellpadding="0" cellspacing="0" border="0" style="margin-top: 32px;">
|
||||
<tr>
|
||||
<td align="center" bgcolor="#7733F5"
|
||||
style="background-color: #7733F5; border-radius: 12px;">
|
||||
<a href="{{ cta_url }}"
|
||||
style="display: inline-block; padding: 16px 36px; background-color: #7733F5; color: #FFFFFF; text-decoration: none; font-family: 'Poppins', sans-serif; font-weight: 600; font-size: 16px; border-radius: 12px; line-height: 1;">
|
||||
{{ cta_label }}
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
@@ -39,6 +39,7 @@ class Flag(str, Enum):
|
||||
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
|
||||
CHAT = "chat"
|
||||
COPILOT_SDK = "copilot-sdk"
|
||||
NIGHTLY_COPILOT = "nightly-copilot"
|
||||
|
||||
|
||||
def is_configured() -> bool:
|
||||
|
||||
@@ -275,12 +275,13 @@ async def store_media_file(
|
||||
# Process file
|
||||
elif file.startswith("data:"):
|
||||
# Data URI
|
||||
parsed_uri = parse_data_uri(file)
|
||||
if parsed_uri is None:
|
||||
match = re.match(r"^data:([^;]+);base64,(.*)$", file, re.DOTALL)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
"Invalid data URI format. Expected data:<mime>;base64,<data>"
|
||||
)
|
||||
mime_type, b64_content = parsed_uri
|
||||
mime_type = match.group(1).strip().lower()
|
||||
b64_content = match.group(2).strip()
|
||||
|
||||
# Generate filename and decode
|
||||
extension = _extension_from_mime(mime_type)
|
||||
@@ -414,70 +415,13 @@ def get_dir_size(path: Path) -> int:
|
||||
return total
|
||||
|
||||
|
||||
async def resolve_media_content(
|
||||
content: MediaFileType,
|
||||
execution_context: "ExecutionContext",
|
||||
*,
|
||||
return_format: MediaReturnFormat,
|
||||
) -> MediaFileType:
|
||||
"""Resolve a ``MediaFileType`` value if it is a media reference, pass through otherwise.
|
||||
|
||||
Convenience wrapper around :func:`is_media_file_ref` + :func:`store_media_file`.
|
||||
Plain text content (source code, filenames) is returned unchanged. Media
|
||||
references (``data:``, ``workspace://``, ``http(s)://``) are resolved via
|
||||
:func:`store_media_file` using *return_format*.
|
||||
|
||||
Use this when a block field is typed as ``MediaFileType`` but may contain
|
||||
either literal text or a media reference.
|
||||
"""
|
||||
if not content or not is_media_file_ref(content):
|
||||
return content
|
||||
return await store_media_file(
|
||||
content, execution_context, return_format=return_format
|
||||
)
|
||||
|
||||
|
||||
def is_media_file_ref(value: str) -> bool:
|
||||
"""Return True if *value* looks like a ``MediaFileType`` reference.
|
||||
|
||||
Detects data URIs, workspace:// references, and HTTP(S) URLs — the
|
||||
formats accepted by :func:`store_media_file`. Plain text content
|
||||
(e.g. source code, filenames) returns False.
|
||||
|
||||
Known limitation: HTTP(S) URL detection is heuristic. Any string that
|
||||
starts with ``http://`` or ``https://`` is treated as a media URL, even
|
||||
if it appears as a URL inside source-code comments or documentation.
|
||||
Blocks that produce source code or Markdown as output may therefore
|
||||
trigger false positives. Callers that need higher precision should
|
||||
inspect the string further (e.g. verify the URL is reachable or has a
|
||||
media-friendly extension).
|
||||
|
||||
Note: this does *not* match local file paths, which are ambiguous
|
||||
(could be filenames or actual paths). Blocks that need to resolve
|
||||
local paths should check for them separately.
|
||||
"""
|
||||
return value.startswith(("data:", "workspace://", "http://", "https://"))
|
||||
|
||||
|
||||
def parse_data_uri(value: str) -> tuple[str, str] | None:
|
||||
"""Parse a ``data:<mime>;base64,<payload>`` URI.
|
||||
|
||||
Returns ``(mime_type, base64_payload)`` if *value* is a valid data URI,
|
||||
or ``None`` if it is not.
|
||||
"""
|
||||
match = re.match(r"^data:([^;]+);base64,(.*)$", value, re.DOTALL)
|
||||
if not match:
|
||||
return None
|
||||
return match.group(1).strip().lower(), match.group(2).strip()
|
||||
|
||||
|
||||
def get_mime_type(file: str) -> str:
|
||||
"""
|
||||
Get the MIME type of a file, whether it's a data URI, URL, or local path.
|
||||
"""
|
||||
if file.startswith("data:"):
|
||||
parsed_uri = parse_data_uri(file)
|
||||
return parsed_uri[0] if parsed_uri else "application/octet-stream"
|
||||
match = re.match(r"^data:([^;]+);base64,", file)
|
||||
return match.group(1) if match else "application/octet-stream"
|
||||
|
||||
elif file.startswith(("http://", "https://")):
|
||||
parsed_url = urlparse(file)
|
||||
|
||||
@@ -1,375 +0,0 @@
|
||||
"""Parse file content into structured Python objects based on file format.
|
||||
|
||||
Used by the ``@@agptfile:`` expansion system to eagerly parse well-known file
|
||||
formats into native Python types *before* schema-driven coercion runs. This
|
||||
lets blocks with ``Any``-typed inputs receive structured data rather than raw
|
||||
strings, while blocks expecting strings get the value coerced back via
|
||||
``convert()``.
|
||||
|
||||
Supported formats:
|
||||
|
||||
- **JSON** (``.json``) — arrays and objects are promoted; scalars stay as strings
|
||||
- **JSON Lines** (``.jsonl``, ``.ndjson``) — each non-empty line parsed as JSON;
|
||||
when all lines are dicts with the same keys (tabular data), output is
|
||||
``list[list[Any]]`` with a header row, consistent with CSV/Parquet/Excel;
|
||||
otherwise returns a plain ``list`` of parsed values
|
||||
- **CSV** (``.csv``) — ``csv.reader`` → ``list[list[str]]``
|
||||
- **TSV** (``.tsv``) — tab-delimited → ``list[list[str]]``
|
||||
- **YAML** (``.yaml``, ``.yml``) — parsed via PyYAML; containers only
|
||||
- **TOML** (``.toml``) — parsed via stdlib ``tomllib``
|
||||
- **Parquet** (``.parquet``) — via pandas/pyarrow → ``list[list[Any]]`` with header row
|
||||
- **Excel** (``.xlsx``) — via pandas/openpyxl → ``list[list[Any]]`` with header row
|
||||
(legacy ``.xls`` is **not** supported — only the modern OOXML format)
|
||||
|
||||
The **fallback contract** is enforced by :func:`parse_file_content`, not by
|
||||
individual parser functions. If any parser raises, ``parse_file_content``
|
||||
catches the exception and returns the original content unchanged (string for
|
||||
text formats, bytes for binary formats). Callers should never see an
|
||||
exception from the public API when ``strict=False``.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import tomllib
|
||||
import zipfile
|
||||
from collections.abc import Callable
|
||||
|
||||
# posixpath.splitext handles forward-slash URI paths correctly on all platforms,
|
||||
# unlike os.path.splitext which uses platform-native separators.
|
||||
from posixpath import splitext
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extension / MIME → format label mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_EXT_TO_FORMAT: dict[str, str] = {
|
||||
".json": "json",
|
||||
".jsonl": "jsonl",
|
||||
".ndjson": "jsonl",
|
||||
".csv": "csv",
|
||||
".tsv": "tsv",
|
||||
".yaml": "yaml",
|
||||
".yml": "yaml",
|
||||
".toml": "toml",
|
||||
".parquet": "parquet",
|
||||
".xlsx": "xlsx",
|
||||
}
|
||||
|
||||
MIME_TO_FORMAT: dict[str, str] = {
|
||||
"application/json": "json",
|
||||
"application/x-ndjson": "jsonl",
|
||||
"application/jsonl": "jsonl",
|
||||
"text/csv": "csv",
|
||||
"text/tab-separated-values": "tsv",
|
||||
"application/x-yaml": "yaml",
|
||||
"application/yaml": "yaml",
|
||||
"text/yaml": "yaml",
|
||||
"application/toml": "toml",
|
||||
"application/vnd.apache.parquet": "parquet",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "xlsx",
|
||||
}
|
||||
|
||||
# Formats that require raw bytes rather than decoded text.
|
||||
BINARY_FORMATS: frozenset[str] = frozenset({"parquet", "xlsx"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API (top-down: main functions first, helpers below)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def infer_format_from_uri(uri: str) -> str | None:
|
||||
"""Return a format label based on URI extension or MIME fragment.
|
||||
|
||||
Returns ``None`` when the format cannot be determined — the caller should
|
||||
fall back to returning the content as a plain string.
|
||||
"""
|
||||
# 1. Check MIME fragment (workspace://abc123#application/json)
|
||||
if "#" in uri:
|
||||
_, fragment = uri.rsplit("#", 1)
|
||||
fmt = MIME_TO_FORMAT.get(fragment.lower())
|
||||
if fmt:
|
||||
return fmt
|
||||
|
||||
# 2. Check file extension from the path portion.
|
||||
# Strip the fragment first so ".json#mime" doesn't confuse splitext.
|
||||
path = uri.split("#")[0].split("?")[0]
|
||||
_, ext = splitext(path)
|
||||
fmt = _EXT_TO_FORMAT.get(ext.lower())
|
||||
if fmt is not None:
|
||||
return fmt
|
||||
|
||||
# Legacy .xls is not supported — map it so callers can produce a
|
||||
# user-friendly error instead of returning garbled binary.
|
||||
if ext.lower() == ".xls":
|
||||
return "xls"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def parse_file_content(content: str | bytes, fmt: str, *, strict: bool = False) -> Any:
|
||||
"""Parse *content* according to *fmt* and return a native Python value.
|
||||
|
||||
When *strict* is ``False`` (default), returns the original *content*
|
||||
unchanged if *fmt* is not recognised or parsing fails for any reason.
|
||||
This mode **never raises**.
|
||||
|
||||
When *strict* is ``True``, parsing errors are propagated to the caller.
|
||||
Unrecognised formats or type mismatches (e.g. text for a binary format)
|
||||
still return *content* unchanged without raising.
|
||||
"""
|
||||
if fmt == "xls":
|
||||
return (
|
||||
"[Unsupported format] Legacy .xls files are not supported. "
|
||||
"Please re-save the file as .xlsx (Excel 2007+) and upload again."
|
||||
)
|
||||
|
||||
try:
|
||||
if fmt in BINARY_FORMATS:
|
||||
parser = _BINARY_PARSERS.get(fmt)
|
||||
if parser is None:
|
||||
return content
|
||||
if isinstance(content, str):
|
||||
# Caller gave us text for a binary format — can't parse.
|
||||
return content
|
||||
return parser(content)
|
||||
|
||||
parser = _TEXT_PARSERS.get(fmt)
|
||||
if parser is None:
|
||||
return content
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8", errors="replace")
|
||||
return parser(content)
|
||||
|
||||
except PARSE_EXCEPTIONS:
|
||||
if strict:
|
||||
raise
|
||||
logger.debug("Structured parsing failed for format=%s, falling back", fmt)
|
||||
return content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Exception loading helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _load_openpyxl_exception() -> type[Exception]:
|
||||
"""Return openpyxl's InvalidFileException, raising ImportError if absent."""
|
||||
from openpyxl.utils.exceptions import InvalidFileException # noqa: PLC0415
|
||||
|
||||
return InvalidFileException
|
||||
|
||||
|
||||
def _load_arrow_exception() -> type[Exception]:
|
||||
"""Return pyarrow's ArrowException, raising ImportError if absent."""
|
||||
from pyarrow import ArrowException # noqa: PLC0415
|
||||
|
||||
return ArrowException
|
||||
|
||||
|
||||
def _optional_exc(loader: "Callable[[], type[Exception]]") -> "type[Exception] | None":
|
||||
"""Return the exception class from *loader*, or ``None`` if the dep is absent."""
|
||||
try:
|
||||
return loader()
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
# Exception types that can be raised during file content parsing.
|
||||
# Shared between ``parse_file_content`` (which catches them in non-strict mode)
|
||||
# and ``file_ref._expand_bare_ref`` (which re-raises them as FileRefExpansionError).
|
||||
#
|
||||
# Optional-dependency exception types are loaded via a helper that raises
|
||||
# ``ImportError`` at *parse time* rather than silently becoming ``None`` here.
|
||||
# This ensures mypy sees clean types and missing deps surface as real errors.
|
||||
PARSE_EXCEPTIONS: tuple[type[BaseException], ...] = tuple(
|
||||
exc
|
||||
for exc in (
|
||||
json.JSONDecodeError,
|
||||
csv.Error,
|
||||
yaml.YAMLError,
|
||||
tomllib.TOMLDecodeError,
|
||||
ValueError,
|
||||
UnicodeDecodeError,
|
||||
ImportError,
|
||||
OSError,
|
||||
KeyError,
|
||||
TypeError,
|
||||
zipfile.BadZipFile,
|
||||
_optional_exc(_load_openpyxl_exception),
|
||||
# ArrowException covers ArrowIOError and ArrowCapacityError which
|
||||
# do not inherit from standard exceptions; ArrowInvalid/ArrowTypeError
|
||||
# already map to ValueError/TypeError but this catches the rest.
|
||||
_optional_exc(_load_arrow_exception),
|
||||
)
|
||||
if exc is not None
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Text-based parsers (content: str → Any)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_container(parser: Callable[[str], Any], content: str) -> list | dict | str:
|
||||
"""Parse *content* and return the result only if it is a container (list/dict).
|
||||
|
||||
Scalar values (strings, numbers, booleans, None) are discarded and the
|
||||
original *content* string is returned instead. This prevents e.g. a JSON
|
||||
file containing just ``"42"`` from silently becoming an int.
|
||||
"""
|
||||
parsed = parser(content)
|
||||
if isinstance(parsed, (list, dict)):
|
||||
return parsed
|
||||
return content
|
||||
|
||||
|
||||
def _parse_json(content: str) -> list | dict | str:
|
||||
return _parse_container(json.loads, content)
|
||||
|
||||
|
||||
def _parse_jsonl(content: str) -> Any:
|
||||
lines = [json.loads(line) for line in content.splitlines() if line.strip()]
|
||||
if not lines:
|
||||
return content
|
||||
|
||||
# When every line is a dict with the same keys, convert to table format
|
||||
# (header row + data rows) — consistent with CSV/TSV/Parquet/Excel output.
|
||||
# Require ≥2 dicts so a single-line JSONL stays as [dict] (not a table).
|
||||
if len(lines) >= 2 and all(isinstance(obj, dict) for obj in lines):
|
||||
keys = list(lines[0].keys())
|
||||
# Cache as tuple to avoid O(n×k) list allocations in the all() call.
|
||||
keys_tuple = tuple(keys)
|
||||
if keys and all(tuple(obj.keys()) == keys_tuple for obj in lines[1:]):
|
||||
return [keys] + [[obj[k] for k in keys] for obj in lines]
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _parse_csv(content: str) -> Any:
|
||||
return _parse_delimited(content, delimiter=",")
|
||||
|
||||
|
||||
def _parse_tsv(content: str) -> Any:
|
||||
return _parse_delimited(content, delimiter="\t")
|
||||
|
||||
|
||||
def _parse_delimited(content: str, *, delimiter: str) -> Any:
|
||||
reader = csv.reader(io.StringIO(content), delimiter=delimiter)
|
||||
# csv.reader never yields [] — blank lines yield [""]. Filter out
|
||||
# rows where every cell is empty (i.e. truly blank lines).
|
||||
rows = [row for row in reader if _row_has_content(row)]
|
||||
if not rows:
|
||||
return content
|
||||
# If the declared delimiter produces only single-column rows, try
|
||||
# sniffing the actual delimiter — catches misidentified files (e.g.
|
||||
# a tab-delimited file with a .csv extension).
|
||||
if len(rows[0]) == 1:
|
||||
try:
|
||||
dialect = csv.Sniffer().sniff(content[:8192])
|
||||
if dialect.delimiter != delimiter:
|
||||
reader = csv.reader(io.StringIO(content), dialect)
|
||||
rows = [row for row in reader if _row_has_content(row)]
|
||||
except csv.Error:
|
||||
pass
|
||||
if rows and len(rows[0]) >= 2:
|
||||
return rows
|
||||
return content
|
||||
|
||||
|
||||
def _row_has_content(row: list[str]) -> bool:
|
||||
"""Return True when *row* contains at least one non-empty cell.
|
||||
|
||||
``csv.reader`` never yields ``[]`` — truly blank lines yield ``[""]``.
|
||||
This predicate filters those out consistently across the initial read
|
||||
and the sniffer-fallback re-read.
|
||||
"""
|
||||
return any(cell for cell in row)
|
||||
|
||||
|
||||
def _parse_yaml(content: str) -> list | dict | str:
|
||||
# NOTE: YAML anchor/alias expansion can amplify input beyond the 10MB cap.
|
||||
# safe_load prevents code execution; for production hardening consider
|
||||
# a YAML parser with expansion limits (e.g. ruamel.yaml with max_alias_count).
|
||||
if "\n---" in content or content.startswith("---\n"):
|
||||
# Multi-document YAML: only the first document is parsed; the rest
|
||||
# are silently ignored by yaml.safe_load. Warn so callers are aware.
|
||||
logger.warning(
|
||||
"Multi-document YAML detected (--- separator); "
|
||||
"only the first document will be parsed."
|
||||
)
|
||||
return _parse_container(yaml.safe_load, content)
|
||||
|
||||
|
||||
def _parse_toml(content: str) -> Any:
|
||||
parsed = tomllib.loads(content)
|
||||
# tomllib.loads always returns a dict — return it even if empty.
|
||||
return parsed
|
||||
|
||||
|
||||
_TEXT_PARSERS: dict[str, Callable[[str], Any]] = {
|
||||
"json": _parse_json,
|
||||
"jsonl": _parse_jsonl,
|
||||
"csv": _parse_csv,
|
||||
"tsv": _parse_tsv,
|
||||
"yaml": _parse_yaml,
|
||||
"toml": _parse_toml,
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Binary-based parsers (content: bytes → Any)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_parquet(content: bytes) -> list[list[Any]]:
|
||||
import pandas as pd
|
||||
|
||||
df = pd.read_parquet(io.BytesIO(content))
|
||||
return _df_to_rows(df)
|
||||
|
||||
|
||||
def _parse_xlsx(content: bytes) -> list[list[Any]]:
|
||||
import pandas as pd
|
||||
|
||||
# Explicitly specify openpyxl engine; the default engine varies by pandas
|
||||
# version and does not support legacy .xls (which is excluded by our format map).
|
||||
df = pd.read_excel(io.BytesIO(content), engine="openpyxl")
|
||||
return _df_to_rows(df)
|
||||
|
||||
|
||||
def _df_to_rows(df: Any) -> list[list[Any]]:
|
||||
"""Convert a DataFrame to ``list[list[Any]]`` with a header row.
|
||||
|
||||
NaN values are replaced with ``None`` so the result is JSON-serializable.
|
||||
Uses explicit cell-level checking because ``df.where(df.notna(), None)``
|
||||
silently converts ``None`` back to ``NaN`` in float64 columns.
|
||||
"""
|
||||
header = df.columns.tolist()
|
||||
rows = [
|
||||
[None if _is_nan(cell) else cell for cell in row] for row in df.values.tolist()
|
||||
]
|
||||
return [header] + rows
|
||||
|
||||
|
||||
def _is_nan(cell: Any) -> bool:
|
||||
"""Check if a cell value is NaN, handling non-scalar types (lists, dicts).
|
||||
|
||||
``pd.isna()`` on a list/dict returns a boolean array which raises
|
||||
``ValueError`` in a boolean context. Guard with a scalar check first.
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
return bool(pd.api.types.is_scalar(cell) and pd.isna(cell))
|
||||
|
||||
|
||||
_BINARY_PARSERS: dict[str, Callable[[bytes], Any]] = {
|
||||
"parquet": _parse_parquet,
|
||||
"xlsx": _parse_xlsx,
|
||||
}
|
||||
@@ -1,624 +0,0 @@
|
||||
"""Tests for file_content_parser — format inference and structured parsing."""
|
||||
|
||||
import io
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.file_content_parser import (
|
||||
BINARY_FORMATS,
|
||||
infer_format_from_uri,
|
||||
parse_file_content,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# infer_format_from_uri
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInferFormat:
|
||||
# --- extension-based ---
|
||||
|
||||
def test_json_extension(self):
|
||||
assert infer_format_from_uri("/home/user/data.json") == "json"
|
||||
|
||||
def test_jsonl_extension(self):
|
||||
assert infer_format_from_uri("/tmp/events.jsonl") == "jsonl"
|
||||
|
||||
def test_ndjson_extension(self):
|
||||
assert infer_format_from_uri("/tmp/events.ndjson") == "jsonl"
|
||||
|
||||
def test_csv_extension(self):
|
||||
assert infer_format_from_uri("workspace:///reports/sales.csv") == "csv"
|
||||
|
||||
def test_tsv_extension(self):
|
||||
assert infer_format_from_uri("/home/user/data.tsv") == "tsv"
|
||||
|
||||
def test_yaml_extension(self):
|
||||
assert infer_format_from_uri("/home/user/config.yaml") == "yaml"
|
||||
|
||||
def test_yml_extension(self):
|
||||
assert infer_format_from_uri("/home/user/config.yml") == "yaml"
|
||||
|
||||
def test_toml_extension(self):
|
||||
assert infer_format_from_uri("/home/user/config.toml") == "toml"
|
||||
|
||||
def test_parquet_extension(self):
|
||||
assert infer_format_from_uri("/data/table.parquet") == "parquet"
|
||||
|
||||
def test_xlsx_extension(self):
|
||||
assert infer_format_from_uri("/data/spreadsheet.xlsx") == "xlsx"
|
||||
|
||||
def test_xls_extension_returns_xls_label(self):
|
||||
# Legacy .xls is mapped so callers can produce a helpful error.
|
||||
assert infer_format_from_uri("/data/old_spreadsheet.xls") == "xls"
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert infer_format_from_uri("/data/FILE.JSON") == "json"
|
||||
assert infer_format_from_uri("/data/FILE.CSV") == "csv"
|
||||
|
||||
def test_unicode_filename(self):
|
||||
assert infer_format_from_uri("/home/user/\u30c7\u30fc\u30bf.json") == "json"
|
||||
assert infer_format_from_uri("/home/user/\u00e9t\u00e9.csv") == "csv"
|
||||
|
||||
def test_unknown_extension(self):
|
||||
assert infer_format_from_uri("/home/user/readme.txt") is None
|
||||
|
||||
def test_no_extension(self):
|
||||
assert infer_format_from_uri("workspace://abc123") is None
|
||||
|
||||
# --- MIME-based ---
|
||||
|
||||
def test_mime_json(self):
|
||||
assert infer_format_from_uri("workspace://abc123#application/json") == "json"
|
||||
|
||||
def test_mime_csv(self):
|
||||
assert infer_format_from_uri("workspace://abc123#text/csv") == "csv"
|
||||
|
||||
def test_mime_tsv(self):
|
||||
assert (
|
||||
infer_format_from_uri("workspace://abc123#text/tab-separated-values")
|
||||
== "tsv"
|
||||
)
|
||||
|
||||
def test_mime_ndjson(self):
|
||||
assert (
|
||||
infer_format_from_uri("workspace://abc123#application/x-ndjson") == "jsonl"
|
||||
)
|
||||
|
||||
def test_mime_yaml(self):
|
||||
assert infer_format_from_uri("workspace://abc123#application/x-yaml") == "yaml"
|
||||
|
||||
def test_mime_xlsx(self):
|
||||
uri = "workspace://abc123#application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
assert infer_format_from_uri(uri) == "xlsx"
|
||||
|
||||
def test_mime_parquet(self):
|
||||
assert (
|
||||
infer_format_from_uri("workspace://abc123#application/vnd.apache.parquet")
|
||||
== "parquet"
|
||||
)
|
||||
|
||||
def test_unknown_mime(self):
|
||||
assert infer_format_from_uri("workspace://abc123#text/plain") is None
|
||||
|
||||
def test_unknown_mime_falls_through_to_extension(self):
|
||||
# Unknown MIME (text/plain) should fall through to extension-based detection.
|
||||
assert infer_format_from_uri("workspace:///data.csv#text/plain") == "csv"
|
||||
|
||||
# --- MIME takes precedence over extension ---
|
||||
|
||||
def test_mime_overrides_extension(self):
|
||||
# .txt extension but JSON MIME → json
|
||||
assert infer_format_from_uri("workspace:///file.txt#application/json") == "json"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — JSON
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseJson:
|
||||
def test_array(self):
|
||||
result = parse_file_content("[1, 2, 3]", "json")
|
||||
assert result == [1, 2, 3]
|
||||
|
||||
def test_object(self):
|
||||
result = parse_file_content('{"key": "value"}', "json")
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_nested(self):
|
||||
content = json.dumps({"rows": [[1, 2], [3, 4]]})
|
||||
result = parse_file_content(content, "json")
|
||||
assert result == {"rows": [[1, 2], [3, 4]]}
|
||||
|
||||
def test_scalar_string_stays_as_string(self):
|
||||
result = parse_file_content('"hello"', "json")
|
||||
assert result == '"hello"' # original content, not parsed
|
||||
|
||||
def test_scalar_number_stays_as_string(self):
|
||||
result = parse_file_content("42", "json")
|
||||
assert result == "42"
|
||||
|
||||
def test_scalar_boolean_stays_as_string(self):
|
||||
result = parse_file_content("true", "json")
|
||||
assert result == "true"
|
||||
|
||||
def test_null_stays_as_string(self):
|
||||
result = parse_file_content("null", "json")
|
||||
assert result == "null"
|
||||
|
||||
def test_invalid_json_fallback(self):
|
||||
content = "not json at all"
|
||||
result = parse_file_content(content, "json")
|
||||
assert result == content
|
||||
|
||||
def test_empty_string_fallback(self):
|
||||
result = parse_file_content("", "json")
|
||||
assert result == ""
|
||||
|
||||
def test_bytes_input_decoded(self):
|
||||
result = parse_file_content(b"[1, 2, 3]", "json")
|
||||
assert result == [1, 2, 3]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — JSONL
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseJsonl:
|
||||
def test_tabular_uniform_dicts_to_table_format(self):
|
||||
"""JSONL with uniform dict keys → table format (header + rows),
|
||||
consistent with CSV/TSV/Parquet/Excel output."""
|
||||
content = '{"name":"apple","color":"red"}\n{"name":"banana","color":"yellow"}\n{"name":"cherry","color":"red"}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [
|
||||
["name", "color"],
|
||||
["apple", "red"],
|
||||
["banana", "yellow"],
|
||||
["cherry", "red"],
|
||||
]
|
||||
|
||||
def test_tabular_single_key_dicts(self):
|
||||
"""JSONL with single-key uniform dicts → table format."""
|
||||
content = '{"a": 1}\n{"a": 2}\n{"a": 3}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [["a"], [1], [2], [3]]
|
||||
|
||||
def test_tabular_blank_lines_skipped(self):
|
||||
content = '{"a": 1}\n\n{"a": 2}\n'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [["a"], [1], [2]]
|
||||
|
||||
def test_heterogeneous_dicts_stay_as_list(self):
|
||||
"""JSONL with different keys across objects → list of dicts (no table)."""
|
||||
content = '{"name":"apple"}\n{"color":"red"}\n{"size":3}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [{"name": "apple"}, {"color": "red"}, {"size": 3}]
|
||||
|
||||
def test_partially_overlapping_keys_stay_as_list(self):
|
||||
"""JSONL dicts with partially overlapping keys → list of dicts."""
|
||||
content = '{"name":"apple","color":"red"}\n{"name":"banana","size":"medium"}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [
|
||||
{"name": "apple", "color": "red"},
|
||||
{"name": "banana", "size": "medium"},
|
||||
]
|
||||
|
||||
def test_mixed_types_stay_as_list(self):
|
||||
"""JSONL with non-dict lines → list of parsed values (no table)."""
|
||||
content = '1\n"hello"\n[1,2]\n'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [1, "hello", [1, 2]]
|
||||
|
||||
def test_mixed_dicts_and_non_dicts_stay_as_list(self):
|
||||
"""JSONL mixing dicts and non-dicts → list of parsed values."""
|
||||
content = '{"a": 1}\n42\n{"b": 2}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [{"a": 1}, 42, {"b": 2}]
|
||||
|
||||
def test_tabular_preserves_key_order(self):
|
||||
"""Table header should follow the key order of the first object."""
|
||||
content = '{"z": 1, "a": 2}\n{"z": 3, "a": 4}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result[0] == ["z", "a"] # order from first object
|
||||
assert result[1] == [1, 2]
|
||||
assert result[2] == [3, 4]
|
||||
|
||||
def test_single_dict_stays_as_list(self):
|
||||
"""Single-line JSONL with one dict → [dict], NOT a table.
|
||||
Tabular detection requires ≥2 dicts to avoid vacuously true all()."""
|
||||
content = '{"a": 1, "b": 2}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [{"a": 1, "b": 2}]
|
||||
|
||||
def test_tabular_with_none_values(self):
|
||||
"""Uniform keys but some null values → table with None cells."""
|
||||
content = '{"name":"apple","color":"red"}\n{"name":"banana","color":null}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [
|
||||
["name", "color"],
|
||||
["apple", "red"],
|
||||
["banana", None],
|
||||
]
|
||||
|
||||
def test_empty_file_fallback(self):
|
||||
result = parse_file_content("", "jsonl")
|
||||
assert result == ""
|
||||
|
||||
def test_all_blank_lines_fallback(self):
|
||||
result = parse_file_content("\n\n\n", "jsonl")
|
||||
assert result == "\n\n\n"
|
||||
|
||||
def test_invalid_line_fallback(self):
|
||||
content = '{"a": 1}\nnot json\n'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == content # fallback
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — CSV
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseCsv:
|
||||
def test_basic(self):
|
||||
content = "Name,Score\nAlice,90\nBob,85"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
|
||||
|
||||
def test_quoted_fields(self):
|
||||
content = 'Name,Bio\nAlice,"Loves, commas"\nBob,Simple'
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result[1] == ["Alice", "Loves, commas"]
|
||||
|
||||
def test_single_column_fallback(self):
|
||||
# Only 1 column — not tabular enough.
|
||||
content = "Name\nAlice\nBob"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == content
|
||||
|
||||
def test_empty_rows_skipped(self):
|
||||
content = "A,B\n\n1,2\n\n3,4"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["A", "B"], ["1", "2"], ["3", "4"]]
|
||||
|
||||
def test_empty_file_fallback(self):
|
||||
result = parse_file_content("", "csv")
|
||||
assert result == ""
|
||||
|
||||
def test_utf8_bom(self):
|
||||
"""CSV with a UTF-8 BOM should parse correctly (BOM stripped by decode)."""
|
||||
bom = "\ufeff"
|
||||
content = bom + "Name,Score\nAlice,90\nBob,85"
|
||||
result = parse_file_content(content, "csv")
|
||||
# The BOM may be part of the first header cell; ensure rows are still parsed.
|
||||
assert len(result) == 3
|
||||
assert result[1] == ["Alice", "90"]
|
||||
assert result[2] == ["Bob", "85"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — TSV
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseTsv:
|
||||
def test_basic(self):
|
||||
content = "Name\tScore\nAlice\t90\nBob\t85"
|
||||
result = parse_file_content(content, "tsv")
|
||||
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
|
||||
|
||||
def test_single_column_fallback(self):
|
||||
content = "Name\nAlice\nBob"
|
||||
result = parse_file_content(content, "tsv")
|
||||
assert result == content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — YAML
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseYaml:
|
||||
def test_list(self):
|
||||
content = "- apple\n- banana\n- cherry"
|
||||
result = parse_file_content(content, "yaml")
|
||||
assert result == ["apple", "banana", "cherry"]
|
||||
|
||||
def test_dict(self):
|
||||
content = "name: Alice\nage: 30"
|
||||
result = parse_file_content(content, "yaml")
|
||||
assert result == {"name": "Alice", "age": 30}
|
||||
|
||||
def test_nested(self):
|
||||
content = "users:\n - name: Alice\n - name: Bob"
|
||||
result = parse_file_content(content, "yaml")
|
||||
assert result == {"users": [{"name": "Alice"}, {"name": "Bob"}]}
|
||||
|
||||
def test_scalar_stays_as_string(self):
|
||||
result = parse_file_content("hello world", "yaml")
|
||||
assert result == "hello world"
|
||||
|
||||
def test_invalid_yaml_fallback(self):
|
||||
content = ":\n :\n invalid: - -"
|
||||
result = parse_file_content(content, "yaml")
|
||||
# Malformed YAML should fall back to the original string, not raise.
|
||||
assert result == content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — TOML
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseToml:
|
||||
def test_basic(self):
|
||||
content = '[server]\nhost = "localhost"\nport = 8080'
|
||||
result = parse_file_content(content, "toml")
|
||||
assert result == {"server": {"host": "localhost", "port": 8080}}
|
||||
|
||||
def test_flat(self):
|
||||
content = 'name = "test"\ncount = 42'
|
||||
result = parse_file_content(content, "toml")
|
||||
assert result == {"name": "test", "count": 42}
|
||||
|
||||
def test_empty_string_returns_empty_dict(self):
|
||||
result = parse_file_content("", "toml")
|
||||
assert result == {}
|
||||
|
||||
def test_invalid_toml_fallback(self):
|
||||
result = parse_file_content("not = [valid toml", "toml")
|
||||
assert result == "not = [valid toml"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — Parquet (binary)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
try:
|
||||
import pyarrow as _pa # noqa: F401 # pyright: ignore[reportMissingImports]
|
||||
|
||||
_has_pyarrow = True
|
||||
except ImportError:
|
||||
_has_pyarrow = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _has_pyarrow, reason="pyarrow not installed")
|
||||
class TestParseParquet:
|
||||
@pytest.fixture
|
||||
def parquet_bytes(self) -> bytes:
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"Name": ["Alice", "Bob"], "Score": [90, 85]})
|
||||
buf = io.BytesIO()
|
||||
df.to_parquet(buf, index=False)
|
||||
return buf.getvalue()
|
||||
|
||||
def test_basic(self, parquet_bytes: bytes):
|
||||
result = parse_file_content(parquet_bytes, "parquet")
|
||||
assert result == [["Name", "Score"], ["Alice", 90], ["Bob", 85]]
|
||||
|
||||
def test_string_input_fallback(self):
|
||||
# Parquet is binary — string input can't be parsed.
|
||||
result = parse_file_content("not parquet", "parquet")
|
||||
assert result == "not parquet"
|
||||
|
||||
def test_invalid_bytes_fallback(self):
|
||||
result = parse_file_content(b"not parquet bytes", "parquet")
|
||||
assert result == b"not parquet bytes"
|
||||
|
||||
def test_empty_bytes_fallback(self):
|
||||
"""Empty binary input should return the empty bytes, not crash."""
|
||||
result = parse_file_content(b"", "parquet")
|
||||
assert result == b""
|
||||
|
||||
def test_nan_replaced_with_none(self):
|
||||
"""NaN values in Parquet must become None for JSON serializability."""
|
||||
import math
|
||||
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"A": [1.0, float("nan"), 3.0], "B": ["x", None, "z"]})
|
||||
buf = io.BytesIO()
|
||||
df.to_parquet(buf, index=False)
|
||||
result = parse_file_content(buf.getvalue(), "parquet")
|
||||
# Row with NaN in float col → None
|
||||
assert result[2][0] is None # float NaN → None
|
||||
assert result[2][1] is None # str None → None
|
||||
# Ensure no NaN leaks
|
||||
for row in result[1:]:
|
||||
for cell in row:
|
||||
if isinstance(cell, float):
|
||||
assert not math.isnan(cell), f"NaN leaked: {row}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — Excel (binary)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseExcel:
|
||||
@pytest.fixture
|
||||
def xlsx_bytes(self) -> bytes:
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"Name": ["Alice", "Bob"], "Score": [90, 85]})
|
||||
buf = io.BytesIO()
|
||||
df.to_excel(buf, index=False) # type: ignore[arg-type] # BytesIO is a valid target
|
||||
return buf.getvalue()
|
||||
|
||||
def test_basic(self, xlsx_bytes: bytes):
|
||||
result = parse_file_content(xlsx_bytes, "xlsx")
|
||||
assert result == [["Name", "Score"], ["Alice", 90], ["Bob", 85]]
|
||||
|
||||
def test_string_input_fallback(self):
|
||||
result = parse_file_content("not xlsx", "xlsx")
|
||||
assert result == "not xlsx"
|
||||
|
||||
def test_invalid_bytes_fallback(self):
|
||||
result = parse_file_content(b"not xlsx bytes", "xlsx")
|
||||
assert result == b"not xlsx bytes"
|
||||
|
||||
def test_empty_bytes_fallback(self):
|
||||
"""Empty binary input should return the empty bytes, not crash."""
|
||||
result = parse_file_content(b"", "xlsx")
|
||||
assert result == b""
|
||||
|
||||
def test_nan_replaced_with_none(self):
|
||||
"""NaN values in float columns must become None for JSON serializability."""
|
||||
import math
|
||||
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"A": [1.0, float("nan"), 3.0], "B": ["x", "y", None]})
|
||||
buf = io.BytesIO()
|
||||
df.to_excel(buf, index=False) # type: ignore[arg-type]
|
||||
result = parse_file_content(buf.getvalue(), "xlsx")
|
||||
# Row with NaN in float col → None, not float('nan')
|
||||
assert result[2][0] is None # float NaN → None
|
||||
assert result[3][1] is None # str None → None
|
||||
# Ensure no NaN leaks
|
||||
for row in result[1:]: # skip header
|
||||
for cell in row:
|
||||
if isinstance(cell, float):
|
||||
assert not math.isnan(cell), f"NaN leaked: {row}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — unknown format / fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFallback:
|
||||
def test_unknown_format_returns_content(self):
|
||||
result = parse_file_content("hello world", "xml")
|
||||
assert result == "hello world"
|
||||
|
||||
def test_none_format_returns_content(self):
|
||||
# Shouldn't normally be called with unrecognised format, but must not crash.
|
||||
result = parse_file_content("hello", "unknown_format")
|
||||
assert result == "hello"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BINARY_FORMATS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBinaryFormats:
|
||||
def test_parquet_is_binary(self):
|
||||
assert "parquet" in BINARY_FORMATS
|
||||
|
||||
def test_xlsx_is_binary(self):
|
||||
assert "xlsx" in BINARY_FORMATS
|
||||
|
||||
def test_text_formats_not_binary(self):
|
||||
for fmt in ("json", "jsonl", "csv", "tsv", "yaml", "toml"):
|
||||
assert fmt not in BINARY_FORMATS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MIME mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMimeMapping:
|
||||
def test_application_yaml(self):
|
||||
assert infer_format_from_uri("workspace://abc123#application/yaml") == "yaml"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CSV sniffer fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCsvSnifferFallback:
|
||||
def test_tab_delimited_with_csv_format(self):
|
||||
"""Tab-delimited content parsed as csv should use sniffer fallback."""
|
||||
content = "Name\tScore\nAlice\t90\nBob\t85"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
|
||||
|
||||
def test_sniffer_failure_returns_content(self):
|
||||
"""When sniffer fails, single-column falls back to raw content."""
|
||||
content = "Name\nAlice\nBob"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenpyxlInvalidFile fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOpenpyxlFallback:
|
||||
def test_invalid_xlsx_non_strict(self):
|
||||
"""Invalid xlsx bytes should fall back gracefully in non-strict mode."""
|
||||
result = parse_file_content(b"not xlsx bytes", "xlsx")
|
||||
assert result == b"not xlsx bytes"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Header-only CSV
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHeaderOnlyCsv:
|
||||
def test_header_only_csv_returns_header_row(self):
|
||||
"""CSV with only a header row (no data rows) should return [[header]]."""
|
||||
content = "Name,Score"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["Name", "Score"]]
|
||||
|
||||
def test_header_only_csv_with_trailing_newline(self):
|
||||
content = "Name,Score\n"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["Name", "Score"]]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Binary format + line range (line range ignored for binary formats)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _has_pyarrow, reason="pyarrow not installed")
|
||||
class TestBinaryFormatLineRange:
|
||||
def test_parquet_ignores_line_range(self):
|
||||
"""Binary formats should parse the full file regardless of line range.
|
||||
|
||||
Line ranges are meaningless for binary formats (parquet/xlsx) — the
|
||||
caller (file_ref._expand_bare_ref) passes raw bytes and the parser
|
||||
should return the complete structured data.
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
|
||||
buf = io.BytesIO()
|
||||
df.to_parquet(buf, index=False)
|
||||
# parse_file_content itself doesn't take a line range — this tests
|
||||
# that the full content is parsed even though the bytes could have
|
||||
# been truncated upstream (it's not, by design).
|
||||
result = parse_file_content(buf.getvalue(), "parquet")
|
||||
assert result == [["A", "B"], [1, 4], [2, 5], [3, 6]]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Legacy .xls UX
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestXlsFallback:
|
||||
def test_xls_returns_helpful_error_string(self):
|
||||
"""Uploading a .xls file should produce a helpful error, not garbled binary."""
|
||||
result = parse_file_content(b"\xd0\xcf\x11\xe0garbled", "xls")
|
||||
assert isinstance(result, str)
|
||||
assert ".xlsx" in result
|
||||
assert "not supported" in result.lower()
|
||||
|
||||
def test_xls_with_string_content(self):
|
||||
result = parse_file_content("some text", "xls")
|
||||
assert isinstance(result, str)
|
||||
assert ".xlsx" in result
|
||||
@@ -8,12 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.util.file import (
|
||||
is_media_file_ref,
|
||||
parse_data_uri,
|
||||
resolve_media_content,
|
||||
store_media_file,
|
||||
)
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
@@ -349,162 +344,3 @@ class TestFileCloudIntegration:
|
||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_media_file_ref
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsMediaFileRef:
|
||||
def test_data_uri(self):
|
||||
assert is_media_file_ref("data:image/png;base64,iVBORw0KGg==") is True
|
||||
|
||||
def test_workspace_uri(self):
|
||||
assert is_media_file_ref("workspace://abc123") is True
|
||||
|
||||
def test_workspace_uri_with_mime(self):
|
||||
assert is_media_file_ref("workspace://abc123#image/png") is True
|
||||
|
||||
def test_http_url(self):
|
||||
assert is_media_file_ref("http://example.com/image.png") is True
|
||||
|
||||
def test_https_url(self):
|
||||
assert is_media_file_ref("https://example.com/image.png") is True
|
||||
|
||||
def test_plain_text(self):
|
||||
assert is_media_file_ref("print('hello')") is False
|
||||
|
||||
def test_local_path(self):
|
||||
assert is_media_file_ref("/tmp/file.txt") is False
|
||||
|
||||
def test_empty_string(self):
|
||||
assert is_media_file_ref("") is False
|
||||
|
||||
def test_filename(self):
|
||||
assert is_media_file_ref("image.png") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_data_uri
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseDataUri:
|
||||
def test_valid_png(self):
|
||||
result = parse_data_uri("data:image/png;base64,iVBORw0KGg==")
|
||||
assert result is not None
|
||||
mime, payload = result
|
||||
assert mime == "image/png"
|
||||
assert payload == "iVBORw0KGg=="
|
||||
|
||||
def test_valid_text(self):
|
||||
result = parse_data_uri("data:text/plain;base64,SGVsbG8=")
|
||||
assert result is not None
|
||||
assert result[0] == "text/plain"
|
||||
assert result[1] == "SGVsbG8="
|
||||
|
||||
def test_mime_case_normalized(self):
|
||||
result = parse_data_uri("data:IMAGE/PNG;base64,abc")
|
||||
assert result is not None
|
||||
assert result[0] == "image/png"
|
||||
|
||||
def test_not_data_uri(self):
|
||||
assert parse_data_uri("workspace://abc123") is None
|
||||
|
||||
def test_plain_text(self):
|
||||
assert parse_data_uri("hello world") is None
|
||||
|
||||
def test_missing_base64(self):
|
||||
assert parse_data_uri("data:image/png;utf-8,abc") is None
|
||||
|
||||
def test_empty_payload(self):
|
||||
result = parse_data_uri("data:image/png;base64,")
|
||||
assert result is not None
|
||||
assert result[1] == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_media_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveMediaContent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_plain_text_passthrough(self):
|
||||
"""Plain text content (not a media ref) passes through unchanged."""
|
||||
ctx = make_test_context()
|
||||
result = await resolve_media_content(
|
||||
MediaFileType("print('hello')"),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
assert result == "print('hello')"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_string_passthrough(self):
|
||||
"""Empty string passes through unchanged."""
|
||||
ctx = make_test_context()
|
||||
result = await resolve_media_content(
|
||||
MediaFileType(""),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_media_ref_delegates_to_store(self):
|
||||
"""Media references are resolved via store_media_file."""
|
||||
ctx = make_test_context()
|
||||
with patch(
|
||||
"backend.util.file.store_media_file",
|
||||
new=AsyncMock(return_value=MediaFileType("data:image/png;base64,abc")),
|
||||
) as mock_store:
|
||||
result = await resolve_media_content(
|
||||
MediaFileType("workspace://img123"),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
assert result == "data:image/png;base64,abc"
|
||||
mock_store.assert_called_once_with(
|
||||
MediaFileType("workspace://img123"),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_uri_delegates_to_store(self):
|
||||
"""Data URIs are also resolved via store_media_file."""
|
||||
ctx = make_test_context()
|
||||
data_uri = "data:image/png;base64,iVBORw0KGg=="
|
||||
with patch(
|
||||
"backend.util.file.store_media_file",
|
||||
new=AsyncMock(return_value=MediaFileType(data_uri)),
|
||||
) as mock_store:
|
||||
result = await resolve_media_content(
|
||||
MediaFileType(data_uri),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
assert result == data_uri
|
||||
mock_store.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_https_url_delegates_to_store(self):
|
||||
"""HTTPS URLs are resolved via store_media_file."""
|
||||
ctx = make_test_context()
|
||||
with patch(
|
||||
"backend.util.file.store_media_file",
|
||||
new=AsyncMock(return_value=MediaFileType("data:image/png;base64,abc")),
|
||||
) as mock_store:
|
||||
result = await resolve_media_content(
|
||||
MediaFileType("https://example.com/image.png"),
|
||||
ctx,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
assert result == "data:image/png;base64,abc"
|
||||
mock_store.assert_called_once_with(
|
||||
MediaFileType("https://example.com/image.png"),
|
||||
ctx,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import date
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Generic, List, Set, Tuple, Type, TypeVar
|
||||
|
||||
@@ -125,6 +126,22 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default=True,
|
||||
description="If the invite-only signup gate is enforced",
|
||||
)
|
||||
nightly_copilot_callback_start_date: date = Field(
|
||||
default=date(2026, 2, 8),
|
||||
description="Users with sessions since this date are eligible for the one-off autopilot callback cohort.",
|
||||
)
|
||||
nightly_copilot_invite_cta_start_date: date = Field(
|
||||
default=date(2026, 3, 13),
|
||||
description="Invite CTA cohort does not run before this date.",
|
||||
)
|
||||
nightly_copilot_invite_cta_delay_hours: int = Field(
|
||||
default=48,
|
||||
description="Delay after invite creation before the invite CTA can run.",
|
||||
)
|
||||
nightly_copilot_callback_token_ttl_hours: int = Field(
|
||||
default=24 * 14,
|
||||
description="TTL for nightly copilot callback tokens.",
|
||||
)
|
||||
enable_credit: bool = Field(
|
||||
default=False,
|
||||
description="If user credit system is enabled or not",
|
||||
|
||||
@@ -86,9 +86,13 @@ class TextFormatter:
|
||||
"i",
|
||||
"img",
|
||||
"li",
|
||||
"ol",
|
||||
"p",
|
||||
"span",
|
||||
"strong",
|
||||
"table",
|
||||
"td",
|
||||
"tr",
|
||||
"u",
|
||||
"ul",
|
||||
]
|
||||
@@ -98,6 +102,15 @@ class TextFormatter:
|
||||
"*": ["class", "style"],
|
||||
"a": ["href"],
|
||||
"img": ["src"],
|
||||
"table": [
|
||||
"align",
|
||||
"border",
|
||||
"cellpadding",
|
||||
"cellspacing",
|
||||
"role",
|
||||
"width",
|
||||
],
|
||||
"td": ["align", "bgcolor", "colspan", "height", "valign", "width"],
|
||||
}
|
||||
|
||||
def format_string(self, template_str: str, values=None, **kwargs) -> str:
|
||||
|
||||
9
autogpt_platform/backend/backend/util/url.py
Normal file
9
autogpt_platform/backend/backend/util/url.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def get_frontend_base_url() -> str:
|
||||
return (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
).rstrip("/")
|
||||
@@ -0,0 +1,67 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "ChatSessionStartType" AS ENUM(
|
||||
'MANUAL',
|
||||
'AUTOPILOT_NIGHTLY',
|
||||
'AUTOPILOT_CALLBACK',
|
||||
'AUTOPILOT_INVITE_CTA'
|
||||
);
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "ChatSession"
|
||||
ADD COLUMN "startType" "ChatSessionStartType" NOT NULL DEFAULT 'MANUAL',
|
||||
ADD COLUMN "executionTag" TEXT,
|
||||
ADD COLUMN "sessionConfig" JSONB NOT NULL DEFAULT '{}',
|
||||
ADD COLUMN "completionReport" JSONB,
|
||||
ADD COLUMN "completionReportRepairCount" INTEGER NOT NULL DEFAULT 0,
|
||||
ADD COLUMN "completionReportRepairQueuedAt" TIMESTAMP(3),
|
||||
ADD COLUMN "completedAt" TIMESTAMP(3),
|
||||
ADD COLUMN "notificationEmailSentAt" TIMESTAMP(3),
|
||||
ADD COLUMN "notificationEmailSkippedAt" TIMESTAMP(3);
|
||||
|
||||
COMMENT ON COLUMN "ChatSession"."sessionConfig" IS 'Validated by backend.copilot.session_types.ChatSessionConfig';
|
||||
COMMENT ON COLUMN "ChatSession"."completionReport" IS 'Validated by backend.copilot.session_types.StoredCompletionReport';
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "ChatSessionCallbackToken"(
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"userId" TEXT NOT NULL,
|
||||
"sourceSessionId" TEXT,
|
||||
"callbackSessionMessage" TEXT NOT NULL,
|
||||
"expiresAt" TIMESTAMP(3) NOT NULL,
|
||||
"consumedAt" TIMESTAMP(3),
|
||||
"consumedSessionId" TEXT,
|
||||
CONSTRAINT "ChatSessionCallbackToken_pkey" PRIMARY KEY("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "ChatSession_userId_executionTag_key"
|
||||
ON "ChatSession"("userId",
|
||||
"executionTag");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ChatSession_userId_startType_updatedAt_idx"
|
||||
ON "ChatSession"("userId",
|
||||
"startType",
|
||||
"updatedAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ChatSessionCallbackToken_userId_expiresAt_idx"
|
||||
ON "ChatSessionCallbackToken"("userId",
|
||||
"expiresAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ChatSessionCallbackToken_consumedSessionId_idx"
|
||||
ON "ChatSessionCallbackToken"("consumedSessionId");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "ChatSessionCallbackToken" ADD CONSTRAINT "ChatSessionCallbackToken_userId_fkey" FOREIGN KEY("userId") REFERENCES "User"("id")
|
||||
ON DELETE CASCADE
|
||||
ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "ChatSessionCallbackToken" ADD CONSTRAINT "ChatSessionCallbackToken_sourceSessionId_fkey" FOREIGN KEY("sourceSessionId") REFERENCES "ChatSession"("id")
|
||||
ON DELETE
|
||||
CASCADE
|
||||
ON UPDATE CASCADE;
|
||||
89
autogpt_platform/backend/poetry.lock
generated
89
autogpt_platform/backend/poetry.lock
generated
@@ -1360,18 +1360,6 @@ files = [
|
||||
dnspython = ">=2.0.0"
|
||||
idna = ">=2.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "et-xmlfile"
|
||||
version = "2.0.0"
|
||||
description = "An implementation of lxml.xmlfile for the standard library"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "et_xmlfile-2.0.0-py3-none-any.whl", hash = "sha256:7a91720bc756843502c3b7504c77b8fe44217c85c537d85037f0f536151b2caa"},
|
||||
{file = "et_xmlfile-2.0.0.tar.gz", hash = "sha256:dab3f4764309081ce75662649be815c4c9081e88f0837825f90fd28317d4da54"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exa-py"
|
||||
version = "1.16.1"
|
||||
@@ -4240,21 +4228,6 @@ datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
|
||||
realtime = ["websockets (>=13,<16)"]
|
||||
voice-helpers = ["numpy (>=2.0.2)", "sounddevice (>=0.5.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "openpyxl"
|
||||
version = "3.1.5"
|
||||
description = "A Python library to read/write Excel 2010 xlsx/xlsm files"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openpyxl-3.1.5-py2.py3-none-any.whl", hash = "sha256:5282c12b107bffeef825f4617dc029afaf41d0ea60823bbb665ef3079dc79de2"},
|
||||
{file = "openpyxl-3.1.5.tar.gz", hash = "sha256:cf0e3cf56142039133628b5acffe8ef0c12bc902d2aadd3e0fe5878dc08d1050"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
et-xmlfile = "*"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-api"
|
||||
version = "1.39.1"
|
||||
@@ -5457,66 +5430,6 @@ files = [
|
||||
{file = "psycopg2_binary-2.9.11-cp39-cp39-win_amd64.whl", hash = "sha256:875039274f8a2361e5207857899706da840768e2a775bf8c65e82f60b197df02"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyarrow"
|
||||
version = "23.0.1"
|
||||
description = "Python library for Apache Arrow"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:3fab8f82571844eb3c460f90a75583801d14ca0cc32b1acc8c361650e006fd56"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:3f91c038b95f71ddfc865f11d5876c42f343b4495535bd262c7b321b0b94507c"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:d0744403adabef53c985a7f8a082b502a368510c40d184df349a0a8754533258"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:c33b5bf406284fd0bba436ed6f6c3ebe8e311722b441d89397c54f871c6863a2"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ddf743e82f69dcd6dbbcb63628895d7161e04e56794ef80550ac6f3315eeb1d5"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e052a211c5ac9848ae15d5ec875ed0943c0221e2fcfe69eee80b604b4e703222"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:5abde149bb3ce524782d838eb67ac095cd3fd6090eba051130589793f1a7f76d"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:6f0147ee9e0386f519c952cc670eb4a8b05caa594eeffe01af0e25f699e4e9bb"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:0ae6e17c828455b6265d590100c295193f93cc5675eb0af59e49dbd00d2de350"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:fed7020203e9ef273360b9e45be52a2a47d3103caf156a30ace5247ffb51bdbd"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:26d50dee49d741ac0e82185033488d28d35be4d763ae6f321f97d1140eb7a0e9"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:3c30143b17161310f151f4a2bcfe41b5ff744238c1039338779424e38579d701"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:db2190fa79c80a23fdd29fef4b8992893f024ae7c17d2f5f4db7171fa30c2c78"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:f00f993a8179e0e1c9713bcc0baf6d6c01326a406a9c23495ec1ba9c9ebf2919"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:f4b0dbfa124c0bb161f8b5ebb40f1a680b70279aa0c9901d44a2b5a20806039f"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:7707d2b6673f7de054e2e83d59f9e805939038eebe1763fe811ee8fa5c0cd1a7"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:86ff03fb9f1a320266e0de855dee4b17da6794c595d207f89bba40d16b5c78b9"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:813d99f31275919c383aab17f0f455a04f5a429c261cc411b1e9a8f5e4aaaa05"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bf5842f960cddd2ef757d486041d57c96483efc295a8c4a0e20e704cbbf39c67"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:564baf97c858ecc03ec01a41062e8f4698abc3e6e2acd79c01c2e97880a19730"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:07deae7783782ac7250989a7b2ecde9b3c343a643f82e8a4df03d93b633006f0"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:6b8fda694640b00e8af3c824f99f789e836720aa8c9379fb435d4c4953a756b8"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:8ff51b1addc469b9444b7c6f3548e19dc931b172ab234e995a60aea9f6e6025f"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:71c5be5cbf1e1cb6169d2a0980850bccb558ddc9b747b6206435313c47c37677"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:9b6f4f17b43bc39d56fec96e53fe89d94bac3eb134137964371b45352d40d0c2"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9fc13fc6c403d1337acab46a2c4346ca6c9dec5780c3c697cf8abfd5e19b6b37"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5c16ed4f53247fa3ffb12a14d236de4213a4415d127fe9cebed33d51671113e2"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:cecfb12ef629cf6be0b1887f9f86463b0dd3dc3195ae6224e74006be4736035a"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:29f7f7419a0e30264ea261fdc0e5fe63ce5a6095003db2945d7cd78df391a7e1"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:33d648dc25b51fd8055c19e4261e813dfc4d2427f068bcecc8b53d01b81b0500"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:cd395abf8f91c673dd3589cadc8cc1ee4e8674fa61b2e923c8dd215d9c7d1f41"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:00be9576d970c31defb5c32eb72ef585bf600ef6d0a82d5eccaae96639cf9d07"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:c2139549494445609f35a5cda4eb94e2c9e4d704ce60a095b342f82460c73a83"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:7044b442f184d84e2351e5084600f0d7343d6117aabcbc1ac78eb1ae11eb4125"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-win_amd64.whl", hash = "sha256:a35581e856a2fafa12f3f54fce4331862b1cfb0bef5758347a858a4aa9d6bae8"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:5df1161da23636a70838099d4aaa65142777185cc0cdba4037a18cee7d8db9ca"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:fa8e51cb04b9f8c9c5ace6bab63af9a1f88d35c0d6cbf53e8c17c098552285e1"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:0b95a3994f015be13c63148fef8832e8a23938128c185ee951c98908a696e0eb"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:4982d71350b1a6e5cfe1af742c53dfb759b11ce14141870d05d9e540d13bc5d1"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c250248f1fe266db627921c89b47b7c06fee0489ad95b04d50353537d74d6886"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:5f4763b83c11c16e5f4c15601ba6dfa849e20723b46aa2617cb4bffe8768479f"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-win_amd64.whl", hash = "sha256:3a4c85ef66c134161987c17b147d6bffdca4566f9a4c1d81a0a01cdf08414ea5"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:17cd28e906c18af486a499422740298c52d7c6795344ea5002a7720b4eadf16d"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-macosx_12_0_x86_64.whl", hash = "sha256:76e823d0e86b4fb5e1cf4a58d293036e678b5a4b03539be933d3b31f9406859f"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:a62e1899e3078bf65943078b3ad2a6ddcacf2373bc06379aac61b1e548a75814"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:df088e8f640c9fae3b1f495b3c64755c4e719091caf250f3a74d095ddf3c836d"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:46718a220d64677c93bc243af1d44b55998255427588e400677d7192671845c7"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:a09f3876e87f48bc2f13583ab551f0379e5dfb83210391e68ace404181a20690"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-win_amd64.whl", hash = "sha256:527e8d899f14bd15b740cd5a54ad56b7f98044955373a17179d5956ddb93d9ce"},
|
||||
{file = "pyarrow-23.0.1.tar.gz", hash = "sha256:b8c5873e33440b2bc2f4a79d2b47017a89c5a24116c055625e6f2ee50523f019"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyasn1"
|
||||
version = "0.6.2"
|
||||
@@ -8969,4 +8882,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "86dab25684dd46e635a33bd33281a926e5626a874ecc048c34389fecf34a87d8"
|
||||
content-hash = "4e4365721cd3b68c58c237353b74adae1c64233fd4446904c335f23eb866fdca"
|
||||
|
||||
@@ -37,6 +37,7 @@ jinja2 = "^3.1.6"
|
||||
jsonref = "^1.1.0"
|
||||
jsonschema = "^4.25.0"
|
||||
langfuse = "^3.14.1"
|
||||
markdown-it-py = "^3.0.0"
|
||||
launchdarkly-server-sdk = "^9.14.1"
|
||||
mem0ai = "^0.1.115"
|
||||
moviepy = "^2.1.2"
|
||||
@@ -92,8 +93,6 @@ gravitas-md2gdocs = "^0.1.0"
|
||||
posthog = "^7.6.0"
|
||||
fpdf2 = "^2.8.6"
|
||||
langsmith = "^0.7.7"
|
||||
openpyxl = "^3.1.5"
|
||||
pyarrow = "^23.0.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
aiohappyeyeballs = "^2.6.1"
|
||||
|
||||
@@ -66,6 +66,7 @@ model User {
|
||||
PendingHumanReviews PendingHumanReview[]
|
||||
Workspace UserWorkspace?
|
||||
ClaimedInvite InvitedUser? @relation("InvitedUserAuthUser")
|
||||
ChatSessionCallbackTokens ChatSessionCallbackToken[]
|
||||
|
||||
// OAuth Provider relations
|
||||
OAuthApplications OAuthApplication[]
|
||||
@@ -87,6 +88,13 @@ enum TallyComputationStatus {
|
||||
FAILED
|
||||
}
|
||||
|
||||
enum ChatSessionStartType {
|
||||
MANUAL
|
||||
AUTOPILOT_NIGHTLY
|
||||
AUTOPILOT_CALLBACK
|
||||
AUTOPILOT_INVITE_CTA
|
||||
}
|
||||
|
||||
model InvitedUser {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
@@ -248,6 +256,15 @@ model ChatSession {
|
||||
// Session metadata
|
||||
title String?
|
||||
credentials Json @default("{}") // Map of provider -> credential metadata
|
||||
startType ChatSessionStartType @default(MANUAL)
|
||||
executionTag String?
|
||||
sessionConfig Json @default("{}") // ChatSessionConfig payload from backend.copilot.session_types.ChatSessionConfig
|
||||
completionReport Json? // StoredCompletionReport payload from backend.copilot.session_types.StoredCompletionReport
|
||||
completionReportRepairCount Int @default(0)
|
||||
completionReportRepairQueuedAt DateTime?
|
||||
completedAt DateTime?
|
||||
notificationEmailSentAt DateTime?
|
||||
notificationEmailSkippedAt DateTime?
|
||||
|
||||
// Rate limiting counters (stored as JSON maps)
|
||||
successfulAgentRuns Json @default("{}") // Map of graph_id -> count
|
||||
@@ -258,8 +275,31 @@ model ChatSession {
|
||||
totalCompletionTokens Int @default(0)
|
||||
|
||||
Messages ChatMessage[]
|
||||
CallbackTokens ChatSessionCallbackToken[] @relation("ChatSessionCallbackSource")
|
||||
|
||||
@@index([userId, updatedAt])
|
||||
@@index([userId, startType, updatedAt])
|
||||
@@unique([userId, executionTag])
|
||||
}
|
||||
|
||||
model ChatSessionCallbackToken {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
sourceSessionId String?
|
||||
SourceSession ChatSession? @relation("ChatSessionCallbackSource", fields: [sourceSessionId], references: [id], onDelete: Cascade)
|
||||
|
||||
callbackSessionMessage String
|
||||
expiresAt DateTime
|
||||
consumedAt DateTime?
|
||||
consumedSessionId String?
|
||||
|
||||
@@index([userId, expiresAt])
|
||||
@@index([consumedSessionId])
|
||||
}
|
||||
|
||||
model ChatMessage {
|
||||
|
||||
@@ -0,0 +1,318 @@
|
||||
"use client";
|
||||
|
||||
import { ChatSessionStartType } from "@/app/api/__generated__/models/chatSessionStartType";
|
||||
import { Badge } from "@/components/atoms/Badge/Badge";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Card } from "@/components/atoms/Card/Card";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { CopilotUsersTable } from "../CopilotUsersTable/CopilotUsersTable";
|
||||
import { useAdminCopilotPage } from "../../useAdminCopilotPage";
|
||||
|
||||
function getStartTypeLabel(startType: ChatSessionStartType) {
|
||||
if (startType === ChatSessionStartType.AUTOPILOT_INVITE_CTA) {
|
||||
return "CTA";
|
||||
}
|
||||
|
||||
if (startType === ChatSessionStartType.AUTOPILOT_NIGHTLY) {
|
||||
return "Nightly";
|
||||
}
|
||||
|
||||
if (startType === ChatSessionStartType.AUTOPILOT_CALLBACK) {
|
||||
return "Callback";
|
||||
}
|
||||
|
||||
return startType;
|
||||
}
|
||||
|
||||
const triggerOptions = [
|
||||
{
|
||||
label: "Trigger CTA",
|
||||
description:
|
||||
"Runs the beta invite CTA flow even if the user would not normally qualify.",
|
||||
startType: ChatSessionStartType.AUTOPILOT_INVITE_CTA,
|
||||
variant: "primary" as const,
|
||||
},
|
||||
{
|
||||
label: "Trigger Nightly",
|
||||
description:
|
||||
"Runs the nightly proactive Autopilot flow immediately for the selected user.",
|
||||
startType: ChatSessionStartType.AUTOPILOT_NIGHTLY,
|
||||
variant: "outline" as const,
|
||||
},
|
||||
{
|
||||
label: "Trigger Callback",
|
||||
description:
|
||||
"Runs the callback re-engagement flow without checking the normal callback cohort.",
|
||||
startType: ChatSessionStartType.AUTOPILOT_CALLBACK,
|
||||
variant: "secondary" as const,
|
||||
},
|
||||
];
|
||||
|
||||
export function AdminCopilotPage() {
|
||||
const {
|
||||
search,
|
||||
selectedUser,
|
||||
pendingTriggerType,
|
||||
lastTriggeredSession,
|
||||
lastEmailSweepResult,
|
||||
searchedUsers,
|
||||
searchErrorMessage,
|
||||
isSearchingUsers,
|
||||
isRefreshingUsers,
|
||||
isTriggeringSession,
|
||||
isSendingPendingEmails,
|
||||
hasSearch,
|
||||
setSearch,
|
||||
handleSelectUser,
|
||||
handleSendPendingEmails,
|
||||
handleTriggerSession,
|
||||
} = useAdminCopilotPage();
|
||||
|
||||
return (
|
||||
<div className="mx-auto flex max-w-7xl flex-col gap-6 p-6">
|
||||
<div className="flex flex-col gap-2">
|
||||
<h1 className="text-3xl font-bold text-zinc-900">Copilot</h1>
|
||||
<p className="max-w-3xl text-sm text-zinc-600">
|
||||
Manually create CTA, Nightly, or Callback Copilot sessions for a
|
||||
specific user. These controls bypass the normal eligibility checks so
|
||||
you can test each flow directly.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-6 xl:grid-cols-[minmax(0,1.35fr),24rem]">
|
||||
<Card className="border border-zinc-200 shadow-sm">
|
||||
<div className="flex flex-col gap-4">
|
||||
<Input
|
||||
id="copilot-user-search"
|
||||
label="Search users"
|
||||
hint="Results update as you type"
|
||||
placeholder="Search by email, name, or user ID"
|
||||
value={search}
|
||||
onChange={(event) => setSearch(event.target.value)}
|
||||
/>
|
||||
{searchErrorMessage ? (
|
||||
<p className="-mt-2 text-sm text-red-500">{searchErrorMessage}</p>
|
||||
) : null}
|
||||
<CopilotUsersTable
|
||||
users={searchedUsers}
|
||||
isLoading={isSearchingUsers}
|
||||
isRefreshing={isRefreshingUsers}
|
||||
hasSearch={hasSearch}
|
||||
selectedUserId={selectedUser?.id ?? null}
|
||||
onSelectUser={handleSelectUser}
|
||||
/>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
<div className="flex flex-col gap-6">
|
||||
<Card className="border border-zinc-200 shadow-sm">
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex items-center justify-between gap-3">
|
||||
<h2 className="text-xl font-semibold text-zinc-900">
|
||||
Selected user
|
||||
</h2>
|
||||
{selectedUser ? <Badge variant="info">Ready</Badge> : null}
|
||||
</div>
|
||||
|
||||
{selectedUser ? (
|
||||
<div className="flex flex-col gap-3 text-sm text-zinc-600">
|
||||
<div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
Email
|
||||
</span>
|
||||
<p className="mt-1 font-medium text-zinc-900">
|
||||
{selectedUser.email}
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
Name
|
||||
</span>
|
||||
<p className="mt-1">
|
||||
{selectedUser.name || "No display name"}
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
Timezone
|
||||
</span>
|
||||
<p className="mt-1">{selectedUser.timezone}</p>
|
||||
</div>
|
||||
<div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
User ID
|
||||
</span>
|
||||
<p className="mt-1 break-all font-mono text-xs text-zinc-500">
|
||||
{selectedUser.id}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<p className="text-sm text-zinc-500">
|
||||
Select a user from the results table to enable manual Copilot
|
||||
triggers.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
<Card className="border border-zinc-200 shadow-sm">
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex flex-col gap-1">
|
||||
<h2 className="text-xl font-semibold text-zinc-900">
|
||||
Trigger flows
|
||||
</h2>
|
||||
<p className="text-sm text-zinc-600">
|
||||
Each action creates a new session immediately for the selected
|
||||
user.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col gap-3">
|
||||
{triggerOptions.map((option) => (
|
||||
<div
|
||||
key={option.startType}
|
||||
className="rounded-2xl border border-zinc-200 p-4"
|
||||
>
|
||||
<div className="flex flex-col gap-3">
|
||||
<div className="flex flex-col gap-1">
|
||||
<span className="font-medium text-zinc-900">
|
||||
{option.label}
|
||||
</span>
|
||||
<p className="text-sm text-zinc-600">
|
||||
{option.description}
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
variant={option.variant}
|
||||
disabled={!selectedUser || isTriggeringSession}
|
||||
loading={pendingTriggerType === option.startType}
|
||||
onClick={() => handleTriggerSession(option.startType)}
|
||||
>
|
||||
{option.label}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
<Card className="border border-zinc-200 shadow-sm">
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex flex-col gap-1">
|
||||
<h2 className="text-xl font-semibold text-zinc-900">
|
||||
Email follow-up
|
||||
</h2>
|
||||
<p className="text-sm text-zinc-600">
|
||||
Run the pending Copilot completion-email sweep immediately for
|
||||
the selected user.
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
variant="secondary"
|
||||
disabled={!selectedUser || isSendingPendingEmails}
|
||||
loading={isSendingPendingEmails}
|
||||
onClick={handleSendPendingEmails}
|
||||
>
|
||||
Send pending emails
|
||||
</Button>
|
||||
{selectedUser && lastEmailSweepResult ? (
|
||||
<div className="rounded-2xl border border-zinc-200 p-4 text-sm text-zinc-600">
|
||||
<p className="font-medium text-zinc-900">
|
||||
Last sweep for {selectedUser.email}
|
||||
</p>
|
||||
<div className="mt-3 grid gap-3 sm:grid-cols-2">
|
||||
<div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
Candidates
|
||||
</span>
|
||||
<p className="mt-1 text-zinc-900">
|
||||
{lastEmailSweepResult.candidate_count}
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
Processed
|
||||
</span>
|
||||
<p className="mt-1 text-zinc-900">
|
||||
{lastEmailSweepResult.processed_count}
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
Sent
|
||||
</span>
|
||||
<p className="mt-1 text-zinc-900">
|
||||
{lastEmailSweepResult.sent_count}
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
Skipped
|
||||
</span>
|
||||
<p className="mt-1 text-zinc-900">
|
||||
{lastEmailSweepResult.skipped_count}
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
Repairs queued
|
||||
</span>
|
||||
<p className="mt-1 text-zinc-900">
|
||||
{lastEmailSweepResult.repair_queued_count}
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
Running / failed
|
||||
</span>
|
||||
<p className="mt-1 text-zinc-900">
|
||||
{lastEmailSweepResult.running_count} /{" "}
|
||||
{lastEmailSweepResult.failed_count}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
{selectedUser && lastTriggeredSession ? (
|
||||
<Card className="border border-zinc-200 shadow-sm">
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex items-center justify-between gap-3">
|
||||
<h2 className="text-xl font-semibold text-zinc-900">
|
||||
Latest session
|
||||
</h2>
|
||||
<Badge variant="success">
|
||||
{getStartTypeLabel(lastTriggeredSession.start_type)}
|
||||
</Badge>
|
||||
</div>
|
||||
<p className="text-sm text-zinc-600">
|
||||
A new Copilot session was created for {selectedUser.email}.
|
||||
</p>
|
||||
<div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
Session ID
|
||||
</span>
|
||||
<p className="mt-1 break-all font-mono text-xs text-zinc-500">
|
||||
{lastTriggeredSession.session_id}
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
as="NextLink"
|
||||
href={`/copilot?sessionId=${lastTriggeredSession.session_id}&showAutopilot=1`}
|
||||
target="_blank"
|
||||
rel="noreferrer"
|
||||
>
|
||||
Open session
|
||||
</Button>
|
||||
</div>
|
||||
</Card>
|
||||
) : null}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
"use client";
|
||||
|
||||
import type { AdminCopilotUserSummary } from "@/app/api/__generated__/models/adminCopilotUserSummary";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/__legacy__/ui/table";
|
||||
|
||||
interface Props {
|
||||
users: AdminCopilotUserSummary[];
|
||||
isLoading: boolean;
|
||||
isRefreshing: boolean;
|
||||
hasSearch: boolean;
|
||||
selectedUserId: string | null;
|
||||
onSelectUser: (user: AdminCopilotUserSummary) => void;
|
||||
}
|
||||
|
||||
function formatDate(value: Date) {
|
||||
return value.toLocaleString();
|
||||
}
|
||||
|
||||
export function CopilotUsersTable({
|
||||
users,
|
||||
isLoading,
|
||||
isRefreshing,
|
||||
hasSearch,
|
||||
selectedUserId,
|
||||
onSelectUser,
|
||||
}: Props) {
|
||||
let emptyMessage = "Search by email, name, or user ID to find a user.";
|
||||
if (hasSearch && isLoading) {
|
||||
emptyMessage = "Searching users...";
|
||||
} else if (hasSearch) {
|
||||
emptyMessage = "No matching users found.";
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex items-center justify-between gap-4">
|
||||
<div className="flex flex-col gap-1">
|
||||
<h2 className="text-xl font-semibold text-zinc-900">User results</h2>
|
||||
<p className="text-sm text-zinc-600">
|
||||
Select an existing user, then run an Autopilot flow manually.
|
||||
</p>
|
||||
</div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
{isRefreshing
|
||||
? "Refreshing"
|
||||
: `${users.length} result${users.length === 1 ? "" : "s"}`}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div className="overflow-hidden rounded-2xl border border-zinc-200">
|
||||
<Table>
|
||||
<TableHeader className="bg-zinc-50">
|
||||
<TableRow>
|
||||
<TableHead>User</TableHead>
|
||||
<TableHead>Timezone</TableHead>
|
||||
<TableHead>Updated</TableHead>
|
||||
<TableHead className="text-right">Action</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{users.length === 0 ? (
|
||||
<TableRow>
|
||||
<TableCell
|
||||
colSpan={4}
|
||||
className="py-10 text-center text-zinc-500"
|
||||
>
|
||||
{emptyMessage}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
) : (
|
||||
users.map((user) => (
|
||||
<TableRow key={user.id} className="align-top">
|
||||
<TableCell>
|
||||
<div className="flex flex-col gap-1">
|
||||
<span className="font-medium text-zinc-900">
|
||||
{user.email}
|
||||
</span>
|
||||
<span className="text-sm text-zinc-600">
|
||||
{user.name || "No display name"}
|
||||
</span>
|
||||
<span className="font-mono text-xs text-zinc-400">
|
||||
{user.id}
|
||||
</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell className="text-sm text-zinc-600">
|
||||
{user.timezone}
|
||||
</TableCell>
|
||||
<TableCell className="text-sm text-zinc-600">
|
||||
{formatDate(user.updated_at)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex justify-end">
|
||||
<Button
|
||||
variant={
|
||||
user.id === selectedUserId ? "secondary" : "outline"
|
||||
}
|
||||
size="small"
|
||||
onClick={() => onSelectUser(user)}
|
||||
>
|
||||
{user.id === selectedUserId ? "Selected" : "Select"}
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))
|
||||
)}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
import { withRoleAccess } from "@/lib/withRoleAccess";
|
||||
import { AdminCopilotPage } from "./components/AdminCopilotPage/AdminCopilotPage";
|
||||
|
||||
function AdminCopilot() {
|
||||
return <AdminCopilotPage />;
|
||||
}
|
||||
|
||||
export default async function AdminCopilotRoute() {
|
||||
"use server";
|
||||
const withAdminAccess = await withRoleAccess(["admin"]);
|
||||
const ProtectedAdminCopilot = await withAdminAccess(AdminCopilot);
|
||||
return <ProtectedAdminCopilot />;
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
"use client";
|
||||
|
||||
import type { AdminCopilotUserSummary } from "@/app/api/__generated__/models/adminCopilotUserSummary";
|
||||
import { ChatSessionStartType } from "@/app/api/__generated__/models/chatSessionStartType";
|
||||
import type { SendCopilotEmailsResponse } from "@/app/api/__generated__/models/sendCopilotEmailsResponse";
|
||||
import type { TriggerCopilotSessionResponse } from "@/app/api/__generated__/models/triggerCopilotSessionResponse";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { customMutator } from "@/app/api/mutators/custom-mutator";
|
||||
import {
|
||||
useGetV2SearchCopilotUsers,
|
||||
usePostV2TriggerCopilotSession,
|
||||
} from "@/app/api/__generated__/endpoints/admin/admin";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { ApiError } from "@/lib/autogpt-server-api/helpers";
|
||||
import { useMutation } from "@tanstack/react-query";
|
||||
import { useDeferredValue, useState } from "react";
|
||||
|
||||
function getErrorMessage(error: unknown) {
|
||||
if (error instanceof ApiError) {
|
||||
if (
|
||||
typeof error.response === "object" &&
|
||||
error.response !== null &&
|
||||
"detail" in error.response &&
|
||||
typeof error.response.detail === "string"
|
||||
) {
|
||||
return error.response.detail;
|
||||
}
|
||||
|
||||
return error.message;
|
||||
}
|
||||
|
||||
if (error instanceof Error) {
|
||||
return error.message;
|
||||
}
|
||||
|
||||
return "Something went wrong";
|
||||
}
|
||||
|
||||
export function useAdminCopilotPage() {
|
||||
const { toast } = useToast();
|
||||
const [search, setSearch] = useState("");
|
||||
const [selectedUser, setSelectedUser] =
|
||||
useState<AdminCopilotUserSummary | null>(null);
|
||||
const [pendingTriggerType, setPendingTriggerType] =
|
||||
useState<ChatSessionStartType | null>(null);
|
||||
const [lastTriggeredSession, setLastTriggeredSession] =
|
||||
useState<TriggerCopilotSessionResponse | null>(null);
|
||||
const [lastEmailSweepResult, setLastEmailSweepResult] =
|
||||
useState<SendCopilotEmailsResponse | null>(null);
|
||||
|
||||
const deferredSearch = useDeferredValue(search);
|
||||
const normalizedSearch = deferredSearch.trim();
|
||||
|
||||
const searchUsersQuery = useGetV2SearchCopilotUsers(
|
||||
normalizedSearch ? { search: normalizedSearch, limit: 20 } : undefined,
|
||||
{
|
||||
query: {
|
||||
enabled: normalizedSearch.length > 0,
|
||||
select: okData,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
const triggerCopilotSessionMutation = usePostV2TriggerCopilotSession({
|
||||
mutation: {
|
||||
onSuccess: (response) => {
|
||||
setPendingTriggerType(null);
|
||||
const session = okData(response) ?? null;
|
||||
setLastTriggeredSession(session);
|
||||
toast({
|
||||
title: "Copilot session created",
|
||||
variant: "default",
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
setPendingTriggerType(null);
|
||||
toast({
|
||||
title: getErrorMessage(error),
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const sendPendingCopilotEmailsMutation = useMutation({
|
||||
mutationKey: ["sendPendingCopilotEmails"],
|
||||
mutationFn: async (userId: string) =>
|
||||
customMutator<{
|
||||
data: SendCopilotEmailsResponse;
|
||||
status: number;
|
||||
headers: Headers;
|
||||
}>("/api/users/admin/copilot/send-emails", {
|
||||
method: "POST",
|
||||
body: JSON.stringify({ user_id: userId }),
|
||||
}),
|
||||
onSuccess: (response) => {
|
||||
const result = okData(response) ?? null;
|
||||
setLastEmailSweepResult(result);
|
||||
if (!result) {
|
||||
toast({
|
||||
title: "Email sweep completed",
|
||||
variant: "default",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
toast({
|
||||
title:
|
||||
result.sent_count > 0
|
||||
? `Sent ${result.sent_count} Copilot email${result.sent_count === 1 ? "" : "s"}`
|
||||
: "Email sweep completed",
|
||||
description: [
|
||||
`${result.candidate_count} candidate${result.candidate_count === 1 ? "" : "s"}`,
|
||||
`${result.sent_count} sent`,
|
||||
`${result.skipped_count} skipped`,
|
||||
`${result.repair_queued_count} repairs queued`,
|
||||
`${result.running_count} still running`,
|
||||
`${result.failed_count} failed`,
|
||||
].join(" • "),
|
||||
variant: "default",
|
||||
});
|
||||
},
|
||||
onError: (error: unknown) => {
|
||||
toast({
|
||||
title: getErrorMessage(error),
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
function handleSelectUser(user: AdminCopilotUserSummary) {
|
||||
setSelectedUser(user);
|
||||
setLastTriggeredSession(null);
|
||||
setLastEmailSweepResult(null);
|
||||
}
|
||||
|
||||
function handleTriggerSession(startType: ChatSessionStartType) {
|
||||
if (!selectedUser) {
|
||||
return;
|
||||
}
|
||||
|
||||
setPendingTriggerType(startType);
|
||||
setLastTriggeredSession(null);
|
||||
triggerCopilotSessionMutation.mutate({
|
||||
data: {
|
||||
user_id: selectedUser.id,
|
||||
start_type: startType,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function handleSendPendingEmails() {
|
||||
if (!selectedUser) {
|
||||
return;
|
||||
}
|
||||
|
||||
setLastEmailSweepResult(null);
|
||||
sendPendingCopilotEmailsMutation.mutate(selectedUser.id);
|
||||
}
|
||||
|
||||
return {
|
||||
search,
|
||||
selectedUser,
|
||||
pendingTriggerType,
|
||||
lastTriggeredSession,
|
||||
lastEmailSweepResult,
|
||||
searchedUsers: searchUsersQuery.data?.users ?? [],
|
||||
searchErrorMessage: searchUsersQuery.error
|
||||
? getErrorMessage(searchUsersQuery.error)
|
||||
: null,
|
||||
isSearchingUsers: searchUsersQuery.isLoading,
|
||||
isRefreshingUsers:
|
||||
searchUsersQuery.isFetching && !searchUsersQuery.isLoading,
|
||||
isTriggeringSession: triggerCopilotSessionMutation.isPending,
|
||||
isSendingPendingEmails: sendPendingCopilotEmailsMutation.isPending,
|
||||
hasSearch: normalizedSearch.length > 0,
|
||||
setSearch,
|
||||
handleSelectUser,
|
||||
handleTriggerSession,
|
||||
handleSendPendingEmails,
|
||||
};
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
MagnifyingGlassIcon,
|
||||
FileTextIcon,
|
||||
SlidersHorizontalIcon,
|
||||
LightningIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
|
||||
const sidebarLinkGroups = [
|
||||
@@ -33,6 +34,11 @@ const sidebarLinkGroups = [
|
||||
href: "/admin/impersonation",
|
||||
icon: <MagnifyingGlassIcon size={24} />,
|
||||
},
|
||||
{
|
||||
text: "Copilot",
|
||||
href: "/admin/copilot",
|
||||
icon: <LightningIcon size={24} />,
|
||||
},
|
||||
{
|
||||
text: "Execution Analytics",
|
||||
href: "/admin/execution-analytics",
|
||||
|
||||
@@ -3,7 +3,6 @@ import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { ExclamationMarkIcon } from "@phosphor-icons/react";
|
||||
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { useState } from "react";
|
||||
import { ConnectIntegrationTool } from "../../../tools/ConnectIntegrationTool/ConnectIntegrationTool";
|
||||
import { CreateAgentTool } from "../../../tools/CreateAgent/CreateAgent";
|
||||
import { EditAgentTool } from "../../../tools/EditAgent/EditAgent";
|
||||
import {
|
||||
@@ -130,8 +129,6 @@ export function MessagePartRenderer({ part, messageID, partIndex }: Props) {
|
||||
case "tool-search_docs":
|
||||
case "tool-get_doc_page":
|
||||
return <SearchDocsTool key={key} part={part as ToolUIPart} />;
|
||||
case "tool-connect_integration":
|
||||
return <ConnectIntegrationTool key={key} part={part as ToolUIPart} />;
|
||||
case "tool-run_block":
|
||||
case "tool-continue_run_block":
|
||||
return <RunBlockTool key={key} part={part as ToolUIPart} />;
|
||||
|
||||
@@ -23,25 +23,22 @@ import {
|
||||
useSidebar,
|
||||
} from "@/components/ui/sidebar";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
CheckCircle,
|
||||
DotsThree,
|
||||
PlusCircleIcon,
|
||||
PlusIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { DotsThree, PlusCircleIcon, PlusIcon } from "@phosphor-icons/react";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import { motion } from "framer-motion";
|
||||
import { parseAsString, useQueryState } from "nuqs";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { getSessionListParams } from "../../helpers";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
import { SessionListItem } from "../SessionListItem/SessionListItem";
|
||||
import { NotificationToggle } from "./components/NotificationToggle/NotificationToggle";
|
||||
import { DeleteChatDialog } from "../DeleteChatDialog/DeleteChatDialog";
|
||||
import { PulseLoader } from "../PulseLoader/PulseLoader";
|
||||
|
||||
export function ChatSidebar() {
|
||||
const { state } = useSidebar();
|
||||
const isCollapsed = state === "collapsed";
|
||||
const [sessionId, setSessionId] = useQueryState("sessionId", parseAsString);
|
||||
const listSessionsParams = getSessionListParams();
|
||||
const {
|
||||
sessionToDelete,
|
||||
setSessionToDelete,
|
||||
@@ -52,7 +49,9 @@ export function ChatSidebar() {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
const { data: sessionsResponse, isLoading: isLoadingSessions } =
|
||||
useGetV2ListSessions({ limit: 50 }, { query: { refetchInterval: 10_000 } });
|
||||
useGetV2ListSessions(listSessionsParams, {
|
||||
query: { refetchInterval: 10_000 },
|
||||
});
|
||||
|
||||
const { mutate: deleteSession, isPending: isDeleting } =
|
||||
useDeleteV2DeleteSession({
|
||||
@@ -180,31 +179,6 @@ export function ChatSidebar() {
|
||||
}
|
||||
}
|
||||
|
||||
function formatDate(dateString: string) {
|
||||
const date = new Date(dateString);
|
||||
const now = new Date();
|
||||
const diffMs = now.getTime() - date.getTime();
|
||||
const diffDays = Math.floor(diffMs / (1000 * 60 * 60 * 24));
|
||||
|
||||
if (diffDays === 0) return "Today";
|
||||
if (diffDays === 1) return "Yesterday";
|
||||
if (diffDays < 7) return `${diffDays} days ago`;
|
||||
|
||||
const day = date.getDate();
|
||||
const ordinal =
|
||||
day % 10 === 1 && day !== 11
|
||||
? "st"
|
||||
: day % 10 === 2 && day !== 12
|
||||
? "nd"
|
||||
: day % 10 === 3 && day !== 13
|
||||
? "rd"
|
||||
: "th";
|
||||
const month = date.toLocaleDateString("en-US", { month: "short" });
|
||||
const year = date.getFullYear();
|
||||
|
||||
return `${day}${ordinal} ${month} ${year}`;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Sidebar
|
||||
@@ -295,17 +269,17 @@ export function ChatSidebar() {
|
||||
No conversations yet
|
||||
</p>
|
||||
) : (
|
||||
sessions.map((session) => (
|
||||
<div
|
||||
key={session.id}
|
||||
className={cn(
|
||||
"group relative w-full rounded-lg transition-colors",
|
||||
session.id === sessionId
|
||||
? "bg-zinc-100"
|
||||
: "hover:bg-zinc-50",
|
||||
)}
|
||||
>
|
||||
{editingSessionId === session.id ? (
|
||||
sessions.map((session) =>
|
||||
editingSessionId === session.id ? (
|
||||
<div
|
||||
key={session.id}
|
||||
className={cn(
|
||||
"group relative w-full rounded-lg transition-colors",
|
||||
session.id === sessionId
|
||||
? "bg-zinc-100"
|
||||
: "hover:bg-zinc-50",
|
||||
)}
|
||||
>
|
||||
<div className="px-3 py-2.5">
|
||||
<input
|
||||
ref={renameInputRef}
|
||||
@@ -331,87 +305,49 @@ export function ChatSidebar() {
|
||||
className="w-full rounded border border-zinc-300 bg-white px-2 py-1 text-sm text-zinc-800 outline-none focus:border-purple-500 focus:ring-1 focus:ring-purple-500"
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<button
|
||||
onClick={() => handleSelectSession(session.id)}
|
||||
className="w-full px-3 py-2.5 pr-10 text-left"
|
||||
>
|
||||
<div className="flex min-w-0 max-w-full items-center gap-2">
|
||||
<div className="min-w-0 flex-1">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"truncate font-normal",
|
||||
session.id === sessionId
|
||||
? "text-zinc-600"
|
||||
: "text-zinc-800",
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<SessionListItem
|
||||
key={session.id}
|
||||
session={session}
|
||||
currentSessionId={sessionId}
|
||||
isCompleted={completedSessionIDs.has(session.id)}
|
||||
onSelect={handleSelectSession}
|
||||
variant="sidebar"
|
||||
actionSlot={
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<button
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
className="rounded-full p-1.5 text-zinc-600 transition-all hover:bg-neutral-100"
|
||||
aria-label="More actions"
|
||||
>
|
||||
<AnimatePresence mode="wait" initial={false}>
|
||||
<motion.span
|
||||
key={session.title || "untitled"}
|
||||
initial={{ opacity: 0, y: 4 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
exit={{ opacity: 0, y: -4 }}
|
||||
transition={{ duration: 0.2 }}
|
||||
className="block truncate"
|
||||
>
|
||||
{session.title || "Untitled chat"}
|
||||
</motion.span>
|
||||
</AnimatePresence>
|
||||
</Text>
|
||||
<Text variant="small" className="text-neutral-400">
|
||||
{formatDate(session.updated_at)}
|
||||
</Text>
|
||||
</div>
|
||||
{session.is_processing &&
|
||||
session.id !== sessionId &&
|
||||
!completedSessionIDs.has(session.id) && (
|
||||
<PulseLoader size={16} className="shrink-0" />
|
||||
)}
|
||||
{completedSessionIDs.has(session.id) &&
|
||||
session.id !== sessionId && (
|
||||
<CheckCircle
|
||||
className="h-4 w-4 shrink-0 text-green-500"
|
||||
weight="fill"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</button>
|
||||
)}
|
||||
{editingSessionId !== session.id && (
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<button
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
className="absolute right-2 top-1/2 -translate-y-1/2 rounded-full p-1.5 text-zinc-600 transition-all hover:bg-neutral-100"
|
||||
aria-label="More actions"
|
||||
>
|
||||
<DotsThree className="h-4 w-4" />
|
||||
</button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem
|
||||
onClick={(e) =>
|
||||
handleRenameClick(e, session.id, session.title)
|
||||
}
|
||||
>
|
||||
Rename
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem
|
||||
onClick={(e) =>
|
||||
handleDeleteClick(e, session.id, session.title)
|
||||
}
|
||||
disabled={isDeleting}
|
||||
className="text-red-600 focus:bg-red-50 focus:text-red-600"
|
||||
>
|
||||
Delete chat
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
)}
|
||||
</div>
|
||||
))
|
||||
<DotsThree className="h-4 w-4" />
|
||||
</button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem
|
||||
onClick={(e) =>
|
||||
handleRenameClick(e, session.id, session.title)
|
||||
}
|
||||
>
|
||||
Rename
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem
|
||||
onClick={(e) =>
|
||||
handleDeleteClick(e, session.id, session.title)
|
||||
}
|
||||
disabled={isDeleting}
|
||||
className="text-red-600 focus:bg-red-50 focus:text-red-600"
|
||||
>
|
||||
Delete chat
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
}
|
||||
/>
|
||||
),
|
||||
)
|
||||
)}
|
||||
</motion.div>
|
||||
)}
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { scrollbarStyles } from "@/components/styles/scrollbars";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
CheckCircle,
|
||||
PlusIcon,
|
||||
SpeakerHigh,
|
||||
SpeakerSlash,
|
||||
@@ -12,8 +9,9 @@ import {
|
||||
X,
|
||||
} from "@phosphor-icons/react";
|
||||
import { Drawer } from "vaul";
|
||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
import { PulseLoader } from "../PulseLoader/PulseLoader";
|
||||
import { SessionListItem } from "../SessionListItem/SessionListItem";
|
||||
|
||||
interface Props {
|
||||
isOpen: boolean;
|
||||
@@ -26,31 +24,6 @@ interface Props {
|
||||
onOpenChange: (open: boolean) => void;
|
||||
}
|
||||
|
||||
function formatDate(dateString: string) {
|
||||
const date = new Date(dateString);
|
||||
const now = new Date();
|
||||
const diffMs = now.getTime() - date.getTime();
|
||||
const diffDays = Math.floor(diffMs / (1000 * 60 * 60 * 24));
|
||||
|
||||
if (diffDays === 0) return "Today";
|
||||
if (diffDays === 1) return "Yesterday";
|
||||
if (diffDays < 7) return `${diffDays} days ago`;
|
||||
|
||||
const day = date.getDate();
|
||||
const ordinal =
|
||||
day % 10 === 1 && day !== 11
|
||||
? "st"
|
||||
: day % 10 === 2 && day !== 12
|
||||
? "nd"
|
||||
: day % 10 === 3 && day !== 13
|
||||
? "rd"
|
||||
: "th";
|
||||
const month = date.toLocaleDateString("en-US", { month: "short" });
|
||||
const year = date.getFullYear();
|
||||
|
||||
return `${day}${ordinal} ${month} ${year}`;
|
||||
}
|
||||
|
||||
export function MobileDrawer({
|
||||
isOpen,
|
||||
sessions,
|
||||
@@ -134,52 +107,19 @@ export function MobileDrawer({
|
||||
</p>
|
||||
) : (
|
||||
sessions.map((session) => (
|
||||
<button
|
||||
<SessionListItem
|
||||
key={session.id}
|
||||
onClick={() => {
|
||||
onSelectSession(session.id);
|
||||
if (completedSessionIDs.has(session.id)) {
|
||||
clearCompletedSession(session.id);
|
||||
session={session}
|
||||
currentSessionId={currentSessionId}
|
||||
isCompleted={completedSessionIDs.has(session.id)}
|
||||
variant="drawer"
|
||||
onSelect={(selectedSessionId) => {
|
||||
onSelectSession(selectedSessionId);
|
||||
if (completedSessionIDs.has(selectedSessionId)) {
|
||||
clearCompletedSession(selectedSessionId);
|
||||
}
|
||||
}}
|
||||
className={cn(
|
||||
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
|
||||
session.id === currentSessionId
|
||||
? "bg-zinc-100"
|
||||
: "hover:bg-zinc-50",
|
||||
)}
|
||||
>
|
||||
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
|
||||
<div className="flex min-w-0 max-w-full items-center gap-1.5">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"truncate font-normal",
|
||||
session.id === currentSessionId
|
||||
? "text-zinc-600"
|
||||
: "text-zinc-800",
|
||||
)}
|
||||
>
|
||||
{session.title || "Untitled chat"}
|
||||
</Text>
|
||||
{session.is_processing &&
|
||||
!completedSessionIDs.has(session.id) &&
|
||||
session.id !== currentSessionId && (
|
||||
<PulseLoader size={8} className="shrink-0" />
|
||||
)}
|
||||
{completedSessionIDs.has(session.id) &&
|
||||
session.id !== currentSessionId && (
|
||||
<CheckCircle
|
||||
className="h-4 w-4 shrink-0 text-green-500"
|
||||
weight="fill"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<Text variant="small" className="text-neutral-400">
|
||||
{formatDate(session.updated_at)}
|
||||
</Text>
|
||||
</div>
|
||||
</button>
|
||||
/>
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
"use client";
|
||||
|
||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||
import { Badge } from "@/components/atoms/Badge/Badge";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { CheckCircle } from "@phosphor-icons/react";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import type { ReactNode } from "react";
|
||||
import {
|
||||
formatSessionDate,
|
||||
getSessionStartTypeLabel,
|
||||
isNonManualSessionStartType,
|
||||
} from "../../helpers";
|
||||
import { PulseLoader } from "../PulseLoader/PulseLoader";
|
||||
|
||||
interface Props {
|
||||
actionSlot?: ReactNode;
|
||||
currentSessionId: string | null;
|
||||
isCompleted: boolean;
|
||||
onSelect: (sessionId: string) => void;
|
||||
session: SessionSummaryResponse;
|
||||
variant?: "sidebar" | "drawer";
|
||||
}
|
||||
|
||||
export function SessionListItem({
|
||||
actionSlot,
|
||||
currentSessionId,
|
||||
isCompleted,
|
||||
onSelect,
|
||||
session,
|
||||
variant = "sidebar",
|
||||
}: Props) {
|
||||
const isActive = session.id === currentSessionId;
|
||||
const showProcessing = session.is_processing && !isCompleted && !isActive;
|
||||
const showCompleted = isCompleted && !isActive;
|
||||
const startTypeLabel = isNonManualSessionStartType(session.start_type)
|
||||
? getSessionStartTypeLabel(session.start_type)
|
||||
: null;
|
||||
|
||||
if (variant === "drawer") {
|
||||
return (
|
||||
<button
|
||||
onClick={() => onSelect(session.id)}
|
||||
className={cn(
|
||||
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
|
||||
isActive ? "bg-zinc-100" : "hover:bg-zinc-50",
|
||||
)}
|
||||
>
|
||||
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
|
||||
<div className="flex min-w-0 max-w-full items-center gap-1.5">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"truncate font-normal",
|
||||
isActive ? "text-zinc-600" : "text-zinc-800",
|
||||
)}
|
||||
>
|
||||
{session.title || "Untitled chat"}
|
||||
</Text>
|
||||
{showProcessing ? (
|
||||
<PulseLoader size={8} className="shrink-0" />
|
||||
) : null}
|
||||
{showCompleted ? (
|
||||
<CheckCircle
|
||||
className="h-4 w-4 shrink-0 text-green-500"
|
||||
weight="fill"
|
||||
/>
|
||||
) : null}
|
||||
</div>
|
||||
{startTypeLabel ? (
|
||||
<div className="mt-1">
|
||||
<Badge variant="info" size="small">
|
||||
{startTypeLabel}
|
||||
</Badge>
|
||||
</div>
|
||||
) : null}
|
||||
<Text variant="small" className="text-neutral-400">
|
||||
{formatSessionDate(session.updated_at)}
|
||||
</Text>
|
||||
</div>
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"group relative w-full rounded-lg transition-colors",
|
||||
isActive ? "bg-zinc-100" : "hover:bg-zinc-50",
|
||||
)}
|
||||
>
|
||||
<button
|
||||
onClick={() => onSelect(session.id)}
|
||||
className="w-full px-3 py-2.5 pr-10 text-left"
|
||||
>
|
||||
<div className="flex min-w-0 max-w-full items-center gap-2">
|
||||
<div className="min-w-0 flex-1">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"truncate font-normal",
|
||||
isActive ? "text-zinc-600" : "text-zinc-800",
|
||||
)}
|
||||
>
|
||||
<AnimatePresence mode="wait" initial={false}>
|
||||
<motion.span
|
||||
key={session.title || "untitled"}
|
||||
initial={{ opacity: 0, y: 4 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
exit={{ opacity: 0, y: -4 }}
|
||||
transition={{ duration: 0.2 }}
|
||||
className="block truncate"
|
||||
>
|
||||
{session.title || "Untitled chat"}
|
||||
</motion.span>
|
||||
</AnimatePresence>
|
||||
</Text>
|
||||
{startTypeLabel ? (
|
||||
<div className="mt-1">
|
||||
<Badge variant="info" size="small">
|
||||
{startTypeLabel}
|
||||
</Badge>
|
||||
</div>
|
||||
) : null}
|
||||
<Text variant="small" className="text-neutral-400">
|
||||
{formatSessionDate(session.updated_at)}
|
||||
</Text>
|
||||
</div>
|
||||
{showProcessing ? (
|
||||
<PulseLoader size={16} className="shrink-0" />
|
||||
) : null}
|
||||
{showCompleted ? (
|
||||
<CheckCircle
|
||||
className="h-4 w-4 shrink-0 text-green-500"
|
||||
weight="fill"
|
||||
/>
|
||||
) : null}
|
||||
</div>
|
||||
</button>
|
||||
{actionSlot ? (
|
||||
<div className="absolute right-2 top-1/2 -translate-y-1/2">
|
||||
{actionSlot}
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,5 +1,65 @@
|
||||
import type { GetV2ListSessionsParams } from "@/app/api/__generated__/models/getV2ListSessionsParams";
|
||||
import {
|
||||
ChatSessionStartType,
|
||||
type ChatSessionStartType as ChatSessionStartTypeValue,
|
||||
} from "@/app/api/__generated__/models/chatSessionStartType";
|
||||
import type { UIMessage } from "ai";
|
||||
|
||||
export const COPILOT_SESSION_LIST_LIMIT = 50;
|
||||
|
||||
export function getSessionListParams(): GetV2ListSessionsParams {
|
||||
return {
|
||||
limit: COPILOT_SESSION_LIST_LIMIT,
|
||||
with_auto: true,
|
||||
};
|
||||
}
|
||||
|
||||
export function isNonManualSessionStartType(
|
||||
startType: ChatSessionStartTypeValue | null | undefined,
|
||||
): boolean {
|
||||
return startType != null && startType !== ChatSessionStartType.MANUAL;
|
||||
}
|
||||
|
||||
export function getSessionStartTypeLabel(
|
||||
startType: ChatSessionStartTypeValue,
|
||||
): string | null {
|
||||
switch (startType) {
|
||||
case ChatSessionStartType.AUTOPILOT_NIGHTLY:
|
||||
return "Nightly";
|
||||
case ChatSessionStartType.AUTOPILOT_CALLBACK:
|
||||
return "Callback";
|
||||
case ChatSessionStartType.AUTOPILOT_INVITE_CTA:
|
||||
return "Invite CTA";
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export function formatSessionDate(dateString: string): string {
|
||||
const date = new Date(dateString);
|
||||
const now = new Date();
|
||||
const diffMs = now.getTime() - date.getTime();
|
||||
const diffDays = Math.floor(diffMs / (1000 * 60 * 60 * 24));
|
||||
|
||||
if (diffDays === 0) return "Today";
|
||||
if (diffDays === 1) return "Yesterday";
|
||||
if (diffDays < 7) return `${diffDays} days ago`;
|
||||
|
||||
const day = date.getDate();
|
||||
const ordinal =
|
||||
day % 10 === 1 && day !== 11
|
||||
? "st"
|
||||
: day % 10 === 2 && day !== 12
|
||||
? "nd"
|
||||
: day % 10 === 3 && day !== 13
|
||||
? "rd"
|
||||
: "th";
|
||||
const month = date.toLocaleDateString("en-US", { month: "short" });
|
||||
const year = date.getFullYear();
|
||||
|
||||
return `${day}${ordinal} ${month} ${year}`;
|
||||
}
|
||||
|
||||
/** Mark any in-progress tool parts as completed/errored so spinners stop. */
|
||||
export function resolveInProgressTools(
|
||||
messages: UIMessage[],
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import type { SetupRequirementsResponse } from "@/app/api/__generated__/models/setupRequirementsResponse";
|
||||
import type { ToolUIPart } from "ai";
|
||||
import { useState } from "react";
|
||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
||||
import { ContentMessage } from "../../components/ToolAccordion/AccordionContent";
|
||||
import { SetupRequirementsCard } from "../RunBlock/components/SetupRequirementsCard/SetupRequirementsCard";
|
||||
|
||||
type Props = {
|
||||
part: ToolUIPart;
|
||||
};
|
||||
|
||||
function parseJson(raw: unknown): unknown {
|
||||
if (typeof raw === "string") {
|
||||
try {
|
||||
return JSON.parse(raw);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return raw;
|
||||
}
|
||||
|
||||
function parseOutput(raw: unknown): SetupRequirementsResponse | null {
|
||||
const parsed = parseJson(raw);
|
||||
if (parsed && typeof parsed === "object" && "setup_info" in parsed) {
|
||||
return parsed as SetupRequirementsResponse;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function parseError(raw: unknown): string | null {
|
||||
const parsed = parseJson(raw);
|
||||
if (parsed && typeof parsed === "object" && "message" in parsed) {
|
||||
return String((parsed as { message: unknown }).message);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function ConnectIntegrationTool({ part }: Props) {
|
||||
// Persist dismissed state here so SetupRequirementsCard remounts don't re-enable Proceed.
|
||||
const [isDismissed, setIsDismissed] = useState(false);
|
||||
|
||||
const isStreaming =
|
||||
part.state === "input-streaming" || part.state === "input-available";
|
||||
const isError = part.state === "output-error";
|
||||
|
||||
const output =
|
||||
part.state === "output-available"
|
||||
? parseOutput((part as { output?: unknown }).output)
|
||||
: null;
|
||||
|
||||
const errorMessage = isError
|
||||
? (parseError((part as { output?: unknown }).output) ??
|
||||
"Failed to connect integration")
|
||||
: null;
|
||||
|
||||
const rawProvider =
|
||||
(part as { input?: { provider?: string } }).input?.provider ?? "";
|
||||
const providerName =
|
||||
output?.setup_info?.agent_name ??
|
||||
// Sanitize LLM-controlled provider slug: trim and cap at 64 chars to
|
||||
// prevent runaway text in the DOM.
|
||||
(rawProvider ? rawProvider.trim().slice(0, 64) : "integration");
|
||||
|
||||
const label = isStreaming
|
||||
? `Connecting ${providerName}…`
|
||||
: isError
|
||||
? `Failed to connect ${providerName}`
|
||||
: output
|
||||
? `Connect ${output.setup_info?.agent_name ?? providerName}`
|
||||
: `Connect ${providerName}`;
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<MorphingTextAnimation
|
||||
text={label}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{isError && errorMessage && (
|
||||
<p className="mt-1 text-sm text-red-500">{errorMessage}</p>
|
||||
)}
|
||||
|
||||
{output && (
|
||||
<div className="mt-2">
|
||||
{isDismissed ? (
|
||||
<ContentMessage>Connected. Continuing…</ContentMessage>
|
||||
) : (
|
||||
<SetupRequirementsCard
|
||||
output={output}
|
||||
credentialsLabel={`${output.setup_info?.agent_name ?? providerName} credentials`}
|
||||
retryInstruction="I've connected my account. Please continue."
|
||||
onComplete={() => setIsDismissed(true)}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -23,16 +23,12 @@ interface Props {
|
||||
/** Override the label shown above the credentials section.
|
||||
* Defaults to "Credentials". */
|
||||
credentialsLabel?: string;
|
||||
/** Called after Proceed is clicked so the parent can persist the dismissed state
|
||||
* across remounts (avoids re-enabling the Proceed button on remount). */
|
||||
onComplete?: () => void;
|
||||
}
|
||||
|
||||
export function SetupRequirementsCard({
|
||||
output,
|
||||
retryInstruction,
|
||||
credentialsLabel,
|
||||
onComplete,
|
||||
}: Props) {
|
||||
const { onSend } = useCopilotChatActions();
|
||||
|
||||
@@ -72,17 +68,13 @@ export function SetupRequirementsCard({
|
||||
return v !== undefined && v !== null && v !== "";
|
||||
});
|
||||
|
||||
if (hasSent) {
|
||||
return <ContentMessage>Connected. Continuing…</ContentMessage>;
|
||||
}
|
||||
|
||||
const canRun =
|
||||
!hasSent &&
|
||||
(!needsCredentials || isAllCredentialsComplete) &&
|
||||
(!needsInputs || isAllInputsComplete);
|
||||
|
||||
function handleRun() {
|
||||
setHasSent(true);
|
||||
onComplete?.();
|
||||
|
||||
const parts: string[] = [];
|
||||
if (needsCredentials) {
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
import { usePostV2ConsumeCallbackTokenRoute } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { parseAsString, useQueryState } from "nuqs";
|
||||
import { useEffect, useState } from "react";
|
||||
import { getGetV2ListSessionsQueryKey } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
|
||||
interface Props {
|
||||
isLoggedIn: boolean;
|
||||
onConsumed: (sessionId: string) => void;
|
||||
onClearAutopilot: () => void;
|
||||
}
|
||||
|
||||
export function useCallbackToken({
|
||||
isLoggedIn,
|
||||
onConsumed,
|
||||
onClearAutopilot,
|
||||
}: Props) {
|
||||
const queryClient = useQueryClient();
|
||||
const [callbackToken, setCallbackToken] = useQueryState(
|
||||
"callbackToken",
|
||||
parseAsString,
|
||||
);
|
||||
const [consumedTokens, setConsumedTokens] = useState<Set<string>>(
|
||||
() => new Set(),
|
||||
);
|
||||
const { mutateAsync: consumeCallbackToken, isPending } =
|
||||
usePostV2ConsumeCallbackTokenRoute();
|
||||
|
||||
const hasConsumedToken =
|
||||
callbackToken != null && consumedTokens.has(callbackToken);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isLoggedIn || !callbackToken || hasConsumedToken) {
|
||||
return;
|
||||
}
|
||||
|
||||
let isCancelled = false;
|
||||
const token = callbackToken;
|
||||
setConsumedTokens((current) => new Set(current).add(token));
|
||||
|
||||
void consumeCallbackToken({ data: { token } })
|
||||
.then((response) => {
|
||||
if (isCancelled) {
|
||||
return;
|
||||
}
|
||||
if (response.status !== 200 || !response.data?.session_id) {
|
||||
throw new Error("Failed to open callback session");
|
||||
}
|
||||
|
||||
onConsumed(response.data.session_id);
|
||||
onClearAutopilot();
|
||||
void setCallbackToken(null);
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
});
|
||||
})
|
||||
.catch((error) => {
|
||||
if (isCancelled) {
|
||||
return;
|
||||
}
|
||||
setConsumedTokens((current) => {
|
||||
const next = new Set(current);
|
||||
next.delete(token);
|
||||
return next;
|
||||
});
|
||||
void setCallbackToken(null);
|
||||
toast({
|
||||
title: "Unable to open callback session",
|
||||
description:
|
||||
error instanceof Error ? error.message : "Please try again.",
|
||||
variant: "destructive",
|
||||
});
|
||||
});
|
||||
|
||||
return () => {
|
||||
isCancelled = true;
|
||||
};
|
||||
}, [
|
||||
callbackToken,
|
||||
consumeCallbackToken,
|
||||
hasConsumedToken,
|
||||
isLoggedIn,
|
||||
onClearAutopilot,
|
||||
onConsumed,
|
||||
queryClient,
|
||||
setCallbackToken,
|
||||
]);
|
||||
|
||||
return {
|
||||
isConsumingCallbackToken: isPending,
|
||||
};
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import {
|
||||
useGetV2GetSession,
|
||||
usePostV2CreateSession,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import type { ChatSessionStartType } from "@/app/api/__generated__/models/chatSessionStartType";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
@@ -70,6 +71,14 @@ export function useChatSession() {
|
||||
);
|
||||
}, [sessionQuery.data, sessionId, hasActiveStream]);
|
||||
|
||||
const sessionStartType = useMemo<ChatSessionStartType | null>(() => {
|
||||
if (sessionQuery.data?.status !== 200) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return sessionQuery.data.data.start_type;
|
||||
}, [sessionQuery.data]);
|
||||
|
||||
const { mutateAsync: createSessionMutation, isPending: isCreatingSession } =
|
||||
usePostV2CreateSession({
|
||||
mutation: {
|
||||
@@ -121,6 +130,7 @@ export function useChatSession() {
|
||||
return {
|
||||
sessionId,
|
||||
setSessionId,
|
||||
sessionStartType,
|
||||
hydratedMessages,
|
||||
hasActiveStream,
|
||||
isLoadingSession: sessionQuery.isLoading,
|
||||
|
||||
@@ -2,34 +2,24 @@ import {
|
||||
getGetV2ListSessionsQueryKey,
|
||||
useDeleteV2DeleteSession,
|
||||
useGetV2ListSessions,
|
||||
type getV2ListSessionsResponse,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { uploadFileDirect } from "@/lib/direct-upload";
|
||||
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import type { FileUIPart } from "ai";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { getSessionListParams } from "./helpers";
|
||||
import { useCopilotUIStore } from "./store";
|
||||
import { useCallbackToken } from "./useCallbackToken";
|
||||
import { useChatSession } from "./useChatSession";
|
||||
import { useFileUpload } from "./useFileUpload";
|
||||
import { useCopilotNotifications } from "./useCopilotNotifications";
|
||||
import { useCopilotStream } from "./useCopilotStream";
|
||||
|
||||
const TITLE_POLL_INTERVAL_MS = 2_000;
|
||||
const TITLE_POLL_MAX_ATTEMPTS = 5;
|
||||
|
||||
interface UploadedFile {
|
||||
file_id: string;
|
||||
name: string;
|
||||
mime_type: string;
|
||||
}
|
||||
import { useTitlePolling } from "./useTitlePolling";
|
||||
|
||||
export function useCopilotPage() {
|
||||
const { isUserLoading, isLoggedIn } = useSupabase();
|
||||
const [isUploadingFiles, setIsUploadingFiles] = useState(false);
|
||||
const [pendingMessage, setPendingMessage] = useState<string | null>(null);
|
||||
const queryClient = useQueryClient();
|
||||
const listSessionsParams = getSessionListParams();
|
||||
|
||||
const { sessionToDelete, setSessionToDelete, isDrawerOpen, setDrawerOpen } =
|
||||
useCopilotUIStore();
|
||||
@@ -39,7 +29,7 @@ export function useCopilotPage() {
|
||||
setSessionId,
|
||||
hydratedMessages,
|
||||
hasActiveStream,
|
||||
isLoadingSession,
|
||||
isLoadingSession: isLoadingCurrentSession,
|
||||
isSessionError,
|
||||
createSession,
|
||||
isCreatingSession,
|
||||
@@ -93,198 +83,33 @@ export function useCopilotPage() {
|
||||
const isMobile =
|
||||
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
||||
|
||||
const pendingFilesRef = useRef<File[]>([]);
|
||||
const { isConsumingCallbackToken } = useCallbackToken({
|
||||
isLoggedIn,
|
||||
onConsumed: setSessionId,
|
||||
onClearAutopilot() {},
|
||||
});
|
||||
|
||||
// --- Send pending message after session creation ---
|
||||
useEffect(() => {
|
||||
if (!sessionId || pendingMessage === null) return;
|
||||
const msg = pendingMessage;
|
||||
const files = pendingFilesRef.current;
|
||||
setPendingMessage(null);
|
||||
pendingFilesRef.current = [];
|
||||
|
||||
if (files.length > 0) {
|
||||
setIsUploadingFiles(true);
|
||||
void uploadFiles(files, sessionId)
|
||||
.then((uploaded) => {
|
||||
if (uploaded.length === 0) {
|
||||
toast({
|
||||
title: "File upload failed",
|
||||
description: "Could not upload any files. Please try again.",
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
const fileParts = buildFileParts(uploaded);
|
||||
sendMessage({
|
||||
text: msg,
|
||||
files: fileParts.length > 0 ? fileParts : undefined,
|
||||
});
|
||||
})
|
||||
.finally(() => setIsUploadingFiles(false));
|
||||
} else {
|
||||
sendMessage({ text: msg });
|
||||
}
|
||||
}, [sessionId, pendingMessage, sendMessage]);
|
||||
|
||||
async function uploadFiles(
|
||||
files: File[],
|
||||
sid: string,
|
||||
): Promise<UploadedFile[]> {
|
||||
const results = await Promise.allSettled(
|
||||
files.map(async (file) => {
|
||||
try {
|
||||
const data = await uploadFileDirect(file, sid);
|
||||
if (!data.file_id) throw new Error("No file_id returned");
|
||||
return {
|
||||
file_id: data.file_id,
|
||||
name: data.name || file.name,
|
||||
mime_type: data.mime_type || "application/octet-stream",
|
||||
} as UploadedFile;
|
||||
} catch (err) {
|
||||
console.error("File upload failed:", err);
|
||||
toast({
|
||||
title: "File upload failed",
|
||||
description: file.name,
|
||||
variant: "destructive",
|
||||
});
|
||||
throw err;
|
||||
}
|
||||
}),
|
||||
);
|
||||
return results
|
||||
.filter(
|
||||
(r): r is PromiseFulfilledResult<UploadedFile> =>
|
||||
r.status === "fulfilled",
|
||||
)
|
||||
.map((r) => r.value);
|
||||
}
|
||||
|
||||
function buildFileParts(uploaded: UploadedFile[]): FileUIPart[] {
|
||||
return uploaded.map((f) => ({
|
||||
type: "file" as const,
|
||||
mediaType: f.mime_type,
|
||||
filename: f.name,
|
||||
url: `/api/proxy/api/workspace/files/${f.file_id}/download`,
|
||||
}));
|
||||
}
|
||||
|
||||
async function onSend(message: string, files?: File[]) {
|
||||
const trimmed = message.trim();
|
||||
if (!trimmed && (!files || files.length === 0)) return;
|
||||
|
||||
// Client-side file limits
|
||||
if (files && files.length > 0) {
|
||||
const MAX_FILES = 10;
|
||||
const MAX_FILE_SIZE_BYTES = 100 * 1024 * 1024; // 100 MB
|
||||
|
||||
if (files.length > MAX_FILES) {
|
||||
toast({
|
||||
title: "Too many files",
|
||||
description: `You can attach up to ${MAX_FILES} files at once.`,
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const oversized = files.filter((f) => f.size > MAX_FILE_SIZE_BYTES);
|
||||
if (oversized.length > 0) {
|
||||
toast({
|
||||
title: "File too large",
|
||||
description: `${oversized[0].name} exceeds the 100 MB limit.`,
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
isUserStoppingRef.current = false;
|
||||
|
||||
if (sessionId) {
|
||||
if (files && files.length > 0) {
|
||||
setIsUploadingFiles(true);
|
||||
try {
|
||||
const uploaded = await uploadFiles(files, sessionId);
|
||||
if (uploaded.length === 0) {
|
||||
// All uploads failed — abort send so chips revert to editable
|
||||
throw new Error("All file uploads failed");
|
||||
}
|
||||
const fileParts = buildFileParts(uploaded);
|
||||
sendMessage({
|
||||
text: trimmed || "",
|
||||
files: fileParts.length > 0 ? fileParts : undefined,
|
||||
});
|
||||
} finally {
|
||||
setIsUploadingFiles(false);
|
||||
}
|
||||
} else {
|
||||
sendMessage({ text: trimmed });
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
setPendingMessage(trimmed || "");
|
||||
if (files && files.length > 0) {
|
||||
pendingFilesRef.current = files;
|
||||
}
|
||||
await createSession();
|
||||
}
|
||||
const { isUploadingFiles, onSend } = useFileUpload({
|
||||
createSession,
|
||||
isUserStoppingRef,
|
||||
sendMessage,
|
||||
sessionId,
|
||||
});
|
||||
|
||||
// --- Session list (for mobile drawer & sidebar) ---
|
||||
const { data: sessionsResponse, isLoading: isLoadingSessions } =
|
||||
useGetV2ListSessions(
|
||||
{ limit: 50 },
|
||||
{ query: { enabled: !isUserLoading && isLoggedIn } },
|
||||
);
|
||||
useGetV2ListSessions(listSessionsParams, {
|
||||
query: { enabled: !isUserLoading && isLoggedIn },
|
||||
});
|
||||
|
||||
const sessions =
|
||||
sessionsResponse?.status === 200 ? sessionsResponse.data.sessions : [];
|
||||
|
||||
// Start title polling when stream ends cleanly — sidebar title animates in
|
||||
const titlePollRef = useRef<ReturnType<typeof setInterval>>();
|
||||
const prevStatusRef = useRef(status);
|
||||
|
||||
useEffect(() => {
|
||||
const prev = prevStatusRef.current;
|
||||
prevStatusRef.current = status;
|
||||
|
||||
const wasActive = prev === "streaming" || prev === "submitted";
|
||||
const isNowReady = status === "ready";
|
||||
|
||||
if (!wasActive || !isNowReady || !sessionId || isReconnecting) return;
|
||||
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey({ limit: 50 }),
|
||||
});
|
||||
const sid = sessionId;
|
||||
let attempts = 0;
|
||||
clearInterval(titlePollRef.current);
|
||||
titlePollRef.current = setInterval(() => {
|
||||
const data = queryClient.getQueryData<getV2ListSessionsResponse>(
|
||||
getGetV2ListSessionsQueryKey({ limit: 50 }),
|
||||
);
|
||||
const hasTitle =
|
||||
data?.status === 200 &&
|
||||
data.data.sessions.some((s) => s.id === sid && s.title);
|
||||
if (hasTitle || attempts >= TITLE_POLL_MAX_ATTEMPTS) {
|
||||
clearInterval(titlePollRef.current);
|
||||
titlePollRef.current = undefined;
|
||||
return;
|
||||
}
|
||||
attempts += 1;
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey({ limit: 50 }),
|
||||
});
|
||||
}, TITLE_POLL_INTERVAL_MS);
|
||||
}, [status, sessionId, isReconnecting, queryClient]);
|
||||
|
||||
// Clean up polling on session change or unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
clearInterval(titlePollRef.current);
|
||||
titlePollRef.current = undefined;
|
||||
};
|
||||
}, [sessionId]);
|
||||
useTitlePolling({
|
||||
isReconnecting,
|
||||
sessionId,
|
||||
status,
|
||||
});
|
||||
|
||||
// --- Mobile drawer handlers ---
|
||||
function handleOpenDrawer() {
|
||||
@@ -334,7 +159,7 @@ export function useCopilotPage() {
|
||||
error,
|
||||
stop,
|
||||
isReconnecting,
|
||||
isLoadingSession,
|
||||
isLoadingSession: isLoadingCurrentSession || isConsumingCallbackToken,
|
||||
isSessionError,
|
||||
isCreatingSession,
|
||||
isUploadingFiles,
|
||||
|
||||
@@ -0,0 +1,178 @@
|
||||
import { uploadFileDirect } from "@/lib/direct-upload";
|
||||
import type { FileUIPart } from "ai";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useEffect, useRef, useState, type MutableRefObject } from "react";
|
||||
|
||||
const MAX_FILES = 10;
|
||||
const MAX_FILE_SIZE_BYTES = 100 * 1024 * 1024;
|
||||
|
||||
interface UploadedFile {
|
||||
file_id: string;
|
||||
name: string;
|
||||
mime_type: string;
|
||||
}
|
||||
|
||||
interface SendMessageInput {
|
||||
text: string;
|
||||
files?: FileUIPart[];
|
||||
}
|
||||
|
||||
interface Props {
|
||||
createSession: () => Promise<unknown>;
|
||||
isUserStoppingRef: MutableRefObject<boolean>;
|
||||
sendMessage: (input: SendMessageInput) => void;
|
||||
sessionId: string | null;
|
||||
}
|
||||
|
||||
async function uploadFiles(
|
||||
files: File[],
|
||||
sessionId: string,
|
||||
): Promise<UploadedFile[]> {
|
||||
const results = await Promise.allSettled(
|
||||
files.map(async (file) => {
|
||||
try {
|
||||
const data = await uploadFileDirect(file, sessionId);
|
||||
if (!data.file_id) throw new Error("No file_id returned");
|
||||
return {
|
||||
file_id: data.file_id,
|
||||
name: data.name || file.name,
|
||||
mime_type: data.mime_type || "application/octet-stream",
|
||||
} satisfies UploadedFile;
|
||||
} catch (error) {
|
||||
console.error("File upload failed:", error);
|
||||
toast({
|
||||
title: "File upload failed",
|
||||
description: file.name,
|
||||
variant: "destructive",
|
||||
});
|
||||
throw error;
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
return results
|
||||
.filter(
|
||||
(result): result is PromiseFulfilledResult<UploadedFile> =>
|
||||
result.status === "fulfilled",
|
||||
)
|
||||
.map((result) => result.value);
|
||||
}
|
||||
|
||||
function buildFileParts(uploaded: UploadedFile[]): FileUIPart[] {
|
||||
return uploaded.map((file) => ({
|
||||
type: "file" as const,
|
||||
mediaType: file.mime_type,
|
||||
filename: file.name,
|
||||
url: `/api/proxy/api/workspace/files/${file.file_id}/download`,
|
||||
}));
|
||||
}
|
||||
|
||||
export function useFileUpload({
|
||||
createSession,
|
||||
isUserStoppingRef,
|
||||
sendMessage,
|
||||
sessionId,
|
||||
}: Props) {
|
||||
const [isUploadingFiles, setIsUploadingFiles] = useState(false);
|
||||
const [pendingMessage, setPendingMessage] = useState<string | null>(null);
|
||||
const pendingFilesRef = useRef<File[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!sessionId || pendingMessage === null) {
|
||||
return;
|
||||
}
|
||||
|
||||
const message = pendingMessage;
|
||||
const files = pendingFilesRef.current;
|
||||
setPendingMessage(null);
|
||||
pendingFilesRef.current = [];
|
||||
|
||||
if (files.length === 0) {
|
||||
sendMessage({ text: message });
|
||||
return;
|
||||
}
|
||||
|
||||
setIsUploadingFiles(true);
|
||||
void uploadFiles(files, sessionId)
|
||||
.then((uploaded) => {
|
||||
if (uploaded.length === 0) {
|
||||
toast({
|
||||
title: "File upload failed",
|
||||
description: "Could not upload any files. Please try again.",
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const fileParts = buildFileParts(uploaded);
|
||||
sendMessage({
|
||||
text: message,
|
||||
files: fileParts.length > 0 ? fileParts : undefined,
|
||||
});
|
||||
})
|
||||
.finally(() => setIsUploadingFiles(false));
|
||||
}, [pendingMessage, sendMessage, sessionId]);
|
||||
|
||||
async function onSend(message: string, files?: File[]) {
|
||||
const trimmed = message.trim();
|
||||
if (!trimmed && (!files || files.length === 0)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (files && files.length > 0) {
|
||||
if (files.length > MAX_FILES) {
|
||||
toast({
|
||||
title: "Too many files",
|
||||
description: `You can attach up to ${MAX_FILES} files at once.`,
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const oversized = files.filter((file) => file.size > MAX_FILE_SIZE_BYTES);
|
||||
if (oversized.length > 0) {
|
||||
toast({
|
||||
title: "File too large",
|
||||
description: `${oversized[0].name} exceeds the 100 MB limit.`,
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
isUserStoppingRef.current = false;
|
||||
|
||||
if (sessionId) {
|
||||
if (!files || files.length === 0) {
|
||||
sendMessage({ text: trimmed });
|
||||
return;
|
||||
}
|
||||
|
||||
setIsUploadingFiles(true);
|
||||
try {
|
||||
const uploaded = await uploadFiles(files, sessionId);
|
||||
if (uploaded.length === 0) {
|
||||
throw new Error("All file uploads failed");
|
||||
}
|
||||
|
||||
const fileParts = buildFileParts(uploaded);
|
||||
sendMessage({
|
||||
text: trimmed || "",
|
||||
files: fileParts.length > 0 ? fileParts : undefined,
|
||||
});
|
||||
} finally {
|
||||
setIsUploadingFiles(false);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
setPendingMessage(trimmed || "");
|
||||
pendingFilesRef.current = files ?? [];
|
||||
await createSession();
|
||||
}
|
||||
|
||||
return {
|
||||
isUploadingFiles,
|
||||
onSend,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
import {
|
||||
getGetV2ListSessionsQueryKey,
|
||||
type getV2ListSessionsResponse,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useEffect, useRef } from "react";
|
||||
import { getSessionListParams } from "./helpers";
|
||||
|
||||
const TITLE_POLL_INTERVAL_MS = 2_000;
|
||||
const TITLE_POLL_MAX_ATTEMPTS = 5;
|
||||
|
||||
interface Props {
|
||||
isReconnecting: boolean;
|
||||
sessionId: string | null;
|
||||
status: string;
|
||||
}
|
||||
|
||||
export function useTitlePolling({ isReconnecting, sessionId, status }: Props) {
|
||||
const queryClient = useQueryClient();
|
||||
const previousStatusRef = useRef(status);
|
||||
|
||||
useEffect(() => {
|
||||
const previousStatus = previousStatusRef.current;
|
||||
previousStatusRef.current = status;
|
||||
|
||||
const wasActive =
|
||||
previousStatus === "streaming" || previousStatus === "submitted";
|
||||
const isNowReady = status === "ready";
|
||||
|
||||
if (!wasActive || !isNowReady || !sessionId || isReconnecting) {
|
||||
return;
|
||||
}
|
||||
|
||||
const params = getSessionListParams();
|
||||
const queryKey = getGetV2ListSessionsQueryKey(params);
|
||||
let attempts = 0;
|
||||
let timeoutId: ReturnType<typeof setTimeout> | undefined;
|
||||
let isCancelled = false;
|
||||
|
||||
const poll = () => {
|
||||
if (isCancelled) {
|
||||
return;
|
||||
}
|
||||
|
||||
const data =
|
||||
queryClient.getQueryData<getV2ListSessionsResponse>(queryKey);
|
||||
const hasTitle =
|
||||
data?.status === 200 &&
|
||||
data.data.sessions.some(
|
||||
(session) => session.id === sessionId && session.title,
|
||||
);
|
||||
|
||||
if (hasTitle || attempts >= TITLE_POLL_MAX_ATTEMPTS) {
|
||||
return;
|
||||
}
|
||||
|
||||
attempts += 1;
|
||||
queryClient.invalidateQueries({ queryKey });
|
||||
timeoutId = setTimeout(poll, TITLE_POLL_INTERVAL_MS);
|
||||
};
|
||||
|
||||
queryClient.invalidateQueries({ queryKey });
|
||||
timeoutId = setTimeout(poll, TITLE_POLL_INTERVAL_MS);
|
||||
|
||||
return () => {
|
||||
isCancelled = true;
|
||||
if (timeoutId) {
|
||||
clearTimeout(timeoutId);
|
||||
}
|
||||
};
|
||||
}, [isReconnecting, queryClient, sessionId, status]);
|
||||
}
|
||||
@@ -1030,6 +1030,16 @@
|
||||
"default": 0,
|
||||
"title": "Offset"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "with_auto",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"title": "With Auto"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
@@ -1079,6 +1089,47 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/sessions/callback-token/consume": {
|
||||
"post": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Consume Callback Token Route",
|
||||
"operationId": "postV2ConsumeCallbackTokenRoute",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ConsumeCallbackTokenRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ConsumeCallbackTokenResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/chat/sessions/{session_id}": {
|
||||
"delete": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
@@ -6670,6 +6721,145 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/users/admin/copilot/send-emails": {
|
||||
"post": {
|
||||
"tags": ["v2", "admin", "users", "admin"],
|
||||
"summary": "Send Pending Copilot Emails",
|
||||
"operationId": "postV2SendPendingCopilotEmails",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/SendCopilotEmailsRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/SendCopilotEmailsResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/users/admin/copilot/trigger": {
|
||||
"post": {
|
||||
"tags": ["v2", "admin", "users", "admin"],
|
||||
"summary": "Trigger Copilot Session",
|
||||
"operationId": "postV2TriggerCopilotSession",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/TriggerCopilotSessionRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/TriggerCopilotSessionResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/users/admin/copilot/users": {
|
||||
"get": {
|
||||
"tags": ["v2", "admin", "users", "admin"],
|
||||
"summary": "Search Copilot Users",
|
||||
"operationId": "getV2SearchCopilotUsers",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "search",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"description": "Search by email, name, or user ID",
|
||||
"default": "",
|
||||
"title": "Search"
|
||||
},
|
||||
"description": "Search by email, name, or user ID"
|
||||
},
|
||||
{
|
||||
"name": "limit",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "integer",
|
||||
"maximum": 50,
|
||||
"minimum": 1,
|
||||
"default": 20,
|
||||
"title": "Limit"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/AdminCopilotUsersResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/users/admin/invited-users": {
|
||||
"get": {
|
||||
"tags": ["v2", "admin", "users", "admin"],
|
||||
@@ -7270,6 +7460,42 @@
|
||||
"required": ["new_balance", "transaction_key"],
|
||||
"title": "AddUserCreditsResponse"
|
||||
},
|
||||
"AdminCopilotUserSummary": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
"email": { "type": "string", "title": "Email" },
|
||||
"name": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Name"
|
||||
},
|
||||
"timezone": { "type": "string", "title": "Timezone" },
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"title": "Created At"
|
||||
},
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"title": "Updated At"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["id", "email", "timezone", "created_at", "updated_at"],
|
||||
"title": "AdminCopilotUserSummary"
|
||||
},
|
||||
"AdminCopilotUsersResponse": {
|
||||
"properties": {
|
||||
"users": {
|
||||
"items": { "$ref": "#/components/schemas/AdminCopilotUserSummary" },
|
||||
"type": "array",
|
||||
"title": "Users"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["users"],
|
||||
"title": "AdminCopilotUsersResponse"
|
||||
},
|
||||
"AgentDetails": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
@@ -7283,14 +7509,12 @@
|
||||
"inputs": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Inputs",
|
||||
"default": {}
|
||||
"title": "Inputs"
|
||||
},
|
||||
"credentials": {
|
||||
"items": { "$ref": "#/components/schemas/CredentialsMetaInput" },
|
||||
"type": "array",
|
||||
"title": "Credentials",
|
||||
"default": []
|
||||
"title": "Credentials"
|
||||
},
|
||||
"execution_options": {
|
||||
"$ref": "#/components/schemas/ExecutionOptions"
|
||||
@@ -7867,20 +8091,17 @@
|
||||
"inputs": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Inputs",
|
||||
"default": {}
|
||||
"title": "Inputs"
|
||||
},
|
||||
"outputs": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Outputs",
|
||||
"default": {}
|
||||
"title": "Outputs"
|
||||
},
|
||||
"credentials": {
|
||||
"items": { "$ref": "#/components/schemas/CredentialsMetaInput" },
|
||||
"type": "array",
|
||||
"title": "Credentials",
|
||||
"default": []
|
||||
"title": "Credentials"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
@@ -8419,6 +8640,16 @@
|
||||
"required": ["query", "conversation_history", "message_id"],
|
||||
"title": "ChatRequest"
|
||||
},
|
||||
"ChatSessionStartType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"MANUAL",
|
||||
"AUTOPILOT_NIGHTLY",
|
||||
"AUTOPILOT_CALLBACK",
|
||||
"AUTOPILOT_INVITE_CTA"
|
||||
],
|
||||
"title": "ChatSessionStartType"
|
||||
},
|
||||
"ClarificationNeededResponse": {
|
||||
"properties": {
|
||||
"type": {
|
||||
@@ -8455,6 +8686,20 @@
|
||||
"title": "ClarifyingQuestion",
|
||||
"description": "A question that needs user clarification."
|
||||
},
|
||||
"ConsumeCallbackTokenRequest": {
|
||||
"properties": { "token": { "type": "string", "title": "Token" } },
|
||||
"type": "object",
|
||||
"required": ["token"],
|
||||
"title": "ConsumeCallbackTokenRequest"
|
||||
},
|
||||
"ConsumeCallbackTokenResponse": {
|
||||
"properties": {
|
||||
"session_id": { "type": "string", "title": "Session Id" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["session_id"],
|
||||
"title": "ConsumeCallbackTokenResponse"
|
||||
},
|
||||
"ContentType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -10878,8 +11123,7 @@
|
||||
"suggestions": {
|
||||
"items": { "type": "string" },
|
||||
"type": "array",
|
||||
"title": "Suggestions",
|
||||
"default": []
|
||||
"title": "Suggestions"
|
||||
},
|
||||
"name": { "type": "string", "title": "Name", "default": "no_results" }
|
||||
},
|
||||
@@ -11915,6 +12159,7 @@
|
||||
"error",
|
||||
"no_results",
|
||||
"need_login",
|
||||
"completion_report_saved",
|
||||
"agents_found",
|
||||
"agent_details",
|
||||
"setup_requirements",
|
||||
@@ -12171,6 +12416,37 @@
|
||||
"required": ["items", "search_id", "total_items", "pagination"],
|
||||
"title": "SearchResponse"
|
||||
},
|
||||
"SendCopilotEmailsRequest": {
|
||||
"properties": { "user_id": { "type": "string", "title": "User Id" } },
|
||||
"type": "object",
|
||||
"required": ["user_id"],
|
||||
"title": "SendCopilotEmailsRequest"
|
||||
},
|
||||
"SendCopilotEmailsResponse": {
|
||||
"properties": {
|
||||
"candidate_count": { "type": "integer", "title": "Candidate Count" },
|
||||
"processed_count": { "type": "integer", "title": "Processed Count" },
|
||||
"sent_count": { "type": "integer", "title": "Sent Count" },
|
||||
"skipped_count": { "type": "integer", "title": "Skipped Count" },
|
||||
"repair_queued_count": {
|
||||
"type": "integer",
|
||||
"title": "Repair Queued Count"
|
||||
},
|
||||
"running_count": { "type": "integer", "title": "Running Count" },
|
||||
"failed_count": { "type": "integer", "title": "Failed Count" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"candidate_count",
|
||||
"processed_count",
|
||||
"sent_count",
|
||||
"skipped_count",
|
||||
"repair_queued_count",
|
||||
"running_count",
|
||||
"failed_count"
|
||||
],
|
||||
"title": "SendCopilotEmailsResponse"
|
||||
},
|
||||
"SessionDetailResponse": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
@@ -12180,6 +12456,11 @@
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "User Id"
|
||||
},
|
||||
"start_type": { "$ref": "#/components/schemas/ChatSessionStartType" },
|
||||
"execution_tag": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Execution Tag"
|
||||
},
|
||||
"messages": {
|
||||
"items": { "additionalProperties": true, "type": "object" },
|
||||
"type": "array",
|
||||
@@ -12193,7 +12474,14 @@
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["id", "created_at", "updated_at", "user_id", "messages"],
|
||||
"required": [
|
||||
"id",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"user_id",
|
||||
"start_type",
|
||||
"messages"
|
||||
],
|
||||
"title": "SessionDetailResponse",
|
||||
"description": "Response model providing complete details for a chat session, including messages."
|
||||
},
|
||||
@@ -12206,10 +12494,21 @@
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Title"
|
||||
},
|
||||
"start_type": { "$ref": "#/components/schemas/ChatSessionStartType" },
|
||||
"execution_tag": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Execution Tag"
|
||||
},
|
||||
"is_processing": { "type": "boolean", "title": "Is Processing" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["id", "created_at", "updated_at", "is_processing"],
|
||||
"required": [
|
||||
"id",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"start_type",
|
||||
"is_processing"
|
||||
],
|
||||
"title": "SessionSummaryResponse",
|
||||
"description": "Response model for a session summary (without messages)."
|
||||
},
|
||||
@@ -13840,6 +14139,24 @@
|
||||
"required": ["transactions", "next_transaction_time"],
|
||||
"title": "TransactionHistory"
|
||||
},
|
||||
"TriggerCopilotSessionRequest": {
|
||||
"properties": {
|
||||
"user_id": { "type": "string", "title": "User Id" },
|
||||
"start_type": { "$ref": "#/components/schemas/ChatSessionStartType" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["user_id", "start_type"],
|
||||
"title": "TriggerCopilotSessionRequest"
|
||||
},
|
||||
"TriggerCopilotSessionResponse": {
|
||||
"properties": {
|
||||
"session_id": { "type": "string", "title": "Session Id" },
|
||||
"start_type": { "$ref": "#/components/schemas/ChatSessionStartType" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["session_id", "start_type"],
|
||||
"title": "TriggerCopilotSessionResponse"
|
||||
},
|
||||
"TriggeredPresetSetupRequest": {
|
||||
"properties": {
|
||||
"name": { "type": "string", "title": "Name" },
|
||||
@@ -14771,8 +15088,7 @@
|
||||
"missing_credentials": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Missing Credentials",
|
||||
"default": {}
|
||||
"title": "Missing Credentials"
|
||||
},
|
||||
"ready_to_run": {
|
||||
"type": "boolean",
|
||||
|
||||
@@ -125,9 +125,9 @@ export function useCredentialsInput({
|
||||
if (hasAttemptedAutoSelect.current) return;
|
||||
hasAttemptedAutoSelect.current = true;
|
||||
|
||||
// Auto-select only when there is exactly one saved credential.
|
||||
// With multiple options the user must choose — regardless of optional/required.
|
||||
if (savedCreds.length > 1) return;
|
||||
// Auto-select if exactly one credential matches.
|
||||
// For optional fields with multiple options, let the user choose.
|
||||
if (isOptional && savedCreds.length > 1) return;
|
||||
|
||||
const cred = savedCreds[0];
|
||||
onSelectCredential({
|
||||
|
||||
Reference in New Issue
Block a user