mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
15 Commits
fix/copilo
...
swiftyos/n
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b6903b70d | ||
|
|
e0e7b129ed | ||
|
|
2f6f02971e | ||
|
|
dfda58306b | ||
|
|
f13e4f60c9 | ||
|
|
2246799694 | ||
|
|
2456882e47 | ||
|
|
5fe35fd156 | ||
|
|
43d71107ef | ||
|
|
050dcd02b6 | ||
|
|
72856b0c11 | ||
|
|
5f574a5974 | ||
|
|
c773faca96 | ||
|
|
97d83aaa75 | ||
|
|
182927a1d4 |
@@ -6,11 +6,13 @@ from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
import prisma.enums
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
from backend.copilot.session_types import ChatSessionStartType
|
||||
from backend.data.model import UserTransaction
|
||||
from backend.util.models import Pagination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.invited_user import BulkInvitedUsersResult, InvitedUserRecord
|
||||
from backend.data.model import User
|
||||
|
||||
|
||||
class UserHistoryResponse(BaseModel):
|
||||
@@ -90,3 +92,51 @@ class BulkInvitedUsersResponse(BaseModel):
|
||||
for row in result.results
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class AdminCopilotUserSummary(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
name: Optional[str] = None
|
||||
timezone: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_user(cls, user: "User") -> "AdminCopilotUserSummary":
|
||||
return cls(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
name=user.name,
|
||||
timezone=user.timezone,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at,
|
||||
)
|
||||
|
||||
|
||||
class AdminCopilotUsersResponse(BaseModel):
|
||||
users: list[AdminCopilotUserSummary]
|
||||
|
||||
|
||||
class TriggerCopilotSessionRequest(BaseModel):
|
||||
user_id: str
|
||||
start_type: ChatSessionStartType
|
||||
|
||||
|
||||
class TriggerCopilotSessionResponse(BaseModel):
|
||||
session_id: str
|
||||
start_type: ChatSessionStartType
|
||||
|
||||
|
||||
class SendCopilotEmailsRequest(BaseModel):
|
||||
user_id: str
|
||||
|
||||
|
||||
class SendCopilotEmailsResponse(BaseModel):
|
||||
candidate_count: int
|
||||
processed_count: int
|
||||
sent_count: int
|
||||
skipped_count: int
|
||||
repair_queued_count: int
|
||||
running_count: int
|
||||
failed_count: int
|
||||
|
||||
@@ -2,8 +2,12 @@ import logging
|
||||
import math
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, File, Query, Security, UploadFile
|
||||
from fastapi import APIRouter, File, HTTPException, Query, Security, UploadFile
|
||||
|
||||
from backend.copilot.autopilot import (
|
||||
send_pending_copilot_emails_for_user,
|
||||
trigger_autopilot_session_for_user,
|
||||
)
|
||||
from backend.data.invited_user import (
|
||||
bulk_create_invited_users_from_file,
|
||||
create_invited_user,
|
||||
@@ -12,13 +16,20 @@ from backend.data.invited_user import (
|
||||
revoke_invited_user,
|
||||
)
|
||||
from backend.data.tally import mask_email
|
||||
from backend.data.user import search_users
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from .model import (
|
||||
AdminCopilotUsersResponse,
|
||||
AdminCopilotUserSummary,
|
||||
BulkInvitedUsersResponse,
|
||||
CreateInvitedUserRequest,
|
||||
InvitedUserResponse,
|
||||
InvitedUsersResponse,
|
||||
SendCopilotEmailsRequest,
|
||||
SendCopilotEmailsResponse,
|
||||
TriggerCopilotSessionRequest,
|
||||
TriggerCopilotSessionResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -135,3 +146,95 @@ async def retry_invited_user_tally_route(
|
||||
invited_user_id,
|
||||
)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/copilot/users",
|
||||
response_model=AdminCopilotUsersResponse,
|
||||
summary="Search Copilot Users",
|
||||
operation_id="getV2SearchCopilotUsers",
|
||||
)
|
||||
async def search_copilot_users_route(
|
||||
search: str = Query("", description="Search by email, name, or user ID"),
|
||||
limit: int = Query(20, ge=1, le=50),
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> AdminCopilotUsersResponse:
|
||||
logger.info(
|
||||
"Admin user %s searched Copilot users (query_length=%s, limit=%s)",
|
||||
admin_user_id,
|
||||
len(search.strip()),
|
||||
limit,
|
||||
)
|
||||
users = await search_users(search, limit=limit)
|
||||
return AdminCopilotUsersResponse(
|
||||
users=[AdminCopilotUserSummary.from_user(user) for user in users]
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/copilot/trigger",
|
||||
response_model=TriggerCopilotSessionResponse,
|
||||
summary="Trigger Copilot Session",
|
||||
operation_id="postV2TriggerCopilotSession",
|
||||
)
|
||||
async def trigger_copilot_session_route(
|
||||
request: TriggerCopilotSessionRequest,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> TriggerCopilotSessionResponse:
|
||||
logger.info(
|
||||
"Admin user %s manually triggered %s for user %s",
|
||||
admin_user_id,
|
||||
request.start_type,
|
||||
request.user_id,
|
||||
)
|
||||
try:
|
||||
session = await trigger_autopilot_session_for_user(
|
||||
request.user_id,
|
||||
start_type=request.start_type,
|
||||
)
|
||||
except LookupError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
logger.info(
|
||||
"Admin user %s created manual Copilot session %s for user %s",
|
||||
admin_user_id,
|
||||
session.session_id,
|
||||
request.user_id,
|
||||
)
|
||||
return TriggerCopilotSessionResponse(
|
||||
session_id=session.session_id,
|
||||
start_type=request.start_type,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/copilot/send-emails",
|
||||
response_model=SendCopilotEmailsResponse,
|
||||
summary="Send Pending Copilot Emails",
|
||||
operation_id="postV2SendPendingCopilotEmails",
|
||||
)
|
||||
async def send_pending_copilot_emails_route(
|
||||
request: SendCopilotEmailsRequest,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> SendCopilotEmailsResponse:
|
||||
logger.info(
|
||||
"Admin user %s manually triggered pending Copilot emails for user %s",
|
||||
admin_user_id,
|
||||
request.user_id,
|
||||
)
|
||||
result = await send_pending_copilot_emails_for_user(request.user_id)
|
||||
logger.info(
|
||||
"Admin user %s completed pending Copilot email sweep for user %s "
|
||||
"(candidates=%s, sent=%s, skipped=%s, repairs=%s, running=%s, failed=%s)",
|
||||
admin_user_id,
|
||||
request.user_id,
|
||||
result.candidate_count,
|
||||
result.sent_count,
|
||||
result.skipped_count,
|
||||
result.repair_queued_count,
|
||||
result.running_count,
|
||||
result.failed_count,
|
||||
)
|
||||
return SendCopilotEmailsResponse(**result.model_dump())
|
||||
|
||||
@@ -8,11 +8,15 @@ import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.copilot.autopilot_email import PendingCopilotEmailSweepResult
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.session_types import ChatSessionStartType
|
||||
from backend.data.invited_user import (
|
||||
BulkInvitedUserRowResult,
|
||||
BulkInvitedUsersResult,
|
||||
InvitedUserRecord,
|
||||
)
|
||||
from backend.data.model import User
|
||||
|
||||
from .user_admin_routes import router as user_admin_router
|
||||
|
||||
@@ -72,6 +76,20 @@ def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
|
||||
)
|
||||
|
||||
|
||||
def _sample_user() -> User:
|
||||
now = datetime.now(timezone.utc)
|
||||
return User(
|
||||
id="user-1",
|
||||
email="copilot@example.com",
|
||||
name="Copilot User",
|
||||
timezone="Europe/Madrid",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
stripe_customer_id=None,
|
||||
top_up_config=None,
|
||||
)
|
||||
|
||||
|
||||
def test_get_invited_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
@@ -166,3 +184,107 @@ def test_retry_invited_user_tally(
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["tally_status"] == "RUNNING"
|
||||
|
||||
|
||||
def test_search_copilot_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.search_users",
|
||||
AsyncMock(return_value=[_sample_user()]),
|
||||
)
|
||||
|
||||
response = client.get("/admin/copilot/users", params={"search": "copilot"})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["users"]) == 1
|
||||
assert data["users"][0]["email"] == "copilot@example.com"
|
||||
assert data["users"][0]["timezone"] == "Europe/Madrid"
|
||||
|
||||
|
||||
def test_trigger_copilot_session(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
session = ChatSession.new(
|
||||
"user-1",
|
||||
start_type=ChatSessionStartType.AUTOPILOT_CALLBACK,
|
||||
)
|
||||
trigger = mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.trigger_autopilot_session_for_user",
|
||||
AsyncMock(return_value=session),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/copilot/trigger",
|
||||
json={
|
||||
"user_id": "user-1",
|
||||
"start_type": ChatSessionStartType.AUTOPILOT_CALLBACK.value,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["session_id"] == session.session_id
|
||||
assert response.json()["start_type"] == "AUTOPILOT_CALLBACK"
|
||||
assert trigger.await_args is not None
|
||||
assert trigger.await_args.args[0] == "user-1"
|
||||
assert (
|
||||
trigger.await_args.kwargs["start_type"]
|
||||
== ChatSessionStartType.AUTOPILOT_CALLBACK
|
||||
)
|
||||
|
||||
|
||||
def test_trigger_copilot_session_returns_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.trigger_autopilot_session_for_user",
|
||||
AsyncMock(side_effect=LookupError("User not found with ID: missing-user")),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/copilot/trigger",
|
||||
json={
|
||||
"user_id": "missing-user",
|
||||
"start_type": ChatSessionStartType.AUTOPILOT_NIGHTLY.value,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "User not found with ID: missing-user"
|
||||
|
||||
|
||||
def test_send_pending_copilot_emails(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
send_emails = mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.send_pending_copilot_emails_for_user",
|
||||
AsyncMock(
|
||||
return_value=PendingCopilotEmailSweepResult(
|
||||
candidate_count=1,
|
||||
processed_count=1,
|
||||
sent_count=1,
|
||||
skipped_count=0,
|
||||
repair_queued_count=0,
|
||||
running_count=0,
|
||||
failed_count=0,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/copilot/send-emails",
|
||||
json={"user_id": "user-1"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"candidate_count": 1,
|
||||
"processed_count": 1,
|
||||
"sent_count": 1,
|
||||
"skipped_count": 0,
|
||||
"repair_queued_count": 0,
|
||||
"running_count": 0,
|
||||
"failed_count": 0,
|
||||
}
|
||||
send_emails.assert_awaited_once_with("user-1")
|
||||
|
||||
@@ -3,18 +3,23 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Any, NoReturn
|
||||
from uuid import uuid4
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from prisma.models import UserWorkspaceFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.autopilot import (
|
||||
consume_callback_token,
|
||||
strip_internal_content,
|
||||
unwrap_internal_content,
|
||||
)
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||
from backend.copilot.model import (
|
||||
@@ -27,7 +32,13 @@ from backend.copilot.model import (
|
||||
get_user_sessions,
|
||||
update_session_title,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamHeartbeat,
|
||||
)
|
||||
from backend.copilot.session_types import ChatSessionStartType
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
from backend.copilot.tools.models import (
|
||||
AgentDetailsResponse,
|
||||
@@ -53,6 +64,7 @@ from backend.copilot.tools.models import (
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import get_business_understanding
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
@@ -65,6 +77,187 @@ _UUID_RE = re.compile(
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
STREAM_QUEUE_GET_TIMEOUT_SECONDS = 10.0
|
||||
STREAMING_RESPONSE_HEADERS = {
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
}
|
||||
|
||||
|
||||
def _build_streaming_response(
|
||||
generator: AsyncGenerator[str, None],
|
||||
) -> StreamingResponse:
|
||||
return StreamingResponse(
|
||||
generator,
|
||||
media_type="text/event-stream",
|
||||
headers=STREAMING_RESPONSE_HEADERS,
|
||||
)
|
||||
|
||||
|
||||
async def _unsubscribe_stream_queue(
|
||||
session_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse] | None,
|
||||
) -> None:
|
||||
if subscriber_queue is None:
|
||||
return
|
||||
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_session(session_id, subscriber_queue)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from session {session_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def _stream_subscriber_queue(
|
||||
*,
|
||||
session_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
log_meta: dict[str, Any],
|
||||
started_at: float,
|
||||
label: str,
|
||||
surface_errors: bool,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(
|
||||
subscriber_queue.get(),
|
||||
timeout=STREAM_QUEUE_GET_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
yield StreamHeartbeat().to_sse()
|
||||
continue
|
||||
|
||||
chunk_count += 1
|
||||
if first_chunk_type is None:
|
||||
first_chunk_type = type(chunk).__name__
|
||||
elapsed = (time.perf_counter() - started_at) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] {label} first chunk at {elapsed:.1f}ms, type={first_chunk_type}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunk_type": first_chunk_type,
|
||||
"elapsed_ms": elapsed,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
yield chunk.to_sse()
|
||||
|
||||
if isinstance(chunk, StreamFinish):
|
||||
total_time = (time.perf_counter() - started_at) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] {label} received StreamFinish in {total_time:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunks_yielded": chunk_count,
|
||||
"total_time_ms": total_time,
|
||||
}
|
||||
},
|
||||
)
|
||||
break
|
||||
except GeneratorExit:
|
||||
logger.info(
|
||||
f"[TIMING] {label} client disconnected after {chunk_count} chunks",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunks_yielded": chunk_count,
|
||||
"reason": "client_disconnect",
|
||||
}
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
elapsed = (time.perf_counter() - started_at) * 1000
|
||||
logger.error(
|
||||
f"[TIMING] {label} error after {elapsed:.1f}ms: {exc}",
|
||||
extra={
|
||||
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(exc)}
|
||||
},
|
||||
)
|
||||
if surface_errors:
|
||||
yield StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="stream_error",
|
||||
).to_sse()
|
||||
yield StreamFinish().to_sse()
|
||||
finally:
|
||||
total_time = (time.perf_counter() - started_at) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] {label} finished in {total_time:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time,
|
||||
"chunks_yielded": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
}
|
||||
},
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
async def _stream_chat_events(
|
||||
*,
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
subscribe_from_id: str,
|
||||
turn_id: str,
|
||||
log_meta: dict[str, Any],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
started_at = time.perf_counter()
|
||||
subscriber_queue = await stream_registry.subscribe_to_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
last_message_id=subscribe_from_id,
|
||||
)
|
||||
|
||||
try:
|
||||
if subscriber_queue is None:
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
async for chunk in _stream_subscriber_queue(
|
||||
session_id=session_id,
|
||||
subscriber_queue=subscriber_queue,
|
||||
log_meta=log_meta,
|
||||
started_at=started_at,
|
||||
label=f"stream_chat_post[{turn_id}]",
|
||||
surface_errors=True,
|
||||
):
|
||||
yield chunk
|
||||
finally:
|
||||
await _unsubscribe_stream_queue(session_id, subscriber_queue)
|
||||
|
||||
|
||||
async def _resume_stream_events(
|
||||
*,
|
||||
session_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
started_at = time.perf_counter()
|
||||
try:
|
||||
async for chunk in _stream_subscriber_queue(
|
||||
session_id=session_id,
|
||||
subscriber_queue=subscriber_queue,
|
||||
log_meta={"session_id": session_id},
|
||||
started_at=started_at,
|
||||
label=f"resume_stream[{session_id}]",
|
||||
surface_errors=False,
|
||||
):
|
||||
yield chunk
|
||||
finally:
|
||||
await _unsubscribe_stream_queue(session_id, subscriber_queue)
|
||||
|
||||
|
||||
async def _validate_and_get_session(
|
||||
@@ -118,6 +311,8 @@ class SessionDetailResponse(BaseModel):
|
||||
created_at: str
|
||||
updated_at: str
|
||||
user_id: str | None
|
||||
start_type: ChatSessionStartType
|
||||
execution_tag: str | None = None
|
||||
messages: list[dict]
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
|
||||
@@ -129,6 +324,8 @@ class SessionSummaryResponse(BaseModel):
|
||||
created_at: str
|
||||
updated_at: str
|
||||
title: str | None = None
|
||||
start_type: ChatSessionStartType
|
||||
execution_tag: str | None = None
|
||||
is_processing: bool
|
||||
|
||||
|
||||
@@ -160,6 +357,14 @@ class UpdateSessionTitleRequest(BaseModel):
|
||||
return stripped
|
||||
|
||||
|
||||
class ConsumeCallbackTokenRequest(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
class ConsumeCallbackTokenResponse(BaseModel):
|
||||
session_id: str
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@@ -171,6 +376,7 @@ async def list_sessions(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
limit: int = Query(default=50, ge=1, le=100),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
with_auto: bool = Query(default=False),
|
||||
) -> ListSessionsResponse:
|
||||
"""
|
||||
List chat sessions for the authenticated user.
|
||||
@@ -186,7 +392,12 @@ async def list_sessions(
|
||||
Returns:
|
||||
ListSessionsResponse: List of session summaries and total count.
|
||||
"""
|
||||
sessions, total_count = await get_user_sessions(user_id, limit, offset)
|
||||
sessions, total_count = await get_user_sessions(
|
||||
user_id,
|
||||
limit,
|
||||
offset,
|
||||
with_auto=with_auto,
|
||||
)
|
||||
|
||||
# Batch-check Redis for active stream status on each session
|
||||
processing_set: set[str] = set()
|
||||
@@ -217,6 +428,8 @@ async def list_sessions(
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
title=session.title,
|
||||
start_type=session.start_type,
|
||||
execution_tag=session.execution_tag,
|
||||
is_processing=session.session_id in processing_set,
|
||||
)
|
||||
for session in sessions
|
||||
@@ -368,12 +581,26 @@ async def get_session(
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
messages = []
|
||||
for message in session.messages:
|
||||
payload = message.model_dump()
|
||||
if message.role == "user":
|
||||
visible_content = strip_internal_content(message.content)
|
||||
if (
|
||||
visible_content is None
|
||||
and session.start_type != ChatSessionStartType.MANUAL
|
||||
):
|
||||
visible_content = unwrap_internal_content(message.content)
|
||||
if visible_content is None:
|
||||
continue
|
||||
payload["content"] = visible_content
|
||||
messages.append(payload)
|
||||
|
||||
# Check if there's an active stream for this session
|
||||
active_stream_info = None
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
session_id,
|
||||
user_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
|
||||
@@ -394,11 +621,28 @@ async def get_session(
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
start_type=session.start_type,
|
||||
execution_tag=session.execution_tag,
|
||||
messages=messages,
|
||||
active_stream=active_stream_info,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/callback-token/consume",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
)
|
||||
async def consume_callback_token_route(
|
||||
request: ConsumeCallbackTokenRequest,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> ConsumeCallbackTokenResponse:
|
||||
try:
|
||||
result = await consume_callback_token(request.token, user_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
return ConsumeCallbackTokenResponse(session_id=result.session_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/cancel",
|
||||
status_code=200,
|
||||
@@ -472,9 +716,6 @@ async def stream_chat_post(
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
stream_start_time = time.perf_counter()
|
||||
log_meta = {"component": "ChatStream", "session_id": session_id}
|
||||
if user_id:
|
||||
@@ -506,18 +747,14 @@ async def stream_chat_post(
|
||||
|
||||
if valid_ids:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Batch query instead of N+1
|
||||
files = await UserWorkspaceFile.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": valid_ids},
|
||||
"workspaceId": workspace.id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
files = await workspace_db().get_workspace_files_by_ids(
|
||||
workspace_id=workspace.id,
|
||||
file_ids=valid_ids,
|
||||
)
|
||||
# Only keep IDs that actually exist in the user's workspace
|
||||
sanitized_file_ids = [wf.id for wf in files] or None
|
||||
file_lines: list[str] = [
|
||||
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
|
||||
f"- {wf.name} ({wf.mime_type}, {round(wf.size_bytes / 1024, 1)} KB), file_id={wf.id}"
|
||||
for wf in files
|
||||
]
|
||||
if file_lines:
|
||||
@@ -587,141 +824,14 @@ async def stream_chat_post(
|
||||
f"[TIMING] Task enqueued to RabbitMQ, setup={setup_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||
)
|
||||
|
||||
# SSE endpoint that subscribes to the task's stream
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import time as time_module
|
||||
|
||||
event_gen_start = time_module.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING] event_generator STARTED, turn={turn_id}, session={session_id}, "
|
||||
f"user={user_id}",
|
||||
extra={"json_fields": log_meta},
|
||||
return _build_streaming_response(
|
||||
_stream_chat_events(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
subscribe_from_id=subscribe_from_id,
|
||||
turn_id=turn_id,
|
||||
log_meta=log_meta,
|
||||
)
|
||||
subscriber_queue = None
|
||||
first_chunk_yielded = False
|
||||
chunks_yielded = 0
|
||||
try:
|
||||
# Subscribe from the position we captured before enqueuing
|
||||
# This avoids replaying old messages while catching all new ones
|
||||
subscriber_queue = await stream_registry.subscribe_to_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
last_message_id=subscribe_from_id,
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
# Read from the subscriber queue and yield to SSE
|
||||
logger.info(
|
||||
"[TIMING] Starting to read from subscriber_queue",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
|
||||
chunks_yielded += 1
|
||||
|
||||
if not first_chunk_yielded:
|
||||
first_chunk_yielded = True
|
||||
elapsed = time_module.perf_counter() - event_gen_start
|
||||
logger.info(
|
||||
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
|
||||
f"type={type(chunk).__name__}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunk_type": type(chunk).__name__,
|
||||
"elapsed_ms": elapsed * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
total_time = time_module.perf_counter() - event_gen_start
|
||||
logger.info(
|
||||
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
|
||||
f"n_chunks={chunks_yielded}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunks_yielded": chunks_yielded,
|
||||
"total_time_ms": total_time * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
yield StreamHeartbeat().to_sse()
|
||||
|
||||
except GeneratorExit:
|
||||
logger.info(
|
||||
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunks_yielded": chunks_yielded,
|
||||
"reason": "client_disconnect",
|
||||
}
|
||||
},
|
||||
)
|
||||
pass # Client disconnected - background task continues
|
||||
except Exception as e:
|
||||
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
||||
logger.error(
|
||||
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
|
||||
extra={
|
||||
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
||||
},
|
||||
)
|
||||
# Surface error to frontend so it doesn't appear stuck
|
||||
yield StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="stream_error",
|
||||
).to_sse()
|
||||
yield StreamFinish().to_sse()
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
if subscriber_queue is not None:
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_session(
|
||||
session_id, subscriber_queue
|
||||
)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from session {session_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||
total_time = time_module.perf_counter() - event_gen_start
|
||||
logger.info(
|
||||
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
|
||||
f"turn={turn_id}, session={session_id}, n_chunks={chunks_yielded}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time * 1000,
|
||||
"chunks_yielded": chunks_yielded,
|
||||
}
|
||||
},
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -747,11 +857,7 @@ async def resume_session_stream(
|
||||
StreamingResponse (SSE) when an active stream exists,
|
||||
or 204 No Content when there is nothing to resume.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
active_session, _ = await stream_registry.get_active_session(session_id, user_id)
|
||||
|
||||
if not active_session:
|
||||
return Response(status_code=204)
|
||||
@@ -768,64 +874,11 @@ async def resume_session_stream(
|
||||
|
||||
if subscriber_queue is None:
|
||||
return Response(status_code=204)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Resume stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
|
||||
if isinstance(chunk, StreamFinish):
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
yield StreamHeartbeat().to_sse()
|
||||
except GeneratorExit:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error in resume stream for session {session_id}: {e}")
|
||||
finally:
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_session(
|
||||
session_id, subscriber_queue
|
||||
)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from session {active_session.session_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.info(
|
||||
"Resume stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"n_chunks": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
return _build_streaming_response(
|
||||
_resume_stream_events(
|
||||
session_id=session_id,
|
||||
subscriber_queue=subscriber_queue,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -977,6 +1030,6 @@ ToolResponseUnion = (
|
||||
description="This endpoint is not meant to be called. It exists solely to "
|
||||
"expose tool response models in the OpenAPI schema for frontend codegen.",
|
||||
)
|
||||
async def _tool_response_schema() -> ToolResponseUnion: # type: ignore[return]
|
||||
async def _tool_response_schema() -> NoReturn:
|
||||
"""Never called at runtime. Exists only so Orval generates TS types."""
|
||||
raise HTTPException(status_code=501, detail="Schema-only endpoint")
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""Tests for chat API routes: session title update, file attachment validation, and suggested prompts."""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import fastapi
|
||||
@@ -8,6 +11,9 @@ import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.chat import routes as chat_routes
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.response_model import StreamFinish
|
||||
from backend.copilot.session_types import ChatSessionStartType
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(chat_routes.router)
|
||||
@@ -115,6 +121,238 @@ def test_update_title_not_found(
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_list_sessions_defaults_to_manual_only(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
started_at = datetime.now(timezone.utc)
|
||||
mock_get_user_sessions = mocker.patch(
|
||||
"backend.api.features.chat.routes.get_user_sessions",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(
|
||||
[
|
||||
SimpleNamespace(
|
||||
session_id="sess-1",
|
||||
started_at=started_at,
|
||||
updated_at=started_at,
|
||||
title="Nightly check-in",
|
||||
start_type=chat_routes.ChatSessionStartType.AUTOPILOT_NIGHTLY,
|
||||
execution_tag="autopilot-nightly:2026-03-13",
|
||||
)
|
||||
],
|
||||
1,
|
||||
),
|
||||
)
|
||||
|
||||
pipe = MagicMock()
|
||||
pipe.hget = MagicMock()
|
||||
pipe.execute = AsyncMock(return_value=["running"])
|
||||
redis = MagicMock()
|
||||
redis.pipeline = MagicMock(return_value=pipe)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=redis,
|
||||
)
|
||||
|
||||
response = client.get("/sessions")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"sessions": [
|
||||
{
|
||||
"id": "sess-1",
|
||||
"created_at": started_at.isoformat(),
|
||||
"updated_at": started_at.isoformat(),
|
||||
"title": "Nightly check-in",
|
||||
"start_type": "AUTOPILOT_NIGHTLY",
|
||||
"execution_tag": "autopilot-nightly:2026-03-13",
|
||||
"is_processing": True,
|
||||
}
|
||||
],
|
||||
"total": 1,
|
||||
}
|
||||
mock_get_user_sessions.assert_awaited_once_with(
|
||||
test_user_id,
|
||||
50,
|
||||
0,
|
||||
with_auto=False,
|
||||
)
|
||||
|
||||
|
||||
def test_list_sessions_can_include_auto_sessions(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_get_user_sessions = mocker.patch(
|
||||
"backend.api.features.chat.routes.get_user_sessions",
|
||||
new_callable=AsyncMock,
|
||||
return_value=([], 0),
|
||||
)
|
||||
|
||||
response = client.get("/sessions?with_auto=true")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"sessions": [], "total": 0}
|
||||
mock_get_user_sessions.assert_awaited_once_with(
|
||||
test_user_id,
|
||||
50,
|
||||
0,
|
||||
with_auto=True,
|
||||
)
|
||||
|
||||
|
||||
def test_consume_callback_token_route_returns_session_id(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mock_consume = mocker.patch(
|
||||
"backend.api.features.chat.routes.consume_callback_token",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SimpleNamespace(session_id="sess-2"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/callback-token/consume",
|
||||
json={"token": "token-123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"session_id": "sess-2"}
|
||||
mock_consume.assert_awaited_once_with("token-123", TEST_USER_ID)
|
||||
|
||||
|
||||
def test_consume_callback_token_route_returns_404_on_invalid_token(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.consume_callback_token",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("Callback token not found"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/callback-token/consume",
|
||||
json={"token": "token-123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json() == {"detail": "Callback token not found"}
|
||||
|
||||
|
||||
def test_get_session_hides_internal_only_messages_for_manual_sessions(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
session = ChatSession.new(
|
||||
TEST_USER_ID,
|
||||
start_type=ChatSessionStartType.MANUAL,
|
||||
)
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="<internal>hidden</internal>"),
|
||||
ChatMessage(
|
||||
role="user",
|
||||
content="Visible<internal>hidden</internal> text",
|
||||
),
|
||||
ChatMessage(role="assistant", content="Public response"),
|
||||
]
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.get_active_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(None, None),
|
||||
)
|
||||
|
||||
response = client.get(f"/sessions/{session.session_id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Visible text",
|
||||
"name": None,
|
||||
"tool_call_id": None,
|
||||
"refusal": None,
|
||||
"tool_calls": None,
|
||||
"function_call": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Public response",
|
||||
"name": None,
|
||||
"tool_call_id": None,
|
||||
"refusal": None,
|
||||
"tool_calls": None,
|
||||
"function_call": None,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_get_session_shows_cleaned_internal_kickoff_for_autopilot_sessions(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
session = ChatSession.new(
|
||||
TEST_USER_ID,
|
||||
start_type=ChatSessionStartType.AUTOPILOT_NIGHTLY,
|
||||
execution_tag="autopilot-nightly:2026-03-13",
|
||||
)
|
||||
session.messages = [
|
||||
ChatMessage(role="user", content="<internal>hidden</internal>"),
|
||||
ChatMessage(
|
||||
role="user",
|
||||
content="Visible<internal>hidden</internal> text",
|
||||
),
|
||||
ChatMessage(role="assistant", content="Public response"),
|
||||
]
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.get_active_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(None, None),
|
||||
)
|
||||
|
||||
response = client.get(f"/sessions/{session.session_id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hidden",
|
||||
"name": None,
|
||||
"tool_call_id": None,
|
||||
"refusal": None,
|
||||
"tool_calls": None,
|
||||
"function_call": None,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Visible text",
|
||||
"name": None,
|
||||
"tool_call_id": None,
|
||||
"refusal": None,
|
||||
"tool_calls": None,
|
||||
"function_call": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Public response",
|
||||
"name": None,
|
||||
"tool_call_id": None,
|
||||
"refusal": None,
|
||||
"tool_calls": None,
|
||||
"function_call": None,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# ─── file_ids Pydantic validation ─────────────────────────────────────
|
||||
|
||||
|
||||
@@ -142,7 +380,11 @@ def _mock_stream_internals(mocker: pytest_mock.MockFixture):
|
||||
return_value=None,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
subscriber_queue = asyncio.Queue()
|
||||
subscriber_queue.put_nowait(StreamFinish())
|
||||
mock_registry.create_session = mocker.AsyncMock(return_value=None)
|
||||
mock_registry.subscribe_to_session = mocker.AsyncMock(return_value=subscriber_queue)
|
||||
mock_registry.unsubscribe_from_session = mocker.AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
@@ -165,11 +407,11 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
workspace_store = mocker.MagicMock()
|
||||
workspace_store.get_workspace_files_by_ids = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
"backend.api.features.chat.routes.workspace_db",
|
||||
return_value=workspace_store,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
@@ -195,11 +437,11 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
workspace_store = mocker.MagicMock()
|
||||
workspace_store.get_workspace_files_by_ids = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
"backend.api.features.chat.routes.workspace_db",
|
||||
return_value=workspace_store,
|
||||
)
|
||||
|
||||
valid_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
@@ -217,9 +459,10 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
)
|
||||
|
||||
# The find_many call should only receive the one valid UUID
|
||||
mock_prisma.find_many.assert_called_once()
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["id"]["in"] == [valid_id]
|
||||
workspace_store.get_workspace_files_by_ids.assert_called_once_with(
|
||||
workspace_id="ws-1",
|
||||
file_ids=[valid_id],
|
||||
)
|
||||
|
||||
|
||||
# ─── Cross-workspace file_ids ─────────────────────────────────────────
|
||||
@@ -233,11 +476,11 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
return_value=type("W", (), {"id": "my-workspace-id"})(),
|
||||
)
|
||||
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
workspace_store = mocker.MagicMock()
|
||||
workspace_store.get_workspace_files_by_ids = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
"backend.api.features.chat.routes.workspace_db",
|
||||
return_value=workspace_store,
|
||||
)
|
||||
|
||||
fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
@@ -246,9 +489,10 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
json={"message": "hi", "file_ids": [fid]},
|
||||
)
|
||||
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
|
||||
assert call_kwargs["where"]["isDeleted"] is False
|
||||
workspace_store.get_workspace_files_by_ids.assert_called_once_with(
|
||||
workspace_id="my-workspace-id",
|
||||
file_ids=[fid],
|
||||
)
|
||||
|
||||
|
||||
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
||||
|
||||
135
autogpt_platform/backend/backend/copilot/autopilot.py
Normal file
135
autogpt_platform/backend/backend/copilot/autopilot.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Autopilot public API — thin facade re-exporting from sub-modules.
|
||||
|
||||
Implementation is split by responsibility:
|
||||
- autopilot_prompts: constants, prompt templates, context builders
|
||||
- autopilot_dispatch: timezone helpers, session creation, dispatch/scheduling
|
||||
- autopilot_completion: completion report extraction, repair, handler
|
||||
- autopilot_email: email sending, link building, notification sweep
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.autopilot_completion import ( # noqa: F401
|
||||
CompletionReportToolCall,
|
||||
CompletionReportToolCallFunction,
|
||||
ToolOutputEnvelope,
|
||||
_build_completion_report_repair_message,
|
||||
_extract_completion_report_from_session,
|
||||
_get_pending_approval_metadata,
|
||||
_queue_completion_report_repair,
|
||||
handle_non_manual_session_completion,
|
||||
)
|
||||
from backend.copilot.autopilot_dispatch import ( # noqa: F401
|
||||
_bucket_end_for_now,
|
||||
_create_autopilot_session,
|
||||
_crosses_local_midnight,
|
||||
_enqueue_session_turn,
|
||||
_resolve_timezone_name,
|
||||
_session_exists_for_execution_tag,
|
||||
_try_create_callback_session,
|
||||
_try_create_invite_cta_session,
|
||||
_try_create_nightly_session,
|
||||
_user_has_recent_manual_message,
|
||||
_user_has_session_since,
|
||||
dispatch_nightly_copilot,
|
||||
get_callback_execution_tag,
|
||||
get_graph_exec_id_for_session,
|
||||
get_invite_cta_execution_tag,
|
||||
get_nightly_execution_tag,
|
||||
trigger_autopilot_session_for_user,
|
||||
)
|
||||
from backend.copilot.autopilot_email import ( # noqa: F401
|
||||
PendingCopilotEmailSweepResult,
|
||||
_build_session_link,
|
||||
_get_completion_email_template_name,
|
||||
_markdown_to_email_html,
|
||||
_send_completion_email,
|
||||
_send_nightly_copilot_emails,
|
||||
send_nightly_copilot_emails,
|
||||
send_pending_copilot_emails_for_user,
|
||||
)
|
||||
from backend.copilot.autopilot_prompts import ( # noqa: F401
|
||||
AUTOPILOT_CALLBACK_EMAIL_TEMPLATE,
|
||||
AUTOPILOT_CALLBACK_TAG,
|
||||
AUTOPILOT_DISABLED_TOOLS,
|
||||
AUTOPILOT_INVITE_CTA_EMAIL_TEMPLATE,
|
||||
AUTOPILOT_INVITE_CTA_TAG,
|
||||
AUTOPILOT_NIGHTLY_EMAIL_TEMPLATE,
|
||||
AUTOPILOT_NIGHTLY_TAG_PREFIX,
|
||||
DEFAULT_AUTOPILOT_CALLBACK_SYSTEM_PROMPT,
|
||||
DEFAULT_AUTOPILOT_INVITE_CTA_SYSTEM_PROMPT,
|
||||
DEFAULT_AUTOPILOT_NIGHTLY_SYSTEM_PROMPT,
|
||||
INTERNAL_TAG_RE,
|
||||
MAX_COMPLETION_REPORT_REPAIRS,
|
||||
_build_autopilot_system_prompt,
|
||||
_format_start_type_label,
|
||||
_get_recent_manual_session_context,
|
||||
_get_recent_sent_email_context,
|
||||
_get_recent_session_summary_context,
|
||||
strip_internal_content,
|
||||
unwrap_internal_content,
|
||||
wrap_internal_message,
|
||||
)
|
||||
from backend.copilot.model import ChatMessage, create_chat_session
|
||||
from backend.data.db_accessors import chat_db
|
||||
|
||||
# -- re-exports from sub-modules (preserves existing import paths) ---------- #
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CallbackTokenConsumeResult(BaseModel):
|
||||
session_id: str
|
||||
|
||||
|
||||
async def consume_callback_token(
|
||||
token_id: str,
|
||||
user_id: str,
|
||||
) -> CallbackTokenConsumeResult:
|
||||
"""Consume a callback token and return the resulting session.
|
||||
|
||||
Uses an atomic consume-and-update to prevent the TOCTOU race where two
|
||||
concurrent requests could each see the token as unconsumed and create
|
||||
duplicate sessions.
|
||||
"""
|
||||
db = chat_db()
|
||||
token = await db.get_chat_session_callback_token(token_id)
|
||||
if token is None or token.user_id != user_id:
|
||||
raise ValueError("Callback token not found")
|
||||
if token.expires_at <= datetime.now(UTC):
|
||||
raise ValueError("Callback token has expired")
|
||||
|
||||
if token.consumed_session_id:
|
||||
return CallbackTokenConsumeResult(session_id=token.consumed_session_id)
|
||||
|
||||
session = await create_chat_session(
|
||||
user_id,
|
||||
initial_messages=[
|
||||
ChatMessage(role="assistant", content=token.callback_session_message)
|
||||
],
|
||||
)
|
||||
|
||||
# Atomically mark consumed — if another request already consumed the token
|
||||
# concurrently, the DB will have a non-null consumed_session_id; we re-read
|
||||
# and return the winner's session instead.
|
||||
await db.mark_chat_session_callback_token_consumed(
|
||||
token_id,
|
||||
session.session_id,
|
||||
)
|
||||
|
||||
# Re-read to see if we won the race
|
||||
refreshed = await db.get_chat_session_callback_token(token_id)
|
||||
if (
|
||||
refreshed
|
||||
and refreshed.consumed_session_id
|
||||
and refreshed.consumed_session_id != session.session_id
|
||||
):
|
||||
return CallbackTokenConsumeResult(session_id=refreshed.consumed_session_id)
|
||||
|
||||
return CallbackTokenConsumeResult(session_id=session.session_id)
|
||||
198
autogpt_platform/backend/backend/copilot/autopilot_completion.py
Normal file
198
autogpt_platform/backend/backend/copilot/autopilot_completion.py
Normal file
@@ -0,0 +1,198 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from backend.copilot.autopilot_dispatch import (
|
||||
_enqueue_session_turn,
|
||||
get_graph_exec_id_for_session,
|
||||
)
|
||||
from backend.copilot.autopilot_prompts import (
|
||||
MAX_COMPLETION_REPORT_REPAIRS,
|
||||
wrap_internal_message,
|
||||
)
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
get_chat_session,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.copilot.session_types import CompletionReportInput, StoredCompletionReport
|
||||
from backend.data.db_accessors import review_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# --------------- models --------------- #
|
||||
|
||||
|
||||
class CompletionReportToolCallFunction(BaseModel):
|
||||
name: str | None = None
|
||||
arguments: str | None = None
|
||||
|
||||
|
||||
class CompletionReportToolCall(BaseModel):
|
||||
id: str
|
||||
function: CompletionReportToolCallFunction = Field(
|
||||
default_factory=CompletionReportToolCallFunction
|
||||
)
|
||||
|
||||
|
||||
class ToolOutputEnvelope(BaseModel):
|
||||
type: str | None = None
|
||||
|
||||
|
||||
# --------------- approval metadata --------------- #
|
||||
|
||||
|
||||
async def _get_pending_approval_metadata(
|
||||
session: ChatSession,
|
||||
) -> tuple[int, str | None]:
|
||||
graph_exec_id = get_graph_exec_id_for_session(session.session_id)
|
||||
pending_count = await review_db().count_pending_reviews_for_graph_exec(
|
||||
graph_exec_id,
|
||||
session.user_id,
|
||||
)
|
||||
return pending_count, graph_exec_id if pending_count > 0 else None
|
||||
|
||||
|
||||
# --------------- extraction --------------- #
|
||||
|
||||
|
||||
def _extract_completion_report_from_session(
|
||||
session: ChatSession,
|
||||
*,
|
||||
pending_approval_count: int,
|
||||
) -> CompletionReportInput | None:
|
||||
tool_outputs = {
|
||||
message.tool_call_id: message.content
|
||||
for message in session.messages
|
||||
if message.role == "tool" and message.tool_call_id
|
||||
}
|
||||
|
||||
latest_report: CompletionReportInput | None = None
|
||||
for message in session.messages:
|
||||
if message.role != "assistant" or not message.tool_calls:
|
||||
continue
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
try:
|
||||
parsed_tool_call = CompletionReportToolCall.model_validate(tool_call)
|
||||
except ValidationError:
|
||||
continue
|
||||
|
||||
if parsed_tool_call.function.name != "completion_report":
|
||||
continue
|
||||
|
||||
output = tool_outputs.get(parsed_tool_call.id)
|
||||
if not output:
|
||||
continue
|
||||
|
||||
try:
|
||||
output_payload = ToolOutputEnvelope.model_validate_json(output)
|
||||
except ValidationError:
|
||||
output_payload = None
|
||||
|
||||
if output_payload is not None and output_payload.type == "error":
|
||||
continue
|
||||
|
||||
try:
|
||||
raw_arguments = parsed_tool_call.function.arguments or "{}"
|
||||
report = CompletionReportInput.model_validate_json(raw_arguments)
|
||||
except ValidationError:
|
||||
continue
|
||||
|
||||
if pending_approval_count > 0 and not report.approval_summary:
|
||||
continue
|
||||
|
||||
latest_report = report
|
||||
|
||||
return latest_report
|
||||
|
||||
|
||||
# --------------- repair --------------- #
|
||||
|
||||
|
||||
def _build_completion_report_repair_message(
|
||||
*,
|
||||
attempt: int,
|
||||
pending_approval_count: int,
|
||||
) -> str:
|
||||
approval_instruction = ""
|
||||
if pending_approval_count > 0:
|
||||
approval_instruction = (
|
||||
f" There are currently {pending_approval_count} pending approval item(s). "
|
||||
"If they still exist, include approval_summary."
|
||||
)
|
||||
|
||||
return wrap_internal_message(
|
||||
"The session completed without a valid completion_report tool call. "
|
||||
f"This is repair attempt {attempt}. Call completion_report now and do not do any additional user-facing work."
|
||||
+ approval_instruction
|
||||
)
|
||||
|
||||
|
||||
async def _queue_completion_report_repair(
|
||||
session: ChatSession,
|
||||
*,
|
||||
pending_approval_count: int,
|
||||
) -> None:
|
||||
attempt = session.completion_report_repair_count + 1
|
||||
repair_message = _build_completion_report_repair_message(
|
||||
attempt=attempt,
|
||||
pending_approval_count=pending_approval_count,
|
||||
)
|
||||
session.messages.append(ChatMessage(role="user", content=repair_message))
|
||||
session.completion_report_repair_count = attempt
|
||||
session.completion_report_repair_queued_at = datetime.now(UTC)
|
||||
session.completed_at = None
|
||||
session.completion_report = None
|
||||
await upsert_chat_session(session)
|
||||
await _enqueue_session_turn(
|
||||
session,
|
||||
message=repair_message,
|
||||
tool_name="completion_report_repair",
|
||||
)
|
||||
|
||||
|
||||
# --------------- handler --------------- #
|
||||
|
||||
|
||||
async def handle_non_manual_session_completion(session_id: str) -> None:
|
||||
session = await get_chat_session(session_id)
|
||||
if session is None or session.is_manual:
|
||||
return
|
||||
|
||||
pending_approval_count, graph_exec_id = await _get_pending_approval_metadata(
|
||||
session
|
||||
)
|
||||
report = _extract_completion_report_from_session(
|
||||
session,
|
||||
pending_approval_count=pending_approval_count,
|
||||
)
|
||||
|
||||
if report is not None:
|
||||
session.completion_report = StoredCompletionReport(
|
||||
**report.model_dump(),
|
||||
has_pending_approvals=pending_approval_count > 0,
|
||||
pending_approval_count=pending_approval_count,
|
||||
pending_approval_graph_exec_id=graph_exec_id,
|
||||
saved_at=datetime.now(UTC),
|
||||
)
|
||||
session.completion_report_repair_queued_at = None
|
||||
session.completed_at = datetime.now(UTC)
|
||||
await upsert_chat_session(session)
|
||||
return
|
||||
|
||||
if session.completion_report_repair_count >= MAX_COMPLETION_REPORT_REPAIRS:
|
||||
session.completion_report_repair_queued_at = None
|
||||
session.completed_at = datetime.now(UTC)
|
||||
await upsert_chat_session(session)
|
||||
return
|
||||
|
||||
await _queue_completion_report_repair(
|
||||
session,
|
||||
pending_approval_count=pending_approval_count,
|
||||
)
|
||||
386
autogpt_platform/backend/backend/copilot/autopilot_dispatch.py
Normal file
386
autogpt_platform/backend/backend/copilot/autopilot_dispatch.py
Normal file
@@ -0,0 +1,386 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, date, datetime, time, timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import prisma.enums
|
||||
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.autopilot_prompts import (
|
||||
AUTOPILOT_CALLBACK_TAG,
|
||||
AUTOPILOT_DISABLED_TOOLS,
|
||||
AUTOPILOT_INVITE_CTA_TAG,
|
||||
AUTOPILOT_NIGHTLY_TAG_PREFIX,
|
||||
_build_autopilot_system_prompt,
|
||||
_render_initial_message,
|
||||
)
|
||||
from backend.copilot.constants import COPILOT_SESSION_PREFIX
|
||||
from backend.copilot.executor.utils import enqueue_copilot_turn
|
||||
from backend.copilot.model import ChatMessage, ChatSession, create_chat_session
|
||||
from backend.copilot.session_types import ChatSessionConfig, ChatSessionStartType
|
||||
from backend.data.db_accessors import chat_db, invited_user_db, user_db
|
||||
from backend.data.model import User
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.timezone_utils import get_user_timezone_or_utc
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.invited_user import InvitedUserRecord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
DISPATCH_BATCH_SIZE = 500
|
||||
|
||||
|
||||
# --------------- tag helpers --------------- #
|
||||
|
||||
|
||||
def get_graph_exec_id_for_session(session_id: str) -> str:
|
||||
return f"{COPILOT_SESSION_PREFIX}{session_id}"
|
||||
|
||||
|
||||
def get_nightly_execution_tag(target_local_date: date) -> str:
|
||||
return f"{AUTOPILOT_NIGHTLY_TAG_PREFIX}{target_local_date.isoformat()}"
|
||||
|
||||
|
||||
def get_callback_execution_tag() -> str:
|
||||
return AUTOPILOT_CALLBACK_TAG
|
||||
|
||||
|
||||
def get_invite_cta_execution_tag() -> str:
|
||||
return AUTOPILOT_INVITE_CTA_TAG
|
||||
|
||||
|
||||
def _get_manual_trigger_execution_tag(start_type: ChatSessionStartType) -> str:
|
||||
timestamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%S%fZ")
|
||||
return f"admin-autopilot:{start_type.value}:{timestamp}:{uuid4()}"
|
||||
|
||||
|
||||
# --------------- timezone helpers --------------- #
|
||||
|
||||
|
||||
def _bucket_end_for_now(now_utc: datetime) -> datetime:
|
||||
minute = 30 if now_utc.minute >= 30 else 0
|
||||
return now_utc.replace(minute=minute, second=0, microsecond=0)
|
||||
|
||||
|
||||
def _resolve_timezone_name(raw_timezone: str | None) -> str:
|
||||
return get_user_timezone_or_utc(raw_timezone)
|
||||
|
||||
|
||||
def _crosses_local_midnight(
|
||||
bucket_start_utc: datetime,
|
||||
bucket_end_utc: datetime,
|
||||
timezone_name: str,
|
||||
) -> date | None:
|
||||
"""Return the new local date if *bucket_end_utc* falls on a different local
|
||||
date than *bucket_start_utc*, taking DST transitions into account.
|
||||
|
||||
During a DST spring-forward the wall clock jumps forward (e.g. 01:59 →
|
||||
03:00). We use ``fold=0`` on the end instant so that ambiguous/missing
|
||||
times are resolved consistently and a single 30-min UTC bucket can never
|
||||
produce midnight on *two* consecutive calls.
|
||||
"""
|
||||
tz = ZoneInfo(timezone_name)
|
||||
start_local = bucket_start_utc.astimezone(tz)
|
||||
end_local = bucket_end_utc.astimezone(tz)
|
||||
# Resolve ambiguous wall-clock times consistently (spring-forward / fall-back)
|
||||
end_local = end_local.replace(fold=0)
|
||||
if start_local.date() == end_local.date():
|
||||
return None
|
||||
return end_local.date()
|
||||
|
||||
|
||||
# --------------- thin DB wrappers --------------- #
|
||||
|
||||
|
||||
async def _user_has_recent_manual_message(user_id: str, since: datetime) -> bool:
|
||||
return await chat_db().has_recent_manual_message(user_id, since)
|
||||
|
||||
|
||||
async def _user_has_session_since(user_id: str, since: datetime) -> bool:
|
||||
return await chat_db().has_session_since(user_id, since)
|
||||
|
||||
|
||||
async def _session_exists_for_execution_tag(user_id: str, execution_tag: str) -> bool:
|
||||
return await chat_db().session_exists_for_execution_tag(user_id, execution_tag)
|
||||
|
||||
|
||||
# --------------- session creation --------------- #
|
||||
|
||||
|
||||
async def _enqueue_session_turn(
|
||||
session: ChatSession,
|
||||
*,
|
||||
message: str,
|
||||
tool_name: str,
|
||||
) -> None:
|
||||
turn_id = str(uuid4())
|
||||
await stream_registry.create_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
tool_call_id=tool_name,
|
||||
tool_name=tool_name,
|
||||
turn_id=turn_id,
|
||||
blocking=False,
|
||||
)
|
||||
await enqueue_copilot_turn(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
message=message,
|
||||
turn_id=turn_id,
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
|
||||
async def _create_autopilot_session(
|
||||
user: User,
|
||||
*,
|
||||
start_type: ChatSessionStartType,
|
||||
execution_tag: str,
|
||||
timezone_name: str,
|
||||
target_local_date: date | None = None,
|
||||
invited_user: InvitedUserRecord | None = None,
|
||||
) -> ChatSession | None:
|
||||
if await _session_exists_for_execution_tag(user.id, execution_tag):
|
||||
return None
|
||||
|
||||
system_prompt = await _build_autopilot_system_prompt(
|
||||
user,
|
||||
start_type=start_type,
|
||||
timezone_name=timezone_name,
|
||||
target_local_date=target_local_date,
|
||||
invited_user=invited_user,
|
||||
)
|
||||
initial_message = _render_initial_message(
|
||||
start_type,
|
||||
user_name=user.name,
|
||||
invited_user=invited_user,
|
||||
)
|
||||
session_config = ChatSessionConfig(
|
||||
system_prompt_override=system_prompt,
|
||||
initial_user_message=initial_message,
|
||||
extra_tools=["completion_report"],
|
||||
disabled_tools=AUTOPILOT_DISABLED_TOOLS,
|
||||
)
|
||||
|
||||
session = await create_chat_session(
|
||||
user.id,
|
||||
start_type=start_type,
|
||||
execution_tag=execution_tag,
|
||||
session_config=session_config,
|
||||
initial_messages=[ChatMessage(role="user", content=initial_message)],
|
||||
)
|
||||
await _enqueue_session_turn(
|
||||
session,
|
||||
message=initial_message,
|
||||
tool_name="autopilot_dispatch",
|
||||
)
|
||||
return session
|
||||
|
||||
|
||||
# --------------- cohort helpers --------------- #
|
||||
|
||||
|
||||
async def _try_create_invite_cta_session(
|
||||
user: User,
|
||||
*,
|
||||
invited_user: InvitedUserRecord | None,
|
||||
now_utc: datetime,
|
||||
timezone_name: str,
|
||||
invite_cta_start: date,
|
||||
invite_cta_delay: timedelta,
|
||||
) -> bool:
|
||||
if invited_user is None:
|
||||
return False
|
||||
if invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
|
||||
return False
|
||||
if invited_user.created_at.date() < invite_cta_start:
|
||||
return False
|
||||
if invited_user.created_at > now_utc - invite_cta_delay:
|
||||
return False
|
||||
if await _session_exists_for_execution_tag(user.id, get_invite_cta_execution_tag()):
|
||||
return False
|
||||
|
||||
created = await _create_autopilot_session(
|
||||
user,
|
||||
start_type=ChatSessionStartType.AUTOPILOT_INVITE_CTA,
|
||||
execution_tag=get_invite_cta_execution_tag(),
|
||||
timezone_name=timezone_name,
|
||||
invited_user=invited_user,
|
||||
)
|
||||
return created is not None
|
||||
|
||||
|
||||
async def _try_create_nightly_session(
|
||||
user: User,
|
||||
*,
|
||||
now_utc: datetime,
|
||||
timezone_name: str,
|
||||
target_local_date: date,
|
||||
) -> bool:
|
||||
if not await _user_has_recent_manual_message(
|
||||
user.id,
|
||||
now_utc - timedelta(hours=24),
|
||||
):
|
||||
return False
|
||||
|
||||
created = await _create_autopilot_session(
|
||||
user,
|
||||
start_type=ChatSessionStartType.AUTOPILOT_NIGHTLY,
|
||||
execution_tag=get_nightly_execution_tag(target_local_date),
|
||||
timezone_name=timezone_name,
|
||||
target_local_date=target_local_date,
|
||||
)
|
||||
return created is not None
|
||||
|
||||
|
||||
async def _try_create_callback_session(
|
||||
user: User,
|
||||
*,
|
||||
callback_start: datetime,
|
||||
timezone_name: str,
|
||||
) -> bool:
|
||||
if not await _user_has_session_since(user.id, callback_start):
|
||||
return False
|
||||
if await _session_exists_for_execution_tag(user.id, get_callback_execution_tag()):
|
||||
return False
|
||||
|
||||
created = await _create_autopilot_session(
|
||||
user,
|
||||
start_type=ChatSessionStartType.AUTOPILOT_CALLBACK,
|
||||
execution_tag=get_callback_execution_tag(),
|
||||
timezone_name=timezone_name,
|
||||
)
|
||||
return created is not None
|
||||
|
||||
|
||||
# --------------- dispatch --------------- #
|
||||
|
||||
|
||||
async def _dispatch_nightly_copilot() -> int:
|
||||
now_utc = datetime.now(UTC)
|
||||
bucket_end = _bucket_end_for_now(now_utc)
|
||||
bucket_start = bucket_end - timedelta(minutes=30)
|
||||
callback_start = datetime.combine(
|
||||
settings.config.nightly_copilot_callback_start_date,
|
||||
time.min,
|
||||
tzinfo=UTC,
|
||||
)
|
||||
invite_cta_start = settings.config.nightly_copilot_invite_cta_start_date
|
||||
invite_cta_delay = timedelta(
|
||||
hours=settings.config.nightly_copilot_invite_cta_delay_hours
|
||||
)
|
||||
|
||||
# Paginate user list to avoid loading the entire table into memory.
|
||||
created_count = 0
|
||||
cursor: str | None = None
|
||||
while True:
|
||||
batch = await user_db().list_users(
|
||||
limit=DISPATCH_BATCH_SIZE,
|
||||
cursor=cursor,
|
||||
)
|
||||
if not batch:
|
||||
break
|
||||
|
||||
user_ids = [user.id for user in batch]
|
||||
invites = await invited_user_db().list_invited_users_for_auth_users(user_ids)
|
||||
invites_by_user_id = {
|
||||
invite.auth_user_id: invite for invite in invites if invite.auth_user_id
|
||||
}
|
||||
|
||||
for user in batch:
|
||||
if not await is_feature_enabled(
|
||||
Flag.NIGHTLY_COPILOT, user.id, default=False
|
||||
):
|
||||
continue
|
||||
|
||||
timezone_name = _resolve_timezone_name(user.timezone)
|
||||
target_local_date = _crosses_local_midnight(
|
||||
bucket_start,
|
||||
bucket_end,
|
||||
timezone_name,
|
||||
)
|
||||
if target_local_date is None:
|
||||
continue
|
||||
|
||||
invited_user = invites_by_user_id.get(user.id)
|
||||
if await _try_create_invite_cta_session(
|
||||
user,
|
||||
invited_user=invited_user,
|
||||
now_utc=now_utc,
|
||||
timezone_name=timezone_name,
|
||||
invite_cta_start=invite_cta_start,
|
||||
invite_cta_delay=invite_cta_delay,
|
||||
):
|
||||
created_count += 1
|
||||
continue
|
||||
|
||||
if await _try_create_nightly_session(
|
||||
user,
|
||||
now_utc=now_utc,
|
||||
timezone_name=timezone_name,
|
||||
target_local_date=target_local_date,
|
||||
):
|
||||
created_count += 1
|
||||
continue
|
||||
|
||||
if await _try_create_callback_session(
|
||||
user,
|
||||
callback_start=callback_start,
|
||||
timezone_name=timezone_name,
|
||||
):
|
||||
created_count += 1
|
||||
|
||||
cursor = batch[-1].id if len(batch) == DISPATCH_BATCH_SIZE else None
|
||||
if cursor is None:
|
||||
break
|
||||
|
||||
return created_count
|
||||
|
||||
|
||||
async def dispatch_nightly_copilot() -> int:
|
||||
return await _dispatch_nightly_copilot()
|
||||
|
||||
|
||||
async def trigger_autopilot_session_for_user(
|
||||
user_id: str,
|
||||
*,
|
||||
start_type: ChatSessionStartType,
|
||||
) -> ChatSession:
|
||||
allowed_start_types = {
|
||||
ChatSessionStartType.AUTOPILOT_INVITE_CTA,
|
||||
ChatSessionStartType.AUTOPILOT_NIGHTLY,
|
||||
ChatSessionStartType.AUTOPILOT_CALLBACK,
|
||||
}
|
||||
if start_type not in allowed_start_types:
|
||||
raise ValueError(f"Unsupported autopilot start type: {start_type}")
|
||||
|
||||
try:
|
||||
user = await user_db().get_user_by_id(user_id)
|
||||
except ValueError as exc:
|
||||
raise LookupError(str(exc)) from exc
|
||||
|
||||
invites = await invited_user_db().list_invited_users_for_auth_users([user_id])
|
||||
invited_user = invites[0] if invites else None
|
||||
timezone_name = _resolve_timezone_name(user.timezone)
|
||||
target_local_date = None
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_NIGHTLY:
|
||||
target_local_date = datetime.now(UTC).astimezone(ZoneInfo(timezone_name)).date()
|
||||
|
||||
session = await _create_autopilot_session(
|
||||
user,
|
||||
start_type=start_type,
|
||||
execution_tag=_get_manual_trigger_execution_tag(start_type),
|
||||
timezone_name=timezone_name,
|
||||
target_local_date=target_local_date,
|
||||
invited_user=invited_user,
|
||||
)
|
||||
if session is None:
|
||||
raise ValueError("Failed to create autopilot session")
|
||||
|
||||
return session
|
||||
297
autogpt_platform/backend/backend/copilot/autopilot_email.py
Normal file
297
autogpt_platform/backend/backend/copilot/autopilot_email.py
Normal file
@@ -0,0 +1,297 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from markdown_it import MarkdownIt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.autopilot_completion import (
|
||||
_get_pending_approval_metadata,
|
||||
_queue_completion_report_repair,
|
||||
)
|
||||
from backend.copilot.autopilot_prompts import (
|
||||
AUTOPILOT_CALLBACK_EMAIL_TEMPLATE,
|
||||
AUTOPILOT_INVITE_CTA_EMAIL_TEMPLATE,
|
||||
AUTOPILOT_NIGHTLY_EMAIL_TEMPLATE,
|
||||
MAX_COMPLETION_REPORT_REPAIRS,
|
||||
)
|
||||
from backend.copilot.model import (
|
||||
ChatSession,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.copilot.service import _generate_session_title
|
||||
from backend.copilot.session_types import ChatSessionStartType
|
||||
from backend.data.db_accessors import chat_db, user_db
|
||||
from backend.notifications.email import EmailSender
|
||||
from backend.util.url import get_frontend_base_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
PENDING_NOTIFICATION_SWEEP_LIMIT = 200
|
||||
|
||||
_md = MarkdownIt()
|
||||
|
||||
_EMAIL_INLINE_STYLES: list[tuple[str, str]] = [
|
||||
(
|
||||
"<p>",
|
||||
'<p style="font-size: 15px; line-height: 170%;'
|
||||
" margin-top: 0; margin-bottom: 16px;"
|
||||
' color: #1F1F20;">',
|
||||
),
|
||||
(
|
||||
"<li>",
|
||||
'<li style="font-size: 15px; line-height: 170%;'
|
||||
" margin-top: 0; margin-bottom: 8px;"
|
||||
' color: #1F1F20;">',
|
||||
),
|
||||
(
|
||||
"<ul>",
|
||||
'<ul style="padding: 0 0 0 24px;' ' margin-top: 0; margin-bottom: 16px;">',
|
||||
),
|
||||
(
|
||||
"<ol>",
|
||||
'<ol style="padding: 0 0 0 24px;' ' margin-top: 0; margin-bottom: 16px;">',
|
||||
),
|
||||
(
|
||||
"<a ",
|
||||
'<a style="color: #7733F5;'
|
||||
" text-decoration: underline;"
|
||||
' font-weight: 500;" ',
|
||||
),
|
||||
(
|
||||
"<h2>",
|
||||
'<h2 style="font-size: 20px; font-weight: 600;'
|
||||
' margin-top: 0; margin-bottom: 12px; color: #1F1F20;">',
|
||||
),
|
||||
(
|
||||
"<h3>",
|
||||
'<h3 style="font-size: 18px; font-weight: 600;'
|
||||
' margin-top: 0; margin-bottom: 12px; color: #1F1F20;">',
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _markdown_to_email_html(text: str | None) -> str:
|
||||
"""Convert markdown text to email-safe HTML with inline styles."""
|
||||
if not text or not text.strip():
|
||||
return ""
|
||||
html = _md.render(text.strip())
|
||||
for tag, styled_tag in _EMAIL_INLINE_STYLES:
|
||||
html = html.replace(tag, styled_tag)
|
||||
return html.strip()
|
||||
|
||||
|
||||
# --------------- link builders --------------- #
|
||||
|
||||
|
||||
def _build_session_link(session_id: str, *, show_autopilot: bool) -> str:
|
||||
base_url = get_frontend_base_url()
|
||||
suffix = "&showAutopilot=1" if show_autopilot else ""
|
||||
return f"{base_url}/copilot?sessionId={session_id}{suffix}"
|
||||
|
||||
|
||||
def _get_completion_email_template_name(start_type: ChatSessionStartType) -> str:
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_NIGHTLY:
|
||||
return AUTOPILOT_NIGHTLY_EMAIL_TEMPLATE
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_CALLBACK:
|
||||
return AUTOPILOT_CALLBACK_EMAIL_TEMPLATE
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_INVITE_CTA:
|
||||
return AUTOPILOT_INVITE_CTA_EMAIL_TEMPLATE
|
||||
raise ValueError(f"Unsupported start type for completion email: {start_type}")
|
||||
|
||||
|
||||
class PendingCopilotEmailSweepResult(BaseModel):
|
||||
candidate_count: int = 0
|
||||
processed_count: int = 0
|
||||
sent_count: int = 0
|
||||
skipped_count: int = 0
|
||||
repair_queued_count: int = 0
|
||||
running_count: int = 0
|
||||
failed_count: int = 0
|
||||
|
||||
|
||||
async def _ensure_session_title_for_completed_session(session: ChatSession) -> None:
|
||||
if session.title or not session.user_id:
|
||||
return
|
||||
|
||||
report = session.completion_report
|
||||
if report is None:
|
||||
return
|
||||
|
||||
title = report.email_title.strip() if report.email_title else ""
|
||||
if not title:
|
||||
title_seed = report.email_body or report.thoughts
|
||||
if title_seed:
|
||||
generated_title = await _generate_session_title(
|
||||
title_seed,
|
||||
user_id=session.user_id,
|
||||
session_id=session.session_id,
|
||||
)
|
||||
title = generated_title.strip() if generated_title else ""
|
||||
|
||||
if not title:
|
||||
return
|
||||
|
||||
updated = await update_session_title(
|
||||
session.session_id,
|
||||
session.user_id,
|
||||
title,
|
||||
only_if_empty=True,
|
||||
)
|
||||
if updated:
|
||||
session.title = title
|
||||
|
||||
|
||||
# --------------- send email --------------- #
|
||||
|
||||
|
||||
async def _send_completion_email(session: ChatSession) -> None:
|
||||
report = session.completion_report
|
||||
if report is None:
|
||||
raise ValueError("Missing completion report")
|
||||
try:
|
||||
user = await user_db().get_user_by_id(session.user_id)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"User {session.user_id} not found") from exc
|
||||
|
||||
if not user.email:
|
||||
raise ValueError(f"User {session.user_id} not found")
|
||||
|
||||
approval_cta = report.has_pending_approvals
|
||||
template_name = _get_completion_email_template_name(session.start_type)
|
||||
if approval_cta:
|
||||
cta_url = _build_session_link(session.session_id, show_autopilot=True)
|
||||
cta_label = "Review in Copilot"
|
||||
else:
|
||||
cta_url = _build_session_link(session.session_id, show_autopilot=True)
|
||||
cta_label = (
|
||||
"Try Copilot"
|
||||
if session.start_type == ChatSessionStartType.AUTOPILOT_INVITE_CTA
|
||||
else "Open Copilot"
|
||||
)
|
||||
|
||||
# EmailSender.send_template is synchronous (blocking HTTP call to Postmark).
|
||||
# Run it in a thread to avoid blocking the async event loop.
|
||||
sender = EmailSender()
|
||||
await asyncio.to_thread(
|
||||
sender.send_template,
|
||||
user_email=user.email,
|
||||
subject=report.email_title or "Autopilot update",
|
||||
template_name=template_name,
|
||||
data={
|
||||
"email_body_html": _markdown_to_email_html(report.email_body),
|
||||
"approval_summary_html": _markdown_to_email_html(report.approval_summary),
|
||||
"cta_url": cta_url,
|
||||
"cta_label": cta_label,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# --------------- email sweep --------------- #
|
||||
|
||||
|
||||
async def _process_pending_copilot_email_candidates(
|
||||
candidates: list,
|
||||
) -> PendingCopilotEmailSweepResult:
|
||||
result = PendingCopilotEmailSweepResult(candidate_count=len(candidates))
|
||||
|
||||
for candidate in candidates:
|
||||
session = await get_chat_session(candidate.session_id)
|
||||
if session is None or session.is_manual:
|
||||
continue
|
||||
|
||||
active = await stream_registry.get_session(session.session_id)
|
||||
is_running = active is not None and active.status == "running"
|
||||
if is_running:
|
||||
result.running_count += 1
|
||||
continue
|
||||
|
||||
pending_approval_count, graph_exec_id = await _get_pending_approval_metadata(
|
||||
session
|
||||
)
|
||||
|
||||
if session.completion_report is None:
|
||||
if session.completion_report_repair_count < MAX_COMPLETION_REPORT_REPAIRS:
|
||||
await _queue_completion_report_repair(
|
||||
session,
|
||||
pending_approval_count=pending_approval_count,
|
||||
)
|
||||
result.repair_queued_count += 1
|
||||
continue
|
||||
|
||||
session.completed_at = session.completed_at or datetime.now(UTC)
|
||||
session.completion_report_repair_queued_at = None
|
||||
session.notification_email_skipped_at = datetime.now(UTC)
|
||||
await upsert_chat_session(session)
|
||||
result.skipped_count += 1
|
||||
continue
|
||||
|
||||
session.completed_at = session.completed_at or datetime.now(UTC)
|
||||
if (
|
||||
session.completion_report.pending_approval_graph_exec_id is None
|
||||
and graph_exec_id
|
||||
):
|
||||
session.completion_report = session.completion_report.model_copy(
|
||||
update={
|
||||
"has_pending_approvals": pending_approval_count > 0,
|
||||
"pending_approval_count": pending_approval_count,
|
||||
"pending_approval_graph_exec_id": graph_exec_id,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
await _ensure_session_title_for_completed_session(session)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to ensure session title for session %s",
|
||||
session.session_id,
|
||||
)
|
||||
|
||||
if not session.completion_report.should_notify_user:
|
||||
session.notification_email_skipped_at = datetime.now(UTC)
|
||||
await upsert_chat_session(session)
|
||||
result.skipped_count += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
await _send_completion_email(session)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to send nightly copilot email for session %s",
|
||||
session.session_id,
|
||||
)
|
||||
result.failed_count += 1
|
||||
continue
|
||||
|
||||
session.notification_email_sent_at = datetime.now(UTC)
|
||||
await upsert_chat_session(session)
|
||||
result.sent_count += 1
|
||||
|
||||
result.processed_count = result.sent_count + result.skipped_count
|
||||
return result
|
||||
|
||||
|
||||
async def _send_nightly_copilot_emails() -> int:
|
||||
candidates = await chat_db().get_pending_notification_chat_sessions(
|
||||
limit=PENDING_NOTIFICATION_SWEEP_LIMIT
|
||||
)
|
||||
result = await _process_pending_copilot_email_candidates(candidates)
|
||||
return result.processed_count
|
||||
|
||||
|
||||
async def send_nightly_copilot_emails() -> int:
|
||||
return await _send_nightly_copilot_emails()
|
||||
|
||||
|
||||
async def send_pending_copilot_emails_for_user(
|
||||
user_id: str,
|
||||
) -> PendingCopilotEmailSweepResult:
|
||||
candidates = await chat_db().get_pending_notification_chat_sessions_for_user(
|
||||
user_id,
|
||||
limit=PENDING_NOTIFICATION_SWEEP_LIMIT,
|
||||
)
|
||||
return await _process_pending_copilot_email_candidates(candidates)
|
||||
409
autogpt_platform/backend/backend/copilot/autopilot_prompts.py
Normal file
409
autogpt_platform/backend/backend/copilot/autopilot_prompts.py
Normal file
@@ -0,0 +1,409 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from backend.copilot.service import _get_system_prompt_template
|
||||
from backend.copilot.service import config as chat_config
|
||||
from backend.copilot.session_types import ChatSessionStartType
|
||||
from backend.data.db_accessors import chat_db, understanding_db
|
||||
from backend.data.understanding import format_understanding_for_prompt
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.invited_user import InvitedUserRecord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
INTERNAL_TAG_RE = re.compile(r"<internal>.*?</internal>", re.DOTALL)
|
||||
MAX_COMPLETION_REPORT_REPAIRS = 2
|
||||
AUTOPILOT_RECENT_CONTEXT_CHAR_LIMIT = 6000
|
||||
AUTOPILOT_RECENT_SESSION_LIMIT = 5
|
||||
AUTOPILOT_RECENT_MESSAGE_LIMIT = 6
|
||||
AUTOPILOT_MESSAGE_CHAR_LIMIT = 500
|
||||
AUTOPILOT_EMAIL_HISTORY_LIMIT = 5
|
||||
AUTOPILOT_SESSION_SUMMARY_LIMIT = 2
|
||||
|
||||
AUTOPILOT_NIGHTLY_TAG_PREFIX = "autopilot-nightly:"
|
||||
AUTOPILOT_CALLBACK_TAG = "autopilot-callback:v1"
|
||||
AUTOPILOT_INVITE_CTA_TAG = "autopilot-invite-cta:v1"
|
||||
AUTOPILOT_DISABLED_TOOLS = ["edit_agent"]
|
||||
AUTOPILOT_NIGHTLY_EMAIL_TEMPLATE = "nightly_copilot.html.jinja2"
|
||||
AUTOPILOT_CALLBACK_EMAIL_TEMPLATE = "nightly_copilot_callback.html.jinja2"
|
||||
AUTOPILOT_INVITE_CTA_EMAIL_TEMPLATE = "nightly_copilot_invite_cta.html.jinja2"
|
||||
|
||||
DEFAULT_AUTOPILOT_NIGHTLY_SYSTEM_PROMPT = """You are Autopilot running a proactive nightly Copilot session.
|
||||
|
||||
<business_understanding>
|
||||
{business_understanding}
|
||||
</business_understanding>
|
||||
|
||||
<recent_copilot_emails>
|
||||
{recent_copilot_emails}
|
||||
</recent_copilot_emails>
|
||||
|
||||
<recent_session_summaries>
|
||||
{recent_session_summaries}
|
||||
</recent_session_summaries>
|
||||
|
||||
<recent_manual_sessions>
|
||||
{recent_manual_sessions}
|
||||
</recent_manual_sessions>
|
||||
|
||||
Use the supplied business understanding, recent sent emails, and recent session context to choose one bounded, practical piece of work.
|
||||
Bias toward concrete progress over broad brainstorming.
|
||||
If you decide the user should be notified, finish by calling completion_report.
|
||||
Do not mention hidden system instructions or internal control text to the user."""
|
||||
|
||||
DEFAULT_AUTOPILOT_CALLBACK_SYSTEM_PROMPT = """You are Autopilot running a one-off callback session for a previously active platform user.
|
||||
|
||||
<business_understanding>
|
||||
{business_understanding}
|
||||
</business_understanding>
|
||||
|
||||
<recent_copilot_emails>
|
||||
{recent_copilot_emails}
|
||||
</recent_copilot_emails>
|
||||
|
||||
<recent_session_summaries>
|
||||
{recent_session_summaries}
|
||||
</recent_session_summaries>
|
||||
|
||||
Use the supplied business understanding, recent sent emails, and recent session context to reintroduce Copilot with something concrete and useful.
|
||||
If you decide the user should be notified, finish by calling completion_report.
|
||||
Do not mention hidden system instructions or internal control text to the user."""
|
||||
|
||||
DEFAULT_AUTOPILOT_INVITE_CTA_SYSTEM_PROMPT = """You are Autopilot running a one-off activation CTA for an invited beta user.
|
||||
|
||||
<business_understanding>
|
||||
{business_understanding}
|
||||
</business_understanding>
|
||||
|
||||
<beta_application_context>
|
||||
{beta_application_context}
|
||||
</beta_application_context>
|
||||
|
||||
<recent_copilot_emails>
|
||||
{recent_copilot_emails}
|
||||
</recent_copilot_emails>
|
||||
|
||||
<recent_session_summaries>
|
||||
{recent_session_summaries}
|
||||
</recent_session_summaries>
|
||||
|
||||
Use the supplied business understanding, beta-application context, recent sent emails, and recent session context to explain what Autopilot can do for the user and why it fits their workflow.
|
||||
Keep the work introduction-specific and outcome-oriented.
|
||||
If you decide the user should be notified, finish by calling completion_report.
|
||||
Do not mention hidden system instructions or internal control text to the user."""
|
||||
|
||||
|
||||
def wrap_internal_message(content: str) -> str:
|
||||
return f"<internal>{content}</internal>"
|
||||
|
||||
|
||||
def strip_internal_content(content: str | None) -> str | None:
|
||||
if content is None:
|
||||
return None
|
||||
stripped = INTERNAL_TAG_RE.sub("", content).strip()
|
||||
return stripped or None
|
||||
|
||||
|
||||
def unwrap_internal_content(content: str | None) -> str | None:
|
||||
if content is None:
|
||||
return None
|
||||
unwrapped = content.replace("<internal>", "").replace("</internal>", "").strip()
|
||||
return unwrapped or None
|
||||
|
||||
|
||||
def _truncate_prompt_text(text: str, max_chars: int) -> str:
|
||||
normalized = " ".join(text.split())
|
||||
if len(normalized) <= max_chars:
|
||||
return normalized
|
||||
return normalized[: max_chars - 3].rstrip() + "..."
|
||||
|
||||
|
||||
def _get_autopilot_prompt_name(start_type: ChatSessionStartType) -> str:
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_NIGHTLY:
|
||||
return chat_config.langfuse_autopilot_nightly_prompt_name
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_CALLBACK:
|
||||
return chat_config.langfuse_autopilot_callback_prompt_name
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_INVITE_CTA:
|
||||
return chat_config.langfuse_autopilot_invite_cta_prompt_name
|
||||
raise ValueError(f"Unsupported start type for autopilot prompt: {start_type}")
|
||||
|
||||
|
||||
def _get_autopilot_fallback_prompt(start_type: ChatSessionStartType) -> str:
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_NIGHTLY:
|
||||
return DEFAULT_AUTOPILOT_NIGHTLY_SYSTEM_PROMPT
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_CALLBACK:
|
||||
return DEFAULT_AUTOPILOT_CALLBACK_SYSTEM_PROMPT
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_INVITE_CTA:
|
||||
return DEFAULT_AUTOPILOT_INVITE_CTA_SYSTEM_PROMPT
|
||||
raise ValueError(f"Unsupported start type for autopilot prompt: {start_type}")
|
||||
|
||||
|
||||
def _format_start_type_label(start_type: ChatSessionStartType) -> str:
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_NIGHTLY:
|
||||
return "Nightly"
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_CALLBACK:
|
||||
return "Callback"
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_INVITE_CTA:
|
||||
return "Beta Invite CTA"
|
||||
return start_type.value
|
||||
|
||||
|
||||
def _get_invited_user_tally_understanding(
|
||||
invited_user: InvitedUserRecord | None,
|
||||
) -> dict[str, Any] | None:
|
||||
return invited_user.tally_understanding if invited_user is not None else None
|
||||
|
||||
|
||||
def _render_initial_message(
|
||||
start_type: ChatSessionStartType,
|
||||
*,
|
||||
user_name: str | None,
|
||||
invited_user: InvitedUserRecord | None = None,
|
||||
) -> str:
|
||||
display_name = user_name or "the user"
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_NIGHTLY:
|
||||
return wrap_internal_message(
|
||||
"This is a nightly proactive Copilot session. Review recent manual activity, "
|
||||
f"do one useful piece of work for {display_name}, and finish with completion_report."
|
||||
)
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_CALLBACK:
|
||||
return wrap_internal_message(
|
||||
"This is a one-off callback session for a previously active user. "
|
||||
f"Reintroduce Copilot with something concrete and useful for {display_name}, "
|
||||
"then finish with completion_report."
|
||||
)
|
||||
|
||||
invite_summary = ""
|
||||
tally_understanding = _get_invited_user_tally_understanding(invited_user)
|
||||
if tally_understanding is not None:
|
||||
invite_summary = "\nKnown context from the beta application:\n" + json.dumps(
|
||||
tally_understanding, ensure_ascii=False
|
||||
)
|
||||
return wrap_internal_message(
|
||||
"This is a one-off invite CTA session for an invited beta user who has not yet activated. "
|
||||
f"Create a tailored introduction for {display_name}, explain how Autopilot can help, "
|
||||
f"and finish with completion_report.{invite_summary}"
|
||||
)
|
||||
|
||||
|
||||
def _get_previous_local_midnight_utc(
|
||||
target_local_date: date,
|
||||
timezone_name: str,
|
||||
) -> datetime:
|
||||
from datetime import UTC
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
tz = ZoneInfo(timezone_name)
|
||||
previous_midnight_local = datetime.combine(
|
||||
target_local_date - timedelta(days=1),
|
||||
time.min,
|
||||
tzinfo=tz,
|
||||
)
|
||||
return previous_midnight_local.astimezone(UTC)
|
||||
|
||||
|
||||
async def _get_recent_manual_session_context(
|
||||
user_id: str,
|
||||
*,
|
||||
since_utc: datetime,
|
||||
) -> str:
|
||||
sessions = await chat_db().get_manual_chat_sessions_since(
|
||||
user_id,
|
||||
since_utc,
|
||||
AUTOPILOT_RECENT_SESSION_LIMIT,
|
||||
)
|
||||
|
||||
if not sessions:
|
||||
return "No recent manual sessions since the previous nightly run."
|
||||
|
||||
blocks: list[str] = []
|
||||
used_chars = 0
|
||||
|
||||
for session in sessions:
|
||||
messages = await chat_db().get_chat_messages_since(
|
||||
session.session_id, since_utc
|
||||
)
|
||||
|
||||
visible_messages: list[str] = []
|
||||
for message in messages[-AUTOPILOT_RECENT_MESSAGE_LIMIT:]:
|
||||
content = message.content or ""
|
||||
if message.role == "user":
|
||||
visible = strip_internal_content(content)
|
||||
else:
|
||||
visible = content.strip() or None
|
||||
if not visible:
|
||||
continue
|
||||
|
||||
role_label = {
|
||||
"user": "User",
|
||||
"assistant": "Assistant",
|
||||
"tool": "Tool",
|
||||
}.get(message.role, message.role.title())
|
||||
visible_messages.append(
|
||||
f"{role_label}: {_truncate_prompt_text(visible, AUTOPILOT_MESSAGE_CHAR_LIMIT)}"
|
||||
)
|
||||
|
||||
if not visible_messages:
|
||||
continue
|
||||
|
||||
title_suffix = f" ({session.title})" if session.title else ""
|
||||
block = (
|
||||
f"### Session updated {session.updated_at.isoformat()}{title_suffix}\n"
|
||||
+ "\n".join(visible_messages)
|
||||
)
|
||||
if used_chars + len(block) > AUTOPILOT_RECENT_CONTEXT_CHAR_LIMIT:
|
||||
break
|
||||
|
||||
blocks.append(block)
|
||||
used_chars += len(block)
|
||||
|
||||
return (
|
||||
"\n\n".join(blocks)
|
||||
if blocks
|
||||
else "No recent manual sessions since the previous nightly run."
|
||||
)
|
||||
|
||||
|
||||
async def _get_recent_sent_email_context(user_id: str) -> str:
|
||||
sessions = await chat_db().get_recent_sent_email_chat_sessions(
|
||||
user_id,
|
||||
AUTOPILOT_EMAIL_HISTORY_LIMIT,
|
||||
)
|
||||
if not sessions:
|
||||
return "No recent Copilot or Autopilot emails have been sent to this user."
|
||||
|
||||
blocks: list[str] = []
|
||||
for session in sessions:
|
||||
report = session.completion_report
|
||||
sent_at = session.notification_email_sent_at
|
||||
if report is None or sent_at is None:
|
||||
continue
|
||||
|
||||
lines = [
|
||||
f"### Sent {sent_at.isoformat()} ({_format_start_type_label(session.start_type)})",
|
||||
]
|
||||
if report.email_title:
|
||||
lines.append(
|
||||
f"Subject: {_truncate_prompt_text(report.email_title, AUTOPILOT_MESSAGE_CHAR_LIMIT)}"
|
||||
)
|
||||
if report.email_body:
|
||||
lines.append(
|
||||
f"Body: {_truncate_prompt_text(report.email_body, AUTOPILOT_MESSAGE_CHAR_LIMIT)}"
|
||||
)
|
||||
if report.callback_session_message:
|
||||
lines.append(
|
||||
"CTA Message: "
|
||||
+ _truncate_prompt_text(
|
||||
report.callback_session_message,
|
||||
AUTOPILOT_MESSAGE_CHAR_LIMIT,
|
||||
)
|
||||
)
|
||||
blocks.append("\n".join(lines))
|
||||
|
||||
return (
|
||||
"\n\n".join(blocks)
|
||||
if blocks
|
||||
else "No recent Copilot or Autopilot emails have been sent to this user."
|
||||
)
|
||||
|
||||
|
||||
async def _get_recent_session_summary_context(user_id: str) -> str:
|
||||
sessions = await chat_db().get_recent_completion_report_chat_sessions(
|
||||
user_id,
|
||||
AUTOPILOT_SESSION_SUMMARY_LIMIT,
|
||||
)
|
||||
if not sessions:
|
||||
return "No recent Copilot session summaries are available."
|
||||
|
||||
blocks: list[str] = []
|
||||
for session in sessions:
|
||||
report = session.completion_report
|
||||
if report is None:
|
||||
continue
|
||||
|
||||
title_suffix = f" ({session.title})" if session.title else ""
|
||||
lines = [
|
||||
f"### {_format_start_type_label(session.start_type)} session updated {session.updated_at.isoformat()}{title_suffix}",
|
||||
f"Summary: {_truncate_prompt_text(report.thoughts, AUTOPILOT_MESSAGE_CHAR_LIMIT)}",
|
||||
]
|
||||
if report.email_title:
|
||||
lines.append(
|
||||
"Email Title: "
|
||||
+ _truncate_prompt_text(
|
||||
report.email_title, AUTOPILOT_MESSAGE_CHAR_LIMIT
|
||||
)
|
||||
)
|
||||
blocks.append("\n".join(lines))
|
||||
|
||||
return (
|
||||
"\n\n".join(blocks)
|
||||
if blocks
|
||||
else "No recent Copilot session summaries are available."
|
||||
)
|
||||
|
||||
|
||||
async def _build_autopilot_system_prompt(
|
||||
user: Any,
|
||||
*,
|
||||
start_type: ChatSessionStartType,
|
||||
timezone_name: str,
|
||||
target_local_date: date | None = None,
|
||||
invited_user: InvitedUserRecord | None = None,
|
||||
) -> str:
|
||||
understanding = await understanding_db().get_business_understanding(user.id)
|
||||
business_understanding = (
|
||||
format_understanding_for_prompt(understanding)
|
||||
if understanding
|
||||
else "No saved business understanding yet."
|
||||
)
|
||||
recent_copilot_emails = await _get_recent_sent_email_context(user.id)
|
||||
recent_session_summaries = await _get_recent_session_summary_context(user.id)
|
||||
recent_manual_sessions = "Not applicable for this prompt type."
|
||||
beta_application_context = "No beta application context available."
|
||||
|
||||
users_information_sections = [
|
||||
"## Business Understanding\n" + business_understanding
|
||||
]
|
||||
users_information_sections.append(
|
||||
"## Recent Copilot Emails Sent To User\n" + recent_copilot_emails
|
||||
)
|
||||
users_information_sections.append(
|
||||
"## Recent Copilot Session Summaries\n" + recent_session_summaries
|
||||
)
|
||||
users_information = "\n\n".join(users_information_sections)
|
||||
|
||||
if (
|
||||
start_type == ChatSessionStartType.AUTOPILOT_NIGHTLY
|
||||
and target_local_date is not None
|
||||
):
|
||||
recent_manual_sessions = await _get_recent_manual_session_context(
|
||||
user.id,
|
||||
since_utc=_get_previous_local_midnight_utc(
|
||||
target_local_date,
|
||||
timezone_name,
|
||||
),
|
||||
)
|
||||
|
||||
tally_understanding = _get_invited_user_tally_understanding(invited_user)
|
||||
if tally_understanding is not None:
|
||||
beta_application_context = json.dumps(tally_understanding, ensure_ascii=False)
|
||||
|
||||
return await _get_system_prompt_template(
|
||||
users_information,
|
||||
prompt_name=_get_autopilot_prompt_name(start_type),
|
||||
fallback_prompt=_get_autopilot_fallback_prompt(start_type),
|
||||
template_vars={
|
||||
"users_information": users_information,
|
||||
"business_understanding": business_understanding,
|
||||
"recent_copilot_emails": recent_copilot_emails,
|
||||
"recent_session_summaries": recent_session_summaries,
|
||||
"recent_manual_sessions": recent_manual_sessions,
|
||||
"beta_application_context": beta_application_context,
|
||||
},
|
||||
)
|
||||
1088
autogpt_platform/backend/backend/copilot/autopilot_test.py
Normal file
1088
autogpt_platform/backend/backend/copilot/autopilot_test.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -38,8 +38,8 @@ from backend.copilot.response_model import (
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import (
|
||||
_build_system_prompt,
|
||||
_generate_session_title,
|
||||
_resolve_system_prompt,
|
||||
client,
|
||||
config,
|
||||
)
|
||||
@@ -160,7 +160,7 @@ async def stream_chat_completion_baseline(
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Generate title for new sessions
|
||||
if is_user_message and not session.title:
|
||||
if is_user_message and session.is_manual and not session.title:
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
if len(user_messages) == 1:
|
||||
first_message = user_messages[0].content or message or ""
|
||||
@@ -177,16 +177,20 @@ async def stream_chat_completion_baseline(
|
||||
# changes from concurrent chats updating business understanding.
|
||||
is_first_turn = len(session.messages) <= 1
|
||||
if is_first_turn:
|
||||
base_system_prompt, _ = await _build_system_prompt(
|
||||
user_id, has_conversation_history=False
|
||||
base_system_prompt, _ = await _resolve_system_prompt(
|
||||
session,
|
||||
user_id,
|
||||
has_conversation_history=False,
|
||||
)
|
||||
else:
|
||||
base_system_prompt, _ = await _build_system_prompt(
|
||||
user_id=None, has_conversation_history=True
|
||||
base_system_prompt, _ = await _resolve_system_prompt(
|
||||
session,
|
||||
user_id=None,
|
||||
has_conversation_history=True,
|
||||
)
|
||||
|
||||
# Append tool documentation and technical notes
|
||||
system_prompt = base_system_prompt + get_baseline_supplement()
|
||||
system_prompt = base_system_prompt + get_baseline_supplement(session)
|
||||
|
||||
# Compress context if approaching the model's token limit
|
||||
messages_for_context = await _compress_session_messages(session.messages)
|
||||
@@ -199,7 +203,7 @@ async def stream_chat_completion_baseline(
|
||||
if msg.role in ("user", "assistant") and msg.content:
|
||||
openai_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
tools = get_available_tools()
|
||||
tools = get_available_tools(session)
|
||||
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
|
||||
|
||||
@@ -65,6 +65,18 @@ class ChatConfig(BaseSettings):
|
||||
default="CoPilot Prompt",
|
||||
description="Name of the prompt in Langfuse to fetch",
|
||||
)
|
||||
langfuse_autopilot_nightly_prompt_name: str = Field(
|
||||
default="CoPilot Nightly",
|
||||
description="Langfuse prompt name for nightly Autopilot sessions",
|
||||
)
|
||||
langfuse_autopilot_callback_prompt_name: str = Field(
|
||||
default="CoPilot Callback",
|
||||
description="Langfuse prompt name for callback Autopilot sessions",
|
||||
)
|
||||
langfuse_autopilot_invite_cta_prompt_name: str = Field(
|
||||
default="CoPilot Beta Invite CTA",
|
||||
description="Langfuse prompt name for beta invite CTA Autopilot sessions",
|
||||
)
|
||||
langfuse_prompt_cache_ttl: int = Field(
|
||||
default=300,
|
||||
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
|
||||
|
||||
@@ -8,19 +8,48 @@ from typing import Any
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from prisma.models import ChatSessionCallbackToken as PrismaChatSessionCallbackToken
|
||||
from prisma.types import (
|
||||
ChatMessageCreateInput,
|
||||
ChatSessionCreateInput,
|
||||
ChatSessionUpdateInput,
|
||||
ChatSessionWhereInput,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data import db
|
||||
from backend.util.json import SafeJson, sanitize_string
|
||||
|
||||
from .model import ChatMessage, ChatSession, ChatSessionInfo
|
||||
from .session_types import ChatSessionStartType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_UNSET = object()
|
||||
|
||||
|
||||
class ChatSessionCallbackTokenInfo(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
source_session_id: str | None = None
|
||||
callback_session_message: str
|
||||
expires_at: datetime
|
||||
consumed_at: datetime | None = None
|
||||
consumed_session_id: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_db(
|
||||
cls,
|
||||
token: PrismaChatSessionCallbackToken,
|
||||
) -> "ChatSessionCallbackTokenInfo":
|
||||
return cls(
|
||||
id=token.id,
|
||||
user_id=token.userId,
|
||||
source_session_id=token.sourceSessionId,
|
||||
callback_session_message=token.callbackSessionMessage,
|
||||
expires_at=token.expiresAt,
|
||||
consumed_at=token.consumedAt,
|
||||
consumed_session_id=token.consumedSessionId,
|
||||
)
|
||||
|
||||
|
||||
async def get_chat_session(session_id: str) -> ChatSession | None:
|
||||
@@ -32,9 +61,103 @@ async def get_chat_session(session_id: str) -> ChatSession | None:
|
||||
return ChatSession.from_db(session) if session else None
|
||||
|
||||
|
||||
async def has_recent_manual_message(user_id: str, since: datetime) -> bool:
|
||||
message = await PrismaChatMessage.prisma().find_first(
|
||||
where={
|
||||
"role": "user",
|
||||
"createdAt": {"gte": since},
|
||||
"Session": {
|
||||
"is": {
|
||||
"userId": user_id,
|
||||
"startType": ChatSessionStartType.MANUAL.value,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
return message is not None
|
||||
|
||||
|
||||
async def has_session_since(user_id: str, since: datetime) -> bool:
|
||||
session = await PrismaChatSession.prisma().find_first(
|
||||
where={"userId": user_id, "createdAt": {"gte": since}}
|
||||
)
|
||||
return session is not None
|
||||
|
||||
|
||||
async def session_exists_for_execution_tag(user_id: str, execution_tag: str) -> bool:
|
||||
session = await PrismaChatSession.prisma().find_first(
|
||||
where={"userId": user_id, "executionTag": execution_tag}
|
||||
)
|
||||
return session is not None
|
||||
|
||||
|
||||
async def get_manual_chat_sessions_since(
|
||||
user_id: str,
|
||||
since_utc: datetime,
|
||||
limit: int,
|
||||
) -> list[ChatSessionInfo]:
|
||||
sessions = await PrismaChatSession.prisma().find_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"startType": ChatSessionStartType.MANUAL.value,
|
||||
"updatedAt": {"gte": since_utc},
|
||||
},
|
||||
order={"updatedAt": "desc"},
|
||||
take=limit,
|
||||
)
|
||||
return [ChatSessionInfo.from_db(session) for session in sessions]
|
||||
|
||||
|
||||
async def get_chat_messages_since(
|
||||
session_id: str,
|
||||
since_utc: datetime,
|
||||
) -> list[ChatMessage]:
|
||||
messages = await PrismaChatMessage.prisma().find_many(
|
||||
where={
|
||||
"sessionId": session_id,
|
||||
"createdAt": {"gte": since_utc},
|
||||
},
|
||||
order={"sequence": "asc"},
|
||||
)
|
||||
return [ChatMessage.from_db(message) for message in messages]
|
||||
|
||||
|
||||
def _build_chat_message_create_input(
|
||||
*,
|
||||
session_id: str,
|
||||
sequence: int,
|
||||
now: datetime,
|
||||
msg: dict[str, Any],
|
||||
) -> ChatMessageCreateInput:
|
||||
data: ChatMessageCreateInput = {
|
||||
"sessionId": session_id,
|
||||
"role": msg["role"],
|
||||
"sequence": sequence,
|
||||
"createdAt": now,
|
||||
}
|
||||
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = sanitize_string(msg["content"])
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = sanitize_string(msg["refusal"])
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def create_chat_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
start_type: ChatSessionStartType = ChatSessionStartType.MANUAL,
|
||||
execution_tag: str | None = None,
|
||||
session_config: dict[str, Any] | None = None,
|
||||
) -> ChatSessionInfo:
|
||||
"""Create a new chat session in the database."""
|
||||
data = ChatSessionCreateInput(
|
||||
@@ -43,6 +166,9 @@ async def create_chat_session(
|
||||
credentials=SafeJson({}),
|
||||
successfulAgentRuns=SafeJson({}),
|
||||
successfulAgentSchedules=SafeJson({}),
|
||||
startType=start_type.value,
|
||||
executionTag=execution_tag,
|
||||
sessionConfig=SafeJson(session_config or {}),
|
||||
)
|
||||
prisma_session = await PrismaChatSession.prisma().create(data=data)
|
||||
return ChatSessionInfo.from_db(prisma_session)
|
||||
@@ -56,9 +182,19 @@ async def update_chat_session(
|
||||
total_prompt_tokens: int | None = None,
|
||||
total_completion_tokens: int | None = None,
|
||||
title: str | None = None,
|
||||
start_type: ChatSessionStartType | None = None,
|
||||
execution_tag: str | None | object = _UNSET,
|
||||
session_config: dict[str, Any] | None = None,
|
||||
completion_report: dict[str, Any] | None | object = _UNSET,
|
||||
completion_report_repair_count: int | None = None,
|
||||
completion_report_repair_queued_at: datetime | None | object = _UNSET,
|
||||
completed_at: datetime | None | object = _UNSET,
|
||||
notification_email_sent_at: datetime | None | object = _UNSET,
|
||||
notification_email_skipped_at: datetime | None | object = _UNSET,
|
||||
) -> ChatSession | None:
|
||||
"""Update a chat session's metadata."""
|
||||
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
||||
should_clear_completion_report = completion_report is None
|
||||
|
||||
if credentials is not None:
|
||||
data["credentials"] = SafeJson(credentials)
|
||||
@@ -72,12 +208,41 @@ async def update_chat_session(
|
||||
data["totalCompletionTokens"] = total_completion_tokens
|
||||
if title is not None:
|
||||
data["title"] = title
|
||||
if start_type is not None:
|
||||
data["startType"] = start_type.value
|
||||
if execution_tag is not _UNSET:
|
||||
data["executionTag"] = execution_tag
|
||||
if session_config is not None:
|
||||
data["sessionConfig"] = SafeJson(session_config)
|
||||
if completion_report is not _UNSET and completion_report is not None:
|
||||
data["completionReport"] = SafeJson(completion_report)
|
||||
if completion_report_repair_count is not None:
|
||||
data["completionReportRepairCount"] = completion_report_repair_count
|
||||
if completion_report_repair_queued_at is not _UNSET:
|
||||
data["completionReportRepairQueuedAt"] = completion_report_repair_queued_at
|
||||
if completed_at is not _UNSET:
|
||||
data["completedAt"] = completed_at
|
||||
if notification_email_sent_at is not _UNSET:
|
||||
data["notificationEmailSentAt"] = notification_email_sent_at
|
||||
if notification_email_skipped_at is not _UNSET:
|
||||
data["notificationEmailSkippedAt"] = notification_email_skipped_at
|
||||
|
||||
session = await PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data=data,
|
||||
include={"Messages": {"order_by": {"sequence": "asc"}}},
|
||||
)
|
||||
|
||||
if should_clear_completion_report:
|
||||
await db.execute_raw_with_schema(
|
||||
'UPDATE {schema_prefix}"ChatSession" SET "completionReport" = NULL WHERE "id" = $1',
|
||||
session_id,
|
||||
)
|
||||
session = await PrismaChatSession.prisma().find_unique(
|
||||
where={"id": session_id},
|
||||
include={"Messages": {"order_by": {"sequence": "asc"}}},
|
||||
)
|
||||
|
||||
return ChatSession.from_db(session) if session else None
|
||||
|
||||
|
||||
@@ -187,37 +352,15 @@ async def add_chat_messages_batch(
|
||||
now = datetime.now(UTC)
|
||||
|
||||
async with db.transaction() as tx:
|
||||
# Build all message data
|
||||
messages_data = []
|
||||
for i, msg in enumerate(messages):
|
||||
# Build ChatMessageCreateInput with only non-None values
|
||||
# (Prisma TypedDict rejects optional fields set to None)
|
||||
# Note: create_many doesn't support nested creates, use sessionId directly
|
||||
data: ChatMessageCreateInput = {
|
||||
"sessionId": session_id,
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
"createdAt": now,
|
||||
}
|
||||
|
||||
# Add optional string fields — sanitize to strip
|
||||
# PostgreSQL-incompatible control characters.
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = sanitize_string(msg["content"])
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = sanitize_string(msg["refusal"])
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
messages_data.append(data)
|
||||
messages_data = [
|
||||
_build_chat_message_create_input(
|
||||
session_id=session_id,
|
||||
sequence=start_sequence + i,
|
||||
now=now,
|
||||
msg=msg,
|
||||
)
|
||||
for i, msg in enumerate(messages)
|
||||
]
|
||||
|
||||
# Run create_many and session update in parallel within transaction
|
||||
# Both use the same timestamp for consistency
|
||||
@@ -256,10 +399,14 @@ async def get_user_chat_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
with_auto: bool = False,
|
||||
) -> list[ChatSessionInfo]:
|
||||
"""Get chat sessions for a user, ordered by most recent."""
|
||||
prisma_sessions = await PrismaChatSession.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
where={
|
||||
"userId": user_id,
|
||||
**({} if with_auto else {"startType": ChatSessionStartType.MANUAL.value}),
|
||||
},
|
||||
order={"updatedAt": "desc"},
|
||||
take=limit,
|
||||
skip=offset,
|
||||
@@ -267,9 +414,88 @@ async def get_user_chat_sessions(
|
||||
return [ChatSessionInfo.from_db(s) for s in prisma_sessions]
|
||||
|
||||
|
||||
async def get_user_session_count(user_id: str) -> int:
|
||||
async def get_pending_notification_chat_sessions(
|
||||
limit: int = 200,
|
||||
) -> list[ChatSessionInfo]:
|
||||
sessions = await PrismaChatSession.prisma().find_many(
|
||||
where={
|
||||
"startType": {"not": ChatSessionStartType.MANUAL.value},
|
||||
"notificationEmailSentAt": None,
|
||||
"notificationEmailSkippedAt": None,
|
||||
},
|
||||
order={"updatedAt": "asc"},
|
||||
take=limit,
|
||||
)
|
||||
return [ChatSessionInfo.from_db(session) for session in sessions]
|
||||
|
||||
|
||||
async def get_pending_notification_chat_sessions_for_user(
|
||||
user_id: str,
|
||||
limit: int = 200,
|
||||
) -> list[ChatSessionInfo]:
|
||||
sessions = await PrismaChatSession.prisma().find_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"startType": {"not": ChatSessionStartType.MANUAL.value},
|
||||
"notificationEmailSentAt": None,
|
||||
"notificationEmailSkippedAt": None,
|
||||
},
|
||||
order={"updatedAt": "asc"},
|
||||
take=limit,
|
||||
)
|
||||
return [ChatSessionInfo.from_db(session) for session in sessions]
|
||||
|
||||
|
||||
async def get_recent_sent_email_chat_sessions(
|
||||
user_id: str,
|
||||
limit: int,
|
||||
) -> list[ChatSessionInfo]:
|
||||
sessions = await PrismaChatSession.prisma().find_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"startType": {"not": ChatSessionStartType.MANUAL.value},
|
||||
"notificationEmailSentAt": {"not": None},
|
||||
},
|
||||
order={"notificationEmailSentAt": "desc"},
|
||||
take=max(limit * 3, limit),
|
||||
)
|
||||
return [
|
||||
session_info
|
||||
for session_info in (ChatSessionInfo.from_db(session) for session in sessions)
|
||||
if session_info.notification_email_sent_at and session_info.completion_report
|
||||
][:limit]
|
||||
|
||||
|
||||
async def get_recent_completion_report_chat_sessions(
|
||||
user_id: str,
|
||||
limit: int,
|
||||
) -> list[ChatSessionInfo]:
|
||||
sessions = await PrismaChatSession.prisma().find_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"startType": {"not": ChatSessionStartType.MANUAL.value},
|
||||
},
|
||||
order={"updatedAt": "desc"},
|
||||
take=max(limit * 5, 10),
|
||||
)
|
||||
return [
|
||||
session_info
|
||||
for session_info in (ChatSessionInfo.from_db(session) for session in sessions)
|
||||
if session_info.completion_report is not None
|
||||
][:limit]
|
||||
|
||||
|
||||
async def get_user_session_count(
|
||||
user_id: str,
|
||||
with_auto: bool = False,
|
||||
) -> int:
|
||||
"""Get the total number of chat sessions for a user."""
|
||||
return await PrismaChatSession.prisma().count(where={"userId": user_id})
|
||||
return await PrismaChatSession.prisma().count(
|
||||
where={
|
||||
"userId": user_id,
|
||||
**({} if with_auto else {"startType": ChatSessionStartType.MANUAL.value}),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def delete_chat_session(session_id: str, user_id: str | None = None) -> bool:
|
||||
@@ -359,3 +585,42 @@ async def update_tool_message_content(
|
||||
f"tool_call_id {tool_call_id}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def create_chat_session_callback_token(
|
||||
user_id: str,
|
||||
source_session_id: str,
|
||||
callback_session_message: str,
|
||||
expires_at: datetime,
|
||||
) -> ChatSessionCallbackTokenInfo:
|
||||
token = await PrismaChatSessionCallbackToken.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"sourceSessionId": source_session_id,
|
||||
"callbackSessionMessage": callback_session_message,
|
||||
"expiresAt": expires_at,
|
||||
}
|
||||
)
|
||||
return ChatSessionCallbackTokenInfo.from_db(token)
|
||||
|
||||
|
||||
async def get_chat_session_callback_token(
|
||||
token_id: str,
|
||||
) -> ChatSessionCallbackTokenInfo | None:
|
||||
token = await PrismaChatSessionCallbackToken.prisma().find_unique(
|
||||
where={"id": token_id}
|
||||
)
|
||||
return ChatSessionCallbackTokenInfo.from_db(token) if token else None
|
||||
|
||||
|
||||
async def mark_chat_session_callback_token_consumed(
|
||||
token_id: str,
|
||||
consumed_session_id: str,
|
||||
) -> None:
|
||||
await PrismaChatSessionCallbackToken.prisma().update(
|
||||
where={"id": token_id},
|
||||
data={
|
||||
"consumedAt": datetime.now(UTC),
|
||||
"consumedSessionId": consumed_session_id,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -21,7 +21,7 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
)
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.db_accessors import chat_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
@@ -29,6 +29,11 @@ from backend.util import json
|
||||
from backend.util.exceptions import DatabaseError, RedisError
|
||||
|
||||
from .config import ChatConfig
|
||||
from .session_types import (
|
||||
ChatSessionConfig,
|
||||
ChatSessionStartType,
|
||||
StoredCompletionReport,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
@@ -80,11 +85,20 @@ class ChatSessionInfo(BaseModel):
|
||||
user_id: str
|
||||
title: str | None = None
|
||||
usage: list[Usage]
|
||||
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
|
||||
credentials: dict[str, dict] = Field(default_factory=dict)
|
||||
started_at: datetime
|
||||
updated_at: datetime
|
||||
successful_agent_runs: dict[str, int] = {}
|
||||
successful_agent_schedules: dict[str, int] = {}
|
||||
successful_agent_runs: dict[str, int] = Field(default_factory=dict)
|
||||
successful_agent_schedules: dict[str, int] = Field(default_factory=dict)
|
||||
start_type: ChatSessionStartType = ChatSessionStartType.MANUAL
|
||||
execution_tag: str | None = None
|
||||
session_config: ChatSessionConfig = Field(default_factory=ChatSessionConfig)
|
||||
completion_report: StoredCompletionReport | None = None
|
||||
completion_report_repair_count: int = 0
|
||||
completion_report_repair_queued_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
notification_email_sent_at: datetime | None = None
|
||||
notification_email_skipped_at: datetime | None = None
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
|
||||
@@ -97,6 +111,8 @@ class ChatSessionInfo(BaseModel):
|
||||
successful_agent_schedules = _parse_json_field(
|
||||
prisma_session.successfulAgentSchedules, default={}
|
||||
)
|
||||
session_config = _parse_json_field(prisma_session.sessionConfig, default={})
|
||||
completion_report = _parse_json_field(prisma_session.completionReport)
|
||||
|
||||
# Calculate usage from token counts
|
||||
usage = []
|
||||
@@ -110,6 +126,20 @@ class ChatSessionInfo(BaseModel):
|
||||
)
|
||||
)
|
||||
|
||||
parsed_session_config = ChatSessionConfig.model_validate(session_config or {})
|
||||
parsed_completion_report = None
|
||||
if isinstance(completion_report, dict):
|
||||
try:
|
||||
parsed_completion_report = StoredCompletionReport.model_validate(
|
||||
completion_report
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Invalid completionReport payload on session %s",
|
||||
prisma_session.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return cls(
|
||||
session_id=prisma_session.id,
|
||||
user_id=prisma_session.userId,
|
||||
@@ -120,6 +150,15 @@ class ChatSessionInfo(BaseModel):
|
||||
updated_at=prisma_session.updatedAt,
|
||||
successful_agent_runs=successful_agent_runs,
|
||||
successful_agent_schedules=successful_agent_schedules,
|
||||
start_type=ChatSessionStartType(str(prisma_session.startType)),
|
||||
execution_tag=prisma_session.executionTag,
|
||||
session_config=parsed_session_config,
|
||||
completion_report=parsed_completion_report,
|
||||
completion_report_repair_count=prisma_session.completionReportRepairCount,
|
||||
completion_report_repair_queued_at=prisma_session.completionReportRepairQueuedAt,
|
||||
completed_at=prisma_session.completedAt,
|
||||
notification_email_sent_at=prisma_session.notificationEmailSentAt,
|
||||
notification_email_skipped_at=prisma_session.notificationEmailSkippedAt,
|
||||
)
|
||||
|
||||
|
||||
@@ -127,7 +166,13 @@ class ChatSession(ChatSessionInfo):
|
||||
messages: list[ChatMessage]
|
||||
|
||||
@classmethod
|
||||
def new(cls, user_id: str) -> Self:
|
||||
def new(
|
||||
cls,
|
||||
user_id: str,
|
||||
start_type: ChatSessionStartType = ChatSessionStartType.MANUAL,
|
||||
execution_tag: str | None = None,
|
||||
session_config: ChatSessionConfig | None = None,
|
||||
) -> Self:
|
||||
return cls(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
@@ -137,6 +182,9 @@ class ChatSession(ChatSessionInfo):
|
||||
credentials={},
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
start_type=start_type,
|
||||
execution_tag=execution_tag,
|
||||
session_config=session_config or ChatSessionConfig(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -152,6 +200,16 @@ class ChatSession(ChatSessionInfo):
|
||||
messages=[ChatMessage.from_db(m) for m in prisma_session.Messages],
|
||||
)
|
||||
|
||||
@property
|
||||
def is_manual(self) -> bool:
|
||||
return self.start_type == ChatSessionStartType.MANUAL
|
||||
|
||||
def allows_tool(self, tool_name: str) -> bool:
|
||||
return self.session_config.allows_tool(tool_name)
|
||||
|
||||
def disables_tool(self, tool_name: str) -> bool:
|
||||
return self.session_config.disables_tool(tool_name)
|
||||
|
||||
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
|
||||
"""Attach a tool_call to the current turn's assistant message.
|
||||
|
||||
@@ -524,6 +582,9 @@ async def _save_session_to_db(
|
||||
await db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
start_type=session.start_type,
|
||||
execution_tag=session.execution_tag,
|
||||
session_config=session.session_config.model_dump(mode="json"),
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
@@ -539,6 +600,19 @@ async def _save_session_to_db(
|
||||
successful_agent_schedules=session.successful_agent_schedules,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
start_type=session.start_type,
|
||||
execution_tag=session.execution_tag,
|
||||
session_config=session.session_config.model_dump(mode="json"),
|
||||
completion_report=(
|
||||
session.completion_report.model_dump(mode="json")
|
||||
if session.completion_report
|
||||
else None
|
||||
),
|
||||
completion_report_repair_count=session.completion_report_repair_count,
|
||||
completion_report_repair_queued_at=session.completion_report_repair_queued_at,
|
||||
completed_at=session.completed_at,
|
||||
notification_email_sent_at=session.notification_email_sent_at,
|
||||
notification_email_skipped_at=session.notification_email_skipped_at,
|
||||
)
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
@@ -601,7 +675,13 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(user_id: str) -> ChatSession:
|
||||
async def create_chat_session(
|
||||
user_id: str,
|
||||
start_type: ChatSessionStartType = ChatSessionStartType.MANUAL,
|
||||
execution_tag: str | None = None,
|
||||
session_config: ChatSessionConfig | None = None,
|
||||
initial_messages: list[ChatMessage] | None = None,
|
||||
) -> ChatSession:
|
||||
"""Create a new chat session and persist it.
|
||||
|
||||
Raises:
|
||||
@@ -609,14 +689,30 @@ async def create_chat_session(user_id: str) -> ChatSession:
|
||||
callers never receive a non-persisted session that only exists
|
||||
in cache (which would be lost when the cache expires).
|
||||
"""
|
||||
session = ChatSession.new(user_id)
|
||||
session = ChatSession.new(
|
||||
user_id,
|
||||
start_type=start_type,
|
||||
execution_tag=execution_tag,
|
||||
session_config=session_config,
|
||||
)
|
||||
if initial_messages:
|
||||
session.messages.extend(initial_messages)
|
||||
|
||||
# Create in database first - fail fast if this fails
|
||||
try:
|
||||
await chat_db().create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
start_type=session.start_type,
|
||||
execution_tag=session.execution_tag,
|
||||
session_config=session.session_config.model_dump(mode="json"),
|
||||
)
|
||||
if session.messages:
|
||||
await _save_session_to_db(
|
||||
session,
|
||||
0,
|
||||
skip_existence_check=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create session {session.session_id} in database: {e}")
|
||||
raise DatabaseError(
|
||||
@@ -636,6 +732,7 @@ async def get_user_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
with_auto: bool = False,
|
||||
) -> tuple[list[ChatSessionInfo], int]:
|
||||
"""Get chat sessions for a user from the database with total count.
|
||||
|
||||
@@ -644,8 +741,16 @@ async def get_user_sessions(
|
||||
number of sessions for the user (not just the current page).
|
||||
"""
|
||||
db = chat_db()
|
||||
sessions = await db.get_user_chat_sessions(user_id, limit, offset)
|
||||
total_count = await db.get_user_session_count(user_id)
|
||||
sessions = await db.get_user_chat_sessions(
|
||||
user_id,
|
||||
limit,
|
||||
offset,
|
||||
with_auto=with_auto,
|
||||
)
|
||||
total_count = await db.get_user_session_count(
|
||||
user_id,
|
||||
with_auto=with_auto,
|
||||
)
|
||||
|
||||
return sessions, total_count
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from .model import (
|
||||
get_chat_session,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from .session_types import ChatSessionConfig, ChatSessionStartType
|
||||
|
||||
messages = [
|
||||
ChatMessage(content="Hello, how are you?", role="user"),
|
||||
@@ -46,7 +47,15 @@ messages = [
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_serialization_deserialization():
|
||||
s = ChatSession.new(user_id="abc123")
|
||||
s = ChatSession.new(
|
||||
user_id="abc123",
|
||||
start_type=ChatSessionStartType.AUTOPILOT_NIGHTLY,
|
||||
execution_tag="autopilot-nightly:2026-03-13",
|
||||
session_config=ChatSessionConfig(
|
||||
extra_tools=["completion_report"],
|
||||
disabled_tools=["edit_agent"],
|
||||
),
|
||||
)
|
||||
s.messages = messages
|
||||
s.usage = [Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)]
|
||||
serialized = s.model_dump_json()
|
||||
|
||||
@@ -6,7 +6,7 @@ handling the distinction between:
|
||||
- Local mode vs E2B mode (storage/filesystem differences)
|
||||
"""
|
||||
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
from backend.copilot.tools import iter_available_tools
|
||||
|
||||
# Shared technical notes that apply to both SDK and baseline modes
|
||||
_SHARED_TOOL_NOTES = """\
|
||||
@@ -161,7 +161,7 @@ def _get_cloud_sandbox_supplement() -> str:
|
||||
)
|
||||
|
||||
|
||||
def _generate_tool_documentation() -> str:
|
||||
def _generate_tool_documentation(session=None) -> str:
|
||||
"""Auto-generate tool documentation from TOOL_REGISTRY.
|
||||
|
||||
NOTE: This is ONLY used in baseline mode (direct OpenAI API).
|
||||
@@ -177,11 +177,7 @@ def _generate_tool_documentation() -> str:
|
||||
docs = "\n## AVAILABLE TOOLS\n\n"
|
||||
|
||||
# Sort tools alphabetically for consistent output
|
||||
# Filter by is_available to match get_available_tools() behavior
|
||||
for name in sorted(TOOL_REGISTRY.keys()):
|
||||
tool = TOOL_REGISTRY[name]
|
||||
if not tool.is_available:
|
||||
continue
|
||||
for name, tool in sorted(iter_available_tools(session), key=lambda item: item[0]):
|
||||
schema = tool.as_openai_tool()
|
||||
desc = schema["function"].get("description", "No description available")
|
||||
# Format as bullet list with tool name in code style
|
||||
@@ -209,7 +205,7 @@ def get_sdk_supplement(use_e2b: bool, cwd: str = "") -> str:
|
||||
return _get_local_storage_supplement(cwd)
|
||||
|
||||
|
||||
def get_baseline_supplement() -> str:
|
||||
def get_baseline_supplement(session=None) -> str:
|
||||
"""Get the supplement for baseline mode (direct OpenAI API).
|
||||
|
||||
Baseline mode INCLUDES auto-generated tool documentation because the
|
||||
@@ -219,5 +215,5 @@ def get_baseline_supplement() -> str:
|
||||
Returns:
|
||||
The supplement string to append to the system prompt
|
||||
"""
|
||||
tool_docs = _generate_tool_documentation()
|
||||
tool_docs = _generate_tool_documentation(session)
|
||||
return tool_docs + _SHARED_TOOL_NOTES
|
||||
|
||||
@@ -12,7 +12,7 @@ import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, cast
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
import openai
|
||||
from claude_agent_sdk import (
|
||||
@@ -56,9 +56,9 @@ from ..response_model import (
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from ..service import (
|
||||
_build_system_prompt,
|
||||
_generate_session_title,
|
||||
_is_langfuse_configured,
|
||||
_resolve_system_prompt,
|
||||
)
|
||||
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
|
||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
@@ -88,6 +88,10 @@ logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
class _ClaudeSDKTransport(Protocol):
|
||||
async def write(self, data: str) -> None: ...
|
||||
|
||||
|
||||
def _setup_langfuse_otel() -> None:
|
||||
"""Configure OTEL tracing for the Claude Agent SDK → Langfuse.
|
||||
|
||||
@@ -137,6 +141,16 @@ def _setup_langfuse_otel() -> None:
|
||||
_setup_langfuse_otel()
|
||||
|
||||
|
||||
async def _write_multimodal_query(
|
||||
client: ClaudeSDKClient,
|
||||
user_message: dict[str, Any],
|
||||
) -> None:
|
||||
transport = cast(_ClaudeSDKTransport | None, getattr(client, "_transport", None))
|
||||
if transport is None:
|
||||
raise RuntimeError("Claude SDK transport is unavailable for multimodal input")
|
||||
await transport.write(json.dumps(user_message) + "\n")
|
||||
|
||||
|
||||
# Set to hold background tasks to prevent garbage collection
|
||||
_background_tasks: set[asyncio.Task[Any]] = set()
|
||||
|
||||
@@ -690,7 +704,7 @@ async def stream_chat_completion_sdk(
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Generate title for new sessions (first user message)
|
||||
if is_user_message and not session.title:
|
||||
if is_user_message and session.is_manual and not session.title:
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
if len(user_messages) == 1:
|
||||
first_message = user_messages[0].content or message or ""
|
||||
@@ -805,7 +819,11 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
e2b_sandbox, (base_system_prompt, _), dl = await asyncio.gather(
|
||||
_setup_e2b(),
|
||||
_build_system_prompt(user_id, has_conversation_history=has_history),
|
||||
_resolve_system_prompt(
|
||||
session,
|
||||
user_id,
|
||||
has_conversation_history=has_history,
|
||||
),
|
||||
_fetch_transcript(),
|
||||
)
|
||||
|
||||
@@ -862,7 +880,7 @@ async def stream_chat_completion_sdk(
|
||||
"Claude Code CLI subscription (requires `claude login`)."
|
||||
)
|
||||
|
||||
mcp_server = create_copilot_mcp_server(use_e2b=use_e2b)
|
||||
mcp_server = create_copilot_mcp_server(session, use_e2b=use_e2b)
|
||||
|
||||
sdk_model = _resolve_sdk_model()
|
||||
|
||||
@@ -876,7 +894,7 @@ async def stream_chat_completion_sdk(
|
||||
on_compact=compaction.on_compact,
|
||||
)
|
||||
|
||||
allowed = get_copilot_tool_names(use_e2b=use_e2b)
|
||||
allowed = get_copilot_tool_names(session, use_e2b=use_e2b)
|
||||
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
|
||||
|
||||
def _on_stderr(line: str) -> None:
|
||||
@@ -977,10 +995,7 @@ async def stream_chat_completion_sdk(
|
||||
"parent_tool_use_id": None,
|
||||
"session_id": session_id,
|
||||
}
|
||||
assert client._transport is not None # noqa: SLF001
|
||||
await client._transport.write( # noqa: SLF001
|
||||
json.dumps(user_msg) + "\n"
|
||||
)
|
||||
await _write_multimodal_query(client, user_msg)
|
||||
# Capture user message in transcript (multimodal)
|
||||
transcript_builder.append_user(content=content_blocks)
|
||||
else:
|
||||
|
||||
@@ -205,6 +205,29 @@ class TestPromptSupplement:
|
||||
):
|
||||
assert "`browser_navigate`" in docs
|
||||
|
||||
def test_baseline_supplement_respects_session_disabled_tools(self):
|
||||
"""Session-specific docs should hide disabled tools and include added session tools."""
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.session_types import (
|
||||
ChatSessionConfig,
|
||||
ChatSessionStartType,
|
||||
)
|
||||
|
||||
session = ChatSession.new(
|
||||
"user-1",
|
||||
start_type=ChatSessionStartType.AUTOPILOT_NIGHTLY,
|
||||
session_config=ChatSessionConfig(
|
||||
extra_tools=["completion_report"],
|
||||
disabled_tools=["edit_agent"],
|
||||
),
|
||||
)
|
||||
|
||||
docs = get_baseline_supplement(session)
|
||||
|
||||
assert "`completion_report`" in docs
|
||||
assert "`edit_agent`" not in docs
|
||||
|
||||
def test_baseline_supplement_includes_workflows(self):
|
||||
"""Baseline supplement should include workflow guidance in tool descriptions."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
@@ -219,15 +242,13 @@ class TestPromptSupplement:
|
||||
def test_baseline_supplement_completeness(self):
|
||||
"""All available tools from TOOL_REGISTRY should appear in baseline supplement."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
from backend.copilot.tools import iter_available_tools
|
||||
|
||||
docs = get_baseline_supplement()
|
||||
|
||||
# Verify each available registered tool is documented
|
||||
# (matches _generate_tool_documentation which filters by is_available)
|
||||
for tool_name, tool in TOOL_REGISTRY.items():
|
||||
if not tool.is_available:
|
||||
continue
|
||||
# (matches _generate_tool_documentation which filters with iter_available_tools)
|
||||
for tool_name, _ in iter_available_tools():
|
||||
assert (
|
||||
f"`{tool_name}`" in docs
|
||||
), f"Tool '{tool_name}' missing from baseline supplement"
|
||||
@@ -277,14 +298,12 @@ class TestPromptSupplement:
|
||||
def test_baseline_supplement_no_duplicate_tools(self):
|
||||
"""No tool should appear multiple times in baseline supplement."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
from backend.copilot.tools import iter_available_tools
|
||||
|
||||
docs = get_baseline_supplement()
|
||||
|
||||
# Count occurrences of each available tool in the entire supplement
|
||||
for tool_name, tool in TOOL_REGISTRY.items():
|
||||
if not tool.is_available:
|
||||
continue
|
||||
for tool_name, _ in iter_available_tools():
|
||||
# Count how many times this tool appears as a bullet point
|
||||
count = docs.count(f"- **`{tool_name}`**")
|
||||
assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)"
|
||||
|
||||
@@ -32,7 +32,7 @@ from backend.copilot.sdk.file_ref import (
|
||||
expand_file_refs_in_args,
|
||||
read_file_bytes,
|
||||
)
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
from backend.copilot.tools import iter_available_tools
|
||||
from backend.copilot.tools.base import BaseTool
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
@@ -338,7 +338,11 @@ def _text_from_mcp_result(result: dict[str, Any]) -> str:
|
||||
)
|
||||
|
||||
|
||||
def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
def create_copilot_mcp_server(
|
||||
session: ChatSession,
|
||||
*,
|
||||
use_e2b: bool = False,
|
||||
):
|
||||
"""Create an in-process MCP server configuration for CoPilot tools.
|
||||
|
||||
When *use_e2b* is True, five additional MCP file tools are registered
|
||||
@@ -387,7 +391,7 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
|
||||
sdk_tools = []
|
||||
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
for tool_name, base_tool in iter_available_tools(session):
|
||||
handler = create_tool_handler(base_tool)
|
||||
decorated = tool(
|
||||
tool_name,
|
||||
@@ -475,25 +479,30 @@ DANGEROUS_PATTERNS = [
|
||||
r"subprocess",
|
||||
]
|
||||
|
||||
# Static tool name list for the non-E2B case (backward compatibility).
|
||||
COPILOT_TOOL_NAMES = [
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
|
||||
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
|
||||
*_SDK_BUILTIN_TOOLS,
|
||||
]
|
||||
|
||||
|
||||
def get_copilot_tool_names(*, use_e2b: bool = False) -> list[str]:
|
||||
def get_copilot_tool_names(
|
||||
session: ChatSession,
|
||||
*,
|
||||
use_e2b: bool = False,
|
||||
) -> list[str]:
|
||||
"""Build the ``allowed_tools`` list for :class:`ClaudeAgentOptions`.
|
||||
|
||||
When *use_e2b* is True the SDK built-in file tools are replaced by MCP
|
||||
equivalents that route to the E2B sandbox.
|
||||
"""
|
||||
tool_names = [
|
||||
f"{MCP_TOOL_PREFIX}{name}" for name, _ in iter_available_tools(session)
|
||||
]
|
||||
|
||||
if not use_e2b:
|
||||
return list(COPILOT_TOOL_NAMES)
|
||||
return [
|
||||
*tool_names,
|
||||
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
|
||||
*_SDK_BUILTIN_TOOLS,
|
||||
]
|
||||
|
||||
return [
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
|
||||
*tool_names,
|
||||
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in E2B_FILE_TOOL_NAMES],
|
||||
*_SDK_BUILTIN_ALWAYS,
|
||||
|
||||
@@ -3,11 +3,14 @@
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import get_sdk_cwd
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.session_types import ChatSessionConfig, ChatSessionStartType
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
from .tool_adapter import (
|
||||
_MCP_MAX_CHARS,
|
||||
_text_from_mcp_result,
|
||||
get_copilot_tool_names,
|
||||
pop_pending_tool_output,
|
||||
set_execution_context,
|
||||
stash_pending_tool_output,
|
||||
@@ -168,3 +171,20 @@ class TestTruncationAndStashIntegration:
|
||||
text = _text_from_mcp_result(truncated)
|
||||
assert len(text) < len(big_text)
|
||||
assert len(str(truncated)) <= _MCP_MAX_CHARS
|
||||
|
||||
|
||||
class TestSessionToolFiltering:
|
||||
def test_disabled_tools_are_removed_from_sdk_allowed_tools(self):
|
||||
session = ChatSession.new(
|
||||
"user-1",
|
||||
start_type=ChatSessionStartType.AUTOPILOT_NIGHTLY,
|
||||
session_config=ChatSessionConfig(
|
||||
extra_tools=["completion_report"],
|
||||
disabled_tools=["edit_agent"],
|
||||
),
|
||||
)
|
||||
|
||||
tool_names = get_copilot_tool_names(session)
|
||||
|
||||
assert "mcp__copilot__completion_report" in tool_names
|
||||
assert "mcp__copilot__edit_agent" not in tool_names
|
||||
|
||||
@@ -22,7 +22,7 @@ from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
from backend.util.settings import AppEnvironment, Settings
|
||||
|
||||
from .config import ChatConfig
|
||||
from .model import ChatSessionInfo, get_chat_session, upsert_chat_session
|
||||
from .model import ChatSession, ChatSessionInfo, get_chat_session, upsert_chat_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -64,7 +64,13 @@ def _is_langfuse_configured() -> bool:
|
||||
)
|
||||
|
||||
|
||||
async def _get_system_prompt_template(context: str) -> str:
|
||||
async def _get_system_prompt_template(
|
||||
context: str,
|
||||
*,
|
||||
prompt_name: str | None = None,
|
||||
fallback_prompt: str | None = None,
|
||||
template_vars: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
"""Get the system prompt, trying Langfuse first with fallback to default.
|
||||
|
||||
Args:
|
||||
@@ -73,6 +79,11 @@ async def _get_system_prompt_template(context: str) -> str:
|
||||
Returns:
|
||||
The compiled system prompt string.
|
||||
"""
|
||||
resolved_prompt_name = prompt_name or config.langfuse_prompt_name
|
||||
resolved_template_vars = {
|
||||
"users_information": context,
|
||||
**(template_vars or {}),
|
||||
}
|
||||
if _is_langfuse_configured():
|
||||
try:
|
||||
# Use asyncio.to_thread to avoid blocking the event loop
|
||||
@@ -85,16 +96,16 @@ async def _get_system_prompt_template(context: str) -> str:
|
||||
)
|
||||
prompt = await asyncio.to_thread(
|
||||
langfuse.get_prompt,
|
||||
config.langfuse_prompt_name,
|
||||
resolved_prompt_name,
|
||||
label=label,
|
||||
cache_ttl_seconds=config.langfuse_prompt_cache_ttl,
|
||||
)
|
||||
return prompt.compile(users_information=context)
|
||||
return prompt.compile(**resolved_template_vars)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch prompt from Langfuse, using default: {e}")
|
||||
|
||||
# Fallback to default prompt
|
||||
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
|
||||
return (fallback_prompt or DEFAULT_SYSTEM_PROMPT).format(**resolved_template_vars)
|
||||
|
||||
|
||||
async def _build_system_prompt(
|
||||
@@ -131,6 +142,21 @@ async def _build_system_prompt(
|
||||
return compiled, understanding
|
||||
|
||||
|
||||
async def _resolve_system_prompt(
|
||||
session: ChatSession,
|
||||
user_id: str | None,
|
||||
*,
|
||||
has_conversation_history: bool = False,
|
||||
) -> tuple[str, Any]:
|
||||
override = session.session_config.system_prompt_override
|
||||
if override:
|
||||
return override, None
|
||||
return await _build_system_prompt(
|
||||
user_id,
|
||||
has_conversation_history=has_conversation_history,
|
||||
)
|
||||
|
||||
|
||||
async def _generate_session_title(
|
||||
message: str,
|
||||
user_id: str | None = None,
|
||||
|
||||
60
autogpt_platform/backend/backend/copilot/session_types.py
Normal file
60
autogpt_platform/backend/backend/copilot/session_types.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class ChatSessionStartType(str, Enum):
|
||||
MANUAL = "MANUAL"
|
||||
AUTOPILOT_NIGHTLY = "AUTOPILOT_NIGHTLY"
|
||||
AUTOPILOT_CALLBACK = "AUTOPILOT_CALLBACK"
|
||||
AUTOPILOT_INVITE_CTA = "AUTOPILOT_INVITE_CTA"
|
||||
|
||||
|
||||
class ChatSessionConfig(BaseModel):
|
||||
system_prompt_override: str | None = None
|
||||
initial_user_message: str | None = None
|
||||
initial_assistant_message: str | None = None
|
||||
extra_tools: list[str] = Field(default_factory=list)
|
||||
disabled_tools: list[str] = Field(default_factory=list)
|
||||
|
||||
def allows_tool(self, tool_name: str) -> bool:
|
||||
return tool_name in self.extra_tools
|
||||
|
||||
def disables_tool(self, tool_name: str) -> bool:
|
||||
return tool_name in self.disabled_tools
|
||||
|
||||
|
||||
class CompletionReportInput(BaseModel):
|
||||
thoughts: str
|
||||
should_notify_user: bool
|
||||
email_title: str | None = None
|
||||
email_body: str | None = None
|
||||
callback_session_message: str | None = None
|
||||
approval_summary: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_notification_fields(self) -> "CompletionReportInput":
|
||||
if self.should_notify_user:
|
||||
required_fields = {
|
||||
"email_title": self.email_title,
|
||||
"email_body": self.email_body,
|
||||
"callback_session_message": self.callback_session_message,
|
||||
}
|
||||
missing = [
|
||||
field_name for field_name, value in required_fields.items() if not value
|
||||
]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
"Missing required notification fields: " + ", ".join(missing)
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class StoredCompletionReport(CompletionReportInput):
|
||||
has_pending_approvals: bool
|
||||
pending_approval_count: int
|
||||
pending_approval_graph_exec_id: str | None = None
|
||||
saved_at: datetime
|
||||
@@ -17,11 +17,12 @@ Subscribers:
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from collections.abc import Awaitable
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import orjson
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from backend.api.model import CopilotCompletionPayload
|
||||
from backend.data.notification_bus import (
|
||||
@@ -55,6 +56,12 @@ _listener_sessions: dict[int, tuple[str, asyncio.Task]] = {}
|
||||
# Timeout for putting chunks into subscriber queues (seconds)
|
||||
# If the queue is full and doesn't drain within this time, send an overflow error
|
||||
QUEUE_PUT_TIMEOUT = 5.0
|
||||
SESSION_LOOKUP_RETRY_SECONDS = 0.05
|
||||
STREAM_REPLAY_COUNT = 1000
|
||||
STREAM_XREAD_BLOCK_MS = 5000
|
||||
STREAM_XREAD_COUNT = 100
|
||||
STALE_SESSION_BUFFER_SECONDS = 300
|
||||
UNSUBSCRIBE_TIMEOUT_SECONDS = 5.0
|
||||
|
||||
# Lua script for atomic compare-and-swap status update (idempotent completion)
|
||||
# Returns 1 if status was updated, 0 if already completed/failed
|
||||
@@ -68,19 +75,24 @@ return 0
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveSession:
|
||||
SessionStatus = Literal["running", "completed", "failed"]
|
||||
RedisHash = dict[str, str]
|
||||
RedisStreamMessages = list[tuple[str, list[tuple[str, RedisHash]]]]
|
||||
|
||||
|
||||
class ActiveSession(BaseModel):
|
||||
"""Represents an active streaming session (metadata only, no in-memory queues)."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
session_id: str
|
||||
user_id: str | None
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
turn_id: str = ""
|
||||
blocking: bool = False # If True, HTTP request is waiting for completion
|
||||
status: Literal["running", "completed", "failed"] = "running"
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
asyncio_task: asyncio.Task | None = None
|
||||
status: SessionStatus = "running"
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
def _get_session_meta_key(session_id: str) -> str:
|
||||
@@ -93,7 +105,54 @@ def _get_turn_stream_key(turn_id: str) -> str:
|
||||
return f"{config.turn_stream_prefix}{turn_id}"
|
||||
|
||||
|
||||
def _parse_session_meta(meta: dict[Any, Any], session_id: str = "") -> ActiveSession:
|
||||
async def _redis_hset_mapping(redis: Any, key: str, mapping: RedisHash) -> int:
|
||||
return await cast(Awaitable[int], redis.hset(key, mapping=mapping))
|
||||
|
||||
|
||||
async def _redis_hgetall(redis: Any, key: str) -> RedisHash:
|
||||
return cast(
|
||||
RedisHash,
|
||||
await cast(Awaitable[dict[str, str]], redis.hgetall(key)),
|
||||
)
|
||||
|
||||
|
||||
async def _redis_hget(redis: Any, key: str, field: str) -> str | None:
|
||||
return cast(
|
||||
str | None,
|
||||
await cast(Awaitable[str | None], redis.hget(key, field)),
|
||||
)
|
||||
|
||||
|
||||
async def _redis_xread(
|
||||
redis: Any,
|
||||
streams: dict[str, str],
|
||||
*,
|
||||
count: int,
|
||||
block: int | None,
|
||||
) -> RedisStreamMessages:
|
||||
return cast(
|
||||
RedisStreamMessages,
|
||||
await cast(
|
||||
Awaitable[RedisStreamMessages],
|
||||
redis.xread(streams, count=count, block=block),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def _redis_complete_session(
|
||||
redis: Any,
|
||||
meta_key: str,
|
||||
status: SessionStatus,
|
||||
) -> int:
|
||||
return int(
|
||||
await cast(
|
||||
Awaitable[int | str],
|
||||
redis.eval(COMPLETE_SESSION_SCRIPT, 1, meta_key, status),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _parse_session_meta(meta: RedisHash, session_id: str = "") -> ActiveSession:
|
||||
"""Parse a raw Redis hash into a typed ActiveSession.
|
||||
|
||||
Centralises the ``meta.get(...)`` boilerplate so callers don't repeat it.
|
||||
@@ -107,7 +166,7 @@ def _parse_session_meta(meta: dict[Any, Any], session_id: str = "") -> ActiveSes
|
||||
tool_name=meta.get("tool_name", ""),
|
||||
turn_id=meta.get("turn_id", "") or session_id,
|
||||
blocking=meta.get("blocking") == "1",
|
||||
status=meta.get("status", "running"), # type: ignore[arg-type]
|
||||
status=cast(SessionStatus, meta.get("status", "running")),
|
||||
)
|
||||
|
||||
|
||||
@@ -170,7 +229,8 @@ async def create_session(
|
||||
# No need to delete old stream — each turn_id is a fresh UUID
|
||||
|
||||
hset_start = time.perf_counter()
|
||||
await redis.hset( # type: ignore[misc]
|
||||
await _redis_hset_mapping(
|
||||
redis,
|
||||
meta_key,
|
||||
mapping={
|
||||
"session_id": session_id,
|
||||
@@ -280,6 +340,108 @@ async def publish_chunk(
|
||||
return message_id
|
||||
|
||||
|
||||
def _decode_stream_chunk(msg_data: RedisHash) -> StreamBaseResponse | None:
|
||||
raw_data = msg_data.get("data")
|
||||
if raw_data is None:
|
||||
return None
|
||||
|
||||
chunk_data = orjson.loads(raw_data)
|
||||
return _reconstruct_chunk(chunk_data)
|
||||
|
||||
|
||||
async def _replay_messages(
|
||||
messages: RedisStreamMessages,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
*,
|
||||
last_message_id: str,
|
||||
) -> tuple[int, str]:
|
||||
replayed_count = 0
|
||||
replay_last_id = last_message_id
|
||||
|
||||
for _stream_name, stream_messages in messages:
|
||||
for msg_id, msg_data in stream_messages:
|
||||
replay_last_id = msg_id
|
||||
try:
|
||||
chunk = _decode_stream_chunk(msg_data)
|
||||
if chunk is None:
|
||||
continue
|
||||
await subscriber_queue.put(chunk)
|
||||
replayed_count += 1
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to replay message: %s", exc)
|
||||
|
||||
return replayed_count, replay_last_id
|
||||
|
||||
|
||||
async def _deliver_message_to_queue(
|
||||
session_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
chunk: StreamBaseResponse,
|
||||
*,
|
||||
last_delivered_id: str,
|
||||
log_meta: dict[str, Any],
|
||||
) -> bool:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
subscriber_queue.put(chunk),
|
||||
timeout=QUEUE_PUT_TIMEOUT,
|
||||
)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"timeout_s": QUEUE_PUT_TIMEOUT,
|
||||
"reason": "queue_full",
|
||||
}
|
||||
},
|
||||
)
|
||||
try:
|
||||
overflow_error = StreamError(
|
||||
errorText="Message delivery timeout - some messages may have been missed",
|
||||
code="QUEUE_OVERFLOW",
|
||||
details={
|
||||
"last_delivered_id": last_delivered_id,
|
||||
"recovery_hint": f"Reconnect with last_message_id={last_delivered_id}",
|
||||
},
|
||||
)
|
||||
subscriber_queue.put_nowait(overflow_error)
|
||||
except asyncio.QueueFull:
|
||||
logger.error(
|
||||
f"Cannot deliver overflow error for session {session_id}, queue completely blocked"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def _handle_xread_timeout(
|
||||
redis: Any,
|
||||
session_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
) -> bool:
|
||||
meta_key = _get_session_meta_key(session_id)
|
||||
status = await _redis_hget(redis, meta_key, "status")
|
||||
if status != "running":
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
subscriber_queue.put(StreamFinish()),
|
||||
timeout=QUEUE_PUT_TIMEOUT,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Timeout delivering finish event for session {session_id}")
|
||||
return False
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
subscriber_queue.put(StreamHeartbeat()),
|
||||
timeout=QUEUE_PUT_TIMEOUT,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Timeout delivering heartbeat for session {session_id}")
|
||||
return True
|
||||
|
||||
|
||||
async def subscribe_to_session(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
@@ -313,7 +475,7 @@ async def subscribe_to_session(
|
||||
redis_start = time.perf_counter()
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_session_meta_key(session_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
meta = await _redis_hgetall(redis, meta_key)
|
||||
hgetall_time = (time.perf_counter() - redis_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
|
||||
@@ -328,8 +490,8 @@ async def subscribe_to_session(
|
||||
"[TIMING] Session not found on first attempt, retrying after 50ms delay",
|
||||
extra={"json_fields": {**log_meta}},
|
||||
)
|
||||
await asyncio.sleep(0.05) # 50ms
|
||||
meta = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
await asyncio.sleep(SESSION_LOOKUP_RETRY_SECONDS)
|
||||
meta = await _redis_hgetall(redis, meta_key)
|
||||
if not meta:
|
||||
elapsed = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
@@ -374,7 +536,12 @@ async def subscribe_to_session(
|
||||
|
||||
# Step 1: Replay messages from Redis Stream
|
||||
xread_start = time.perf_counter()
|
||||
messages = await redis.xread({stream_key: last_message_id}, block=None, count=1000)
|
||||
messages = await _redis_xread(
|
||||
redis,
|
||||
{stream_key: last_message_id},
|
||||
block=None,
|
||||
count=STREAM_REPLAY_COUNT,
|
||||
)
|
||||
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={session_status}",
|
||||
@@ -387,22 +554,11 @@ async def subscribe_to_session(
|
||||
},
|
||||
)
|
||||
|
||||
replayed_count = 0
|
||||
replay_last_id = last_message_id
|
||||
if messages:
|
||||
for _stream_name, stream_messages in messages:
|
||||
for msg_id, msg_data in stream_messages:
|
||||
replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||
if "data" in msg_data:
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data["data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
await subscriber_queue.put(chunk)
|
||||
replayed_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to replay message: {e}")
|
||||
replayed_count, replay_last_id = await _replay_messages(
|
||||
messages,
|
||||
subscriber_queue,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
|
||||
@@ -455,7 +611,7 @@ async def _stream_listener(
|
||||
session_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
last_replayed_id: str,
|
||||
log_meta: dict | None = None,
|
||||
log_meta: dict[str, Any] | None = None,
|
||||
turn_id: str = "",
|
||||
) -> None:
|
||||
"""Listen to Redis Stream for new messages using blocking XREAD.
|
||||
@@ -499,8 +655,11 @@ async def _stream_listener(
|
||||
# Short timeout prevents frontend timeout (12s) while waiting for heartbeats (15s)
|
||||
xread_start = time.perf_counter()
|
||||
xread_count += 1
|
||||
messages = await redis.xread(
|
||||
{stream_key: current_id}, block=5000, count=100
|
||||
messages = await _redis_xread(
|
||||
redis,
|
||||
{stream_key: current_id},
|
||||
block=STREAM_XREAD_BLOCK_MS,
|
||||
count=STREAM_XREAD_COUNT,
|
||||
)
|
||||
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||
|
||||
@@ -532,114 +691,66 @@ async def _stream_listener(
|
||||
)
|
||||
|
||||
if not messages:
|
||||
# Timeout - check if session is still running
|
||||
meta_key = _get_session_meta_key(session_id)
|
||||
status = await redis.hget(meta_key, "status") # type: ignore[misc]
|
||||
# Stop if session metadata is gone (TTL expired) or status is not "running"
|
||||
if status != "running":
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
subscriber_queue.put(StreamFinish()),
|
||||
timeout=QUEUE_PUT_TIMEOUT,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"Timeout delivering finish event for session {session_id}"
|
||||
)
|
||||
if not await _handle_xread_timeout(
|
||||
redis,
|
||||
session_id,
|
||||
subscriber_queue,
|
||||
):
|
||||
break
|
||||
# Session still running - send heartbeat to keep connection alive
|
||||
# This prevents frontend timeout (12s) during long-running operations
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
subscriber_queue.put(StreamHeartbeat()),
|
||||
timeout=QUEUE_PUT_TIMEOUT,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"Timeout delivering heartbeat for session {session_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
for _stream_name, stream_messages in messages:
|
||||
for msg_id, msg_data in stream_messages:
|
||||
current_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
|
||||
if "data" not in msg_data:
|
||||
continue
|
||||
|
||||
current_id = msg_id
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data["data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
subscriber_queue.put(chunk),
|
||||
timeout=QUEUE_PUT_TIMEOUT,
|
||||
)
|
||||
# Update last delivered ID on successful delivery
|
||||
last_delivered_id = current_id
|
||||
messages_delivered += 1
|
||||
if first_message_time is None:
|
||||
first_message_time = time.perf_counter()
|
||||
elapsed = (first_message_time - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": elapsed,
|
||||
"chunk_type": type(chunk).__name__,
|
||||
}
|
||||
},
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"timeout_s": QUEUE_PUT_TIMEOUT,
|
||||
"reason": "queue_full",
|
||||
}
|
||||
},
|
||||
)
|
||||
# Send overflow error with recovery info
|
||||
try:
|
||||
overflow_error = StreamError(
|
||||
errorText="Message delivery timeout - some messages may have been missed",
|
||||
code="QUEUE_OVERFLOW",
|
||||
details={
|
||||
"last_delivered_id": last_delivered_id,
|
||||
"recovery_hint": f"Reconnect with last_message_id={last_delivered_id}",
|
||||
},
|
||||
)
|
||||
subscriber_queue.put_nowait(overflow_error)
|
||||
except asyncio.QueueFull:
|
||||
# Queue is completely stuck, nothing more we can do
|
||||
logger.error(
|
||||
f"Cannot deliver overflow error for session {session_id}, "
|
||||
"queue completely blocked"
|
||||
)
|
||||
|
||||
# Stop listening on finish
|
||||
if isinstance(chunk, StreamFinish):
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] StreamFinish received in {total_time / 1000:.1f}s; delivered={messages_delivered}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time,
|
||||
"messages_delivered": messages_delivered,
|
||||
}
|
||||
},
|
||||
)
|
||||
return
|
||||
chunk = _decode_stream_chunk(msg_data)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error processing stream message: {e}",
|
||||
extra={"json_fields": {**log_meta, "error": str(e)}},
|
||||
)
|
||||
continue
|
||||
|
||||
if chunk is None:
|
||||
continue
|
||||
|
||||
delivered = await _deliver_message_to_queue(
|
||||
session_id,
|
||||
subscriber_queue,
|
||||
chunk,
|
||||
last_delivered_id=last_delivered_id,
|
||||
log_meta=log_meta,
|
||||
)
|
||||
if delivered:
|
||||
last_delivered_id = current_id
|
||||
messages_delivered += 1
|
||||
if first_message_time is None:
|
||||
first_message_time = time.perf_counter()
|
||||
elapsed = (first_message_time - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": elapsed,
|
||||
"chunk_type": type(chunk).__name__,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if isinstance(chunk, StreamFinish):
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] StreamFinish received in {total_time / 1000:.1f}s; delivered={messages_delivered}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time,
|
||||
"messages_delivered": messages_delivered,
|
||||
}
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
except asyncio.CancelledError:
|
||||
elapsed = (time.perf_counter() - start_time) * 1000
|
||||
@@ -712,16 +823,16 @@ async def mark_session_completed(
|
||||
Returns:
|
||||
True if session was newly marked completed, False if already completed/failed
|
||||
"""
|
||||
status: Literal["completed", "failed"] = "failed" if error_message else "completed"
|
||||
status: SessionStatus = "failed" if error_message else "completed"
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_session_meta_key(session_id)
|
||||
|
||||
# Resolve turn_id for publishing to the correct stream
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
meta = await _redis_hgetall(redis, meta_key)
|
||||
turn_id = _parse_session_meta(meta, session_id).turn_id if meta else session_id
|
||||
|
||||
# Atomic compare-and-swap: only update if status is "running"
|
||||
result = await redis.eval(COMPLETE_SESSION_SCRIPT, 1, meta_key, status) # type: ignore[misc]
|
||||
result = await _redis_complete_session(redis, meta_key, status)
|
||||
|
||||
if result == 0:
|
||||
logger.debug(f"Session {session_id} already completed/failed, skipping")
|
||||
@@ -774,6 +885,18 @@ async def mark_session_completed(
|
||||
f"for session {session_id}: {e}"
|
||||
)
|
||||
|
||||
try:
|
||||
from backend.copilot.autopilot import handle_non_manual_session_completion
|
||||
|
||||
await handle_non_manual_session_completion(session_id)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to process non-manual completion for session %s: %s",
|
||||
session_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -788,7 +911,7 @@ async def get_session(session_id: str) -> ActiveSession | None:
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_session_meta_key(session_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
meta = await _redis_hgetall(redis, meta_key)
|
||||
|
||||
if not meta:
|
||||
return None
|
||||
@@ -815,7 +938,7 @@ async def get_session_with_expiry_info(
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_session_meta_key(session_id)
|
||||
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
meta = await _redis_hgetall(redis, meta_key)
|
||||
|
||||
if not meta:
|
||||
# Metadata expired — we can't resolve turn_id, so check using
|
||||
@@ -847,7 +970,7 @@ async def get_active_session(
|
||||
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_session_meta_key(session_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
meta = await _redis_hgetall(redis, meta_key)
|
||||
|
||||
if not meta:
|
||||
return None, "0-0"
|
||||
@@ -871,7 +994,9 @@ async def get_active_session(
|
||||
try:
|
||||
created_at = datetime.fromisoformat(created_at_str)
|
||||
age_seconds = (datetime.now(timezone.utc) - created_at).total_seconds()
|
||||
stale_threshold = COPILOT_CONSUMER_TIMEOUT_SECONDS + 300 # + 5min buffer
|
||||
stale_threshold = (
|
||||
COPILOT_CONSUMER_TIMEOUT_SECONDS + STALE_SESSION_BUFFER_SECONDS
|
||||
)
|
||||
if age_seconds > stale_threshold:
|
||||
logger.warning(
|
||||
f"[STALE_SESSION] Auto-completing stale session {session_id[:8]}... "
|
||||
@@ -946,7 +1071,11 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
||||
}
|
||||
|
||||
chunk_type = chunk_data.get("type")
|
||||
chunk_class = type_to_class.get(chunk_type) # type: ignore[arg-type]
|
||||
if not isinstance(chunk_type, str):
|
||||
logger.warning(f"Unknown chunk type: {chunk_type}")
|
||||
return None
|
||||
|
||||
chunk_class = type_to_class.get(chunk_type)
|
||||
|
||||
if chunk_class is None:
|
||||
logger.warning(f"Unknown chunk type: {chunk_type}")
|
||||
@@ -1011,7 +1140,7 @@ async def unsubscribe_from_session(
|
||||
|
||||
try:
|
||||
# Wait for the task to be cancelled with a timeout
|
||||
await asyncio.wait_for(listener_task, timeout=5.0)
|
||||
await asyncio.wait_for(listener_task, timeout=UNSUBSCRIBE_TIMEOUT_SECONDS)
|
||||
except asyncio.CancelledError:
|
||||
# Expected - the task was successfully cancelled
|
||||
pass
|
||||
|
||||
@@ -12,6 +12,7 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .completion_report import CompletionReportTool
|
||||
from .continue_run_block import ContinueRunBlockTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
@@ -50,10 +51,12 @@ if TYPE_CHECKING:
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
SESSION_SCOPED_TOOL_NAMES = {"completion_report"}
|
||||
|
||||
# Single source of truth for all tools
|
||||
TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"add_understanding": AddUnderstandingTool(),
|
||||
"completion_report": CompletionReportTool(),
|
||||
"create_agent": CreateAgentTool(),
|
||||
"customize_agent": CustomizeAgentTool(),
|
||||
"edit_agent": EditAgentTool(),
|
||||
@@ -103,16 +106,38 @@ find_agent_tool = TOOL_REGISTRY["find_agent"]
|
||||
run_agent_tool = TOOL_REGISTRY["run_agent"]
|
||||
|
||||
|
||||
def get_available_tools() -> list[ChatCompletionToolParam]:
|
||||
def is_tool_enabled(tool_name: str, session: "ChatSession | None" = None) -> bool:
|
||||
if tool_name not in TOOL_REGISTRY:
|
||||
return False
|
||||
if session is not None and session.disables_tool(tool_name):
|
||||
return False
|
||||
if tool_name not in SESSION_SCOPED_TOOL_NAMES:
|
||||
return True
|
||||
if session is None:
|
||||
return False
|
||||
return session.allows_tool(tool_name)
|
||||
|
||||
|
||||
def iter_available_tools(
|
||||
session: "ChatSession | None" = None,
|
||||
) -> list[tuple[str, BaseTool]]:
|
||||
return [
|
||||
(tool_name, tool)
|
||||
for tool_name, tool in TOOL_REGISTRY.items()
|
||||
if tool.is_available and is_tool_enabled(tool_name, session)
|
||||
]
|
||||
|
||||
|
||||
def get_available_tools(
|
||||
session: "ChatSession | None" = None,
|
||||
) -> list[ChatCompletionToolParam]:
|
||||
"""Return OpenAI tool schemas for tools available in the current environment.
|
||||
|
||||
Called per-request so that env-var or binary availability is evaluated
|
||||
fresh each time (e.g. browser_* tools are excluded when agent-browser
|
||||
CLI is not installed).
|
||||
"""
|
||||
return [
|
||||
tool.as_openai_tool() for tool in TOOL_REGISTRY.values() if tool.is_available
|
||||
]
|
||||
return [tool.as_openai_tool() for _, tool in iter_available_tools(session)]
|
||||
|
||||
|
||||
def get_tool(tool_name: str) -> BaseTool | None:
|
||||
@@ -128,6 +153,9 @@ async def execute_tool(
|
||||
tool_call_id: str,
|
||||
) -> "StreamToolOutputAvailable":
|
||||
"""Execute a tool by name."""
|
||||
if not is_tool_enabled(tool_name, session):
|
||||
raise ValueError(f"Tool {tool_name} is not enabled for this session")
|
||||
|
||||
tool = get_tool(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
"""Tool for finalizing non-manual Copilot sessions."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.constants import COPILOT_SESSION_PREFIX
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.session_types import CompletionReportInput
|
||||
from backend.data.db_accessors import review_db
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import CompletionReportSavedResponse, ErrorResponse, ToolResponseBase
|
||||
|
||||
|
||||
class CompletionReportTool(BaseTool):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "completion_report"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Finalize a non-manual session after you have finished the work. "
|
||||
"Use this exactly once at the end of the flow. "
|
||||
"Summarize what you did, state whether the user should be notified, "
|
||||
"and provide any email/callback content that should be used."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
schema = CompletionReportInput.model_json_schema()
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": schema.get("properties", {}),
|
||||
"required": [
|
||||
"thoughts",
|
||||
"should_notify_user",
|
||||
"email_title",
|
||||
"email_body",
|
||||
"callback_session_message",
|
||||
"approval_summary",
|
||||
],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
if session.is_manual:
|
||||
return ErrorResponse(
|
||||
message="completion_report is only available in non-manual sessions.",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
report = CompletionReportInput.model_validate(kwargs)
|
||||
except Exception as exc:
|
||||
return ErrorResponse(
|
||||
message="completion_report arguments are invalid.",
|
||||
error=str(exc),
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
pending_approval_count = await review_db().count_pending_reviews_for_graph_exec(
|
||||
f"{COPILOT_SESSION_PREFIX}{session.session_id}",
|
||||
session.user_id,
|
||||
)
|
||||
|
||||
if pending_approval_count > 0 and not report.approval_summary:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"approval_summary is required because this session has pending approvals."
|
||||
),
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
return CompletionReportSavedResponse(
|
||||
message="Completion report recorded successfully.",
|
||||
session_id=session.session_id,
|
||||
has_pending_approvals=pending_approval_count > 0,
|
||||
pending_approval_count=pending_approval_count,
|
||||
)
|
||||
@@ -0,0 +1,95 @@
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.session_types import ChatSessionStartType
|
||||
from backend.copilot.tools.completion_report import CompletionReportTool
|
||||
from backend.copilot.tools.models import CompletionReportSavedResponse, ResponseType
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_report_rejects_manual_sessions() -> None:
|
||||
tool = CompletionReportTool()
|
||||
session = ChatSession.new("user-1")
|
||||
|
||||
response = await tool._execute(
|
||||
user_id="user-1",
|
||||
session=session,
|
||||
thoughts="Wrapped up the session.",
|
||||
should_notify_user=False,
|
||||
email_title=None,
|
||||
email_body=None,
|
||||
callback_session_message=None,
|
||||
approval_summary=None,
|
||||
)
|
||||
|
||||
assert response.type == ResponseType.ERROR
|
||||
assert "non-manual sessions" in response.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_report_requires_approval_summary_when_pending(
|
||||
mocker,
|
||||
) -> None:
|
||||
tool = CompletionReportTool()
|
||||
session = ChatSession.new(
|
||||
"user-1",
|
||||
start_type=ChatSessionStartType.AUTOPILOT_NIGHTLY,
|
||||
)
|
||||
|
||||
review_store = Mock()
|
||||
review_store.count_pending_reviews_for_graph_exec = AsyncMock(return_value=2)
|
||||
mocker.patch(
|
||||
"backend.copilot.tools.completion_report.review_db",
|
||||
return_value=review_store,
|
||||
)
|
||||
|
||||
response = await tool._execute(
|
||||
user_id="user-1",
|
||||
session=session,
|
||||
thoughts="Prepared a recommendation for the user.",
|
||||
should_notify_user=True,
|
||||
email_title="Your nightly update",
|
||||
email_body="I found something worth reviewing.",
|
||||
callback_session_message="Let's review the next step together.",
|
||||
approval_summary=None,
|
||||
)
|
||||
|
||||
assert response.type == ResponseType.ERROR
|
||||
assert "approval_summary is required" in response.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_report_succeeds_without_pending_approvals(
|
||||
mocker,
|
||||
) -> None:
|
||||
tool = CompletionReportTool()
|
||||
session = ChatSession.new(
|
||||
"user-1",
|
||||
start_type=ChatSessionStartType.AUTOPILOT_CALLBACK,
|
||||
)
|
||||
|
||||
review_store = Mock()
|
||||
review_store.count_pending_reviews_for_graph_exec = AsyncMock(return_value=0)
|
||||
mocker.patch(
|
||||
"backend.copilot.tools.completion_report.review_db",
|
||||
return_value=review_store,
|
||||
)
|
||||
|
||||
response = await tool._execute(
|
||||
user_id="user-1",
|
||||
session=session,
|
||||
thoughts="Reviewed the account and prepared a useful follow-up.",
|
||||
should_notify_user=True,
|
||||
email_title="Autopilot found something useful",
|
||||
email_body="I put together a recommendation for you.",
|
||||
callback_session_message="Open this chat and I will walk you through it.",
|
||||
approval_summary=None,
|
||||
)
|
||||
|
||||
assert response.type == ResponseType.COMPLETION_REPORT_SAVED
|
||||
response = cast(CompletionReportSavedResponse, response)
|
||||
assert response.has_pending_approvals is False
|
||||
assert response.pending_approval_count == 0
|
||||
@@ -16,6 +16,7 @@ class ResponseType(str, Enum):
|
||||
ERROR = "error"
|
||||
NO_RESULTS = "no_results"
|
||||
NEED_LOGIN = "need_login"
|
||||
COMPLETION_REPORT_SAVED = "completion_report_saved"
|
||||
|
||||
# Agent discovery & execution
|
||||
AGENTS_FOUND = "agents_found"
|
||||
@@ -138,7 +139,7 @@ class NoResultsResponse(ToolResponseBase):
|
||||
"""Response when no agents found."""
|
||||
|
||||
type: ResponseType = ResponseType.NO_RESULTS
|
||||
suggestions: list[str] = []
|
||||
suggestions: list[str] = Field(default_factory=list)
|
||||
name: str = "no_results"
|
||||
|
||||
|
||||
@@ -170,8 +171,8 @@ class AgentDetails(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
in_library: bool = False
|
||||
inputs: dict[str, Any] = {}
|
||||
credentials: list[CredentialsMetaInput] = []
|
||||
inputs: dict[str, Any] = Field(default_factory=dict)
|
||||
credentials: list[CredentialsMetaInput] = Field(default_factory=list)
|
||||
execution_options: ExecutionOptions = Field(default_factory=ExecutionOptions)
|
||||
trigger_info: dict[str, Any] | None = None
|
||||
|
||||
@@ -191,7 +192,7 @@ class UserReadiness(BaseModel):
|
||||
"""User readiness status."""
|
||||
|
||||
has_all_credentials: bool = False
|
||||
missing_credentials: dict[str, Any] = {}
|
||||
missing_credentials: dict[str, Any] = Field(default_factory=dict)
|
||||
ready_to_run: bool = False
|
||||
|
||||
|
||||
@@ -248,6 +249,14 @@ class ErrorResponse(ToolResponseBase):
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class CompletionReportSavedResponse(ToolResponseBase):
|
||||
"""Response for completion_report."""
|
||||
|
||||
type: ResponseType = ResponseType.COMPLETION_REPORT_SAVED
|
||||
has_pending_approvals: bool = False
|
||||
pending_approval_count: int = 0
|
||||
|
||||
|
||||
class InputValidationErrorResponse(ToolResponseBase):
|
||||
"""Response when run_agent receives unknown input fields."""
|
||||
|
||||
@@ -436,9 +445,9 @@ class BlockDetails(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
inputs: dict[str, Any] = {}
|
||||
outputs: dict[str, Any] = {}
|
||||
credentials: list[CredentialsMetaInput] = []
|
||||
inputs: dict[str, Any] = Field(default_factory=dict)
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
credentials: list[CredentialsMetaInput] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BlockDetailsResponse(ToolResponseBase):
|
||||
@@ -631,7 +640,7 @@ class FolderInfo(BaseModel):
|
||||
class FolderTreeInfo(FolderInfo):
|
||||
"""Folder with nested children for tree display."""
|
||||
|
||||
children: list["FolderTreeInfo"] = []
|
||||
children: list["FolderTreeInfo"] = Field(default_factory=list)
|
||||
|
||||
|
||||
class FolderCreatedResponse(ToolResponseBase):
|
||||
@@ -678,6 +687,6 @@ class AgentsMovedToFolderResponse(ToolResponseBase):
|
||||
|
||||
type: ResponseType = ResponseType.AGENTS_MOVED_TO_FOLDER
|
||||
agent_ids: list[str]
|
||||
agent_names: list[str] = []
|
||||
agent_names: list[str] = Field(default_factory=list)
|
||||
folder_id: str | None = None
|
||||
count: int = 0
|
||||
|
||||
@@ -92,6 +92,19 @@ def user_db():
|
||||
return user_db
|
||||
|
||||
|
||||
def invited_user_db():
|
||||
if db.is_connected():
|
||||
from backend.data import invited_user as _invited_user_db
|
||||
|
||||
invited_user_db = _invited_user_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
invited_user_db = get_database_manager_async_client()
|
||||
|
||||
return invited_user_db
|
||||
|
||||
|
||||
def understanding_db():
|
||||
if db.is_connected():
|
||||
from backend.data import understanding as _understanding_db
|
||||
|
||||
@@ -79,6 +79,7 @@ from backend.data.graph import (
|
||||
from backend.data.human_review import (
|
||||
cancel_pending_reviews_for_execution,
|
||||
check_approval,
|
||||
count_pending_reviews_for_graph_exec,
|
||||
delete_review_by_node_exec_id,
|
||||
get_or_create_human_review,
|
||||
get_pending_reviews_for_execution,
|
||||
@@ -86,6 +87,7 @@ from backend.data.human_review import (
|
||||
has_pending_reviews_for_graph_exec,
|
||||
update_review_processed_status,
|
||||
)
|
||||
from backend.data.invited_user import list_invited_users_for_auth_users
|
||||
from backend.data.notifications import (
|
||||
clear_all_user_notification_batches,
|
||||
create_or_add_to_user_notification_batch,
|
||||
@@ -107,6 +109,7 @@ from backend.data.user import (
|
||||
get_user_email_verification,
|
||||
get_user_integrations,
|
||||
get_user_notification_preference,
|
||||
list_users,
|
||||
update_user_integrations,
|
||||
)
|
||||
from backend.data.workspace import (
|
||||
@@ -115,6 +118,7 @@ from backend.data.workspace import (
|
||||
get_or_create_workspace,
|
||||
get_workspace_file,
|
||||
get_workspace_file_by_path,
|
||||
get_workspace_files_by_ids,
|
||||
list_workspace_files,
|
||||
soft_delete_workspace_file,
|
||||
)
|
||||
@@ -237,6 +241,7 @@ class DatabaseManager(AppService):
|
||||
|
||||
# ============ User + Integrations ============ #
|
||||
get_user_by_id = _(get_user_by_id)
|
||||
list_users = _(list_users)
|
||||
get_user_integrations = _(get_user_integrations)
|
||||
update_user_integrations = _(update_user_integrations)
|
||||
|
||||
@@ -249,6 +254,7 @@ class DatabaseManager(AppService):
|
||||
# ============ Human In The Loop ============ #
|
||||
cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution)
|
||||
check_approval = _(check_approval)
|
||||
count_pending_reviews_for_graph_exec = _(count_pending_reviews_for_graph_exec)
|
||||
delete_review_by_node_exec_id = _(delete_review_by_node_exec_id)
|
||||
get_or_create_human_review = _(get_or_create_human_review)
|
||||
get_pending_reviews_for_execution = _(get_pending_reviews_for_execution)
|
||||
@@ -313,12 +319,16 @@ class DatabaseManager(AppService):
|
||||
# ============ Workspace ============ #
|
||||
count_workspace_files = _(count_workspace_files)
|
||||
create_workspace_file = _(create_workspace_file)
|
||||
get_workspace_files_by_ids = _(get_workspace_files_by_ids)
|
||||
get_or_create_workspace = _(get_or_create_workspace)
|
||||
get_workspace_file = _(get_workspace_file)
|
||||
get_workspace_file_by_path = _(get_workspace_file_by_path)
|
||||
list_workspace_files = _(list_workspace_files)
|
||||
soft_delete_workspace_file = _(soft_delete_workspace_file)
|
||||
|
||||
# ============ Invited Users ============ #
|
||||
list_invited_users_for_auth_users = _(list_invited_users_for_auth_users)
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = _(get_business_understanding)
|
||||
upsert_business_understanding = _(upsert_business_understanding)
|
||||
@@ -328,8 +338,28 @@ class DatabaseManager(AppService):
|
||||
update_block_optimized_description = _(update_block_optimized_description)
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
get_chat_messages_since = _(chat_db.get_chat_messages_since)
|
||||
get_chat_session_callback_token = _(chat_db.get_chat_session_callback_token)
|
||||
get_chat_session = _(chat_db.get_chat_session)
|
||||
create_chat_session_callback_token = _(chat_db.create_chat_session_callback_token)
|
||||
create_chat_session = _(chat_db.create_chat_session)
|
||||
get_manual_chat_sessions_since = _(chat_db.get_manual_chat_sessions_since)
|
||||
get_pending_notification_chat_sessions = _(
|
||||
chat_db.get_pending_notification_chat_sessions
|
||||
)
|
||||
get_pending_notification_chat_sessions_for_user = _(
|
||||
chat_db.get_pending_notification_chat_sessions_for_user
|
||||
)
|
||||
get_recent_completion_report_chat_sessions = _(
|
||||
chat_db.get_recent_completion_report_chat_sessions
|
||||
)
|
||||
get_recent_sent_email_chat_sessions = _(chat_db.get_recent_sent_email_chat_sessions)
|
||||
has_recent_manual_message = _(chat_db.has_recent_manual_message)
|
||||
has_session_since = _(chat_db.has_session_since)
|
||||
mark_chat_session_callback_token_consumed = _(
|
||||
chat_db.mark_chat_session_callback_token_consumed
|
||||
)
|
||||
session_exists_for_execution_tag = _(chat_db.session_exists_for_execution_tag)
|
||||
update_chat_session = _(chat_db.update_chat_session)
|
||||
add_chat_message = _(chat_db.add_chat_message)
|
||||
add_chat_messages_batch = _(chat_db.add_chat_messages_batch)
|
||||
@@ -374,10 +404,18 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
get_marketplace_graphs_for_monitoring = _(d.get_marketplace_graphs_for_monitoring)
|
||||
|
||||
# Human In The Loop
|
||||
count_pending_reviews_for_graph_exec = _(d.count_pending_reviews_for_graph_exec)
|
||||
has_pending_reviews_for_graph_exec = _(d.has_pending_reviews_for_graph_exec)
|
||||
|
||||
# User Emails
|
||||
get_user_email_by_id = _(d.get_user_email_by_id)
|
||||
list_users = _(d.list_users)
|
||||
|
||||
# CoPilot Chat Sessions
|
||||
get_recent_completion_report_chat_sessions = _(
|
||||
d.get_recent_completion_report_chat_sessions
|
||||
)
|
||||
get_recent_sent_email_chat_sessions = _(d.get_recent_sent_email_chat_sessions)
|
||||
|
||||
# Library
|
||||
list_library_agents = _(d.list_library_agents)
|
||||
@@ -433,12 +471,14 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
|
||||
# ============ User + Integrations ============ #
|
||||
get_user_by_id = d.get_user_by_id
|
||||
list_users = d.list_users
|
||||
get_user_integrations = d.get_user_integrations
|
||||
update_user_integrations = d.update_user_integrations
|
||||
|
||||
# ============ Human In The Loop ============ #
|
||||
cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution
|
||||
check_approval = d.check_approval
|
||||
count_pending_reviews_for_graph_exec = d.count_pending_reviews_for_graph_exec
|
||||
delete_review_by_node_exec_id = d.delete_review_by_node_exec_id
|
||||
get_or_create_human_review = d.get_or_create_human_review
|
||||
get_pending_reviews_for_execution = d.get_pending_reviews_for_execution
|
||||
@@ -506,12 +546,16 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
# ============ Workspace ============ #
|
||||
count_workspace_files = d.count_workspace_files
|
||||
create_workspace_file = d.create_workspace_file
|
||||
get_workspace_files_by_ids = d.get_workspace_files_by_ids
|
||||
get_or_create_workspace = d.get_or_create_workspace
|
||||
get_workspace_file = d.get_workspace_file
|
||||
get_workspace_file_by_path = d.get_workspace_file_by_path
|
||||
list_workspace_files = d.list_workspace_files
|
||||
soft_delete_workspace_file = d.soft_delete_workspace_file
|
||||
|
||||
# ============ Invited Users ============ #
|
||||
list_invited_users_for_auth_users = d.list_invited_users_for_auth_users
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = d.get_business_understanding
|
||||
upsert_business_understanding = d.upsert_business_understanding
|
||||
@@ -520,8 +564,26 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
get_blocks_needing_optimization = d.get_blocks_needing_optimization
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
get_chat_messages_since = d.get_chat_messages_since
|
||||
get_chat_session_callback_token = d.get_chat_session_callback_token
|
||||
get_chat_session = d.get_chat_session
|
||||
create_chat_session_callback_token = d.create_chat_session_callback_token
|
||||
create_chat_session = d.create_chat_session
|
||||
get_manual_chat_sessions_since = d.get_manual_chat_sessions_since
|
||||
get_pending_notification_chat_sessions = d.get_pending_notification_chat_sessions
|
||||
get_pending_notification_chat_sessions_for_user = (
|
||||
d.get_pending_notification_chat_sessions_for_user
|
||||
)
|
||||
get_recent_completion_report_chat_sessions = (
|
||||
d.get_recent_completion_report_chat_sessions
|
||||
)
|
||||
get_recent_sent_email_chat_sessions = d.get_recent_sent_email_chat_sessions
|
||||
has_recent_manual_message = d.has_recent_manual_message
|
||||
has_session_since = d.has_session_since
|
||||
mark_chat_session_callback_token_consumed = (
|
||||
d.mark_chat_session_callback_token_consumed
|
||||
)
|
||||
session_exists_for_execution_tag = d.session_exists_for_execution_tag
|
||||
update_chat_session = d.update_chat_session
|
||||
add_chat_message = d.add_chat_message
|
||||
add_chat_messages_batch = d.add_chat_messages_batch
|
||||
|
||||
@@ -342,6 +342,19 @@ async def has_pending_reviews_for_graph_exec(graph_exec_id: str) -> bool:
|
||||
return count > 0
|
||||
|
||||
|
||||
async def count_pending_reviews_for_graph_exec(
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
) -> int:
|
||||
return await PendingHumanReview.prisma().count(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"graphExecId": graph_exec_id,
|
||||
"status": ReviewStatus.WAITING,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_node_id(node_exec_id: str, get_node_execution) -> str:
|
||||
"""Resolve node_id from a node_exec_id.
|
||||
|
||||
|
||||
@@ -21,8 +21,8 @@ from backend.data.model import User
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.tally import get_business_understanding_input_from_tally, mask_email
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
merge_business_understanding_data,
|
||||
parse_business_understanding_input,
|
||||
)
|
||||
from backend.data.user import get_user_by_email, get_user_by_id
|
||||
from backend.executor.cluster_lock import AsyncClusterLock
|
||||
@@ -63,18 +63,16 @@ class InvitedUserRecord(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, invited_user: "prisma.models.InvitedUser") -> "InvitedUserRecord":
|
||||
payload = (
|
||||
invited_user.tallyUnderstanding
|
||||
if isinstance(invited_user.tallyUnderstanding, dict)
|
||||
else None
|
||||
)
|
||||
payload = parse_business_understanding_input(invited_user.tallyUnderstanding)
|
||||
return cls(
|
||||
id=invited_user.id,
|
||||
email=invited_user.email,
|
||||
status=invited_user.status,
|
||||
auth_user_id=invited_user.authUserId,
|
||||
name=invited_user.name,
|
||||
tally_understanding=payload,
|
||||
tally_understanding=(
|
||||
payload.model_dump(mode="json") if payload is not None else None
|
||||
),
|
||||
tally_status=invited_user.tallyStatus,
|
||||
tally_computed_at=invited_user.tallyComputedAt,
|
||||
tally_error=invited_user.tallyError,
|
||||
@@ -185,19 +183,13 @@ async def _apply_tally_understanding(
|
||||
invited_user: "prisma.models.InvitedUser",
|
||||
tx,
|
||||
) -> None:
|
||||
if not isinstance(invited_user.tallyUnderstanding, dict):
|
||||
return
|
||||
|
||||
try:
|
||||
input_data = BusinessUnderstandingInput.model_validate(
|
||||
invited_user.tallyUnderstanding
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Malformed tallyUnderstanding for invited user %s; skipping",
|
||||
invited_user.id,
|
||||
exc_info=True,
|
||||
)
|
||||
input_data = parse_business_understanding_input(invited_user.tallyUnderstanding)
|
||||
if input_data is None:
|
||||
if invited_user.tallyUnderstanding is not None:
|
||||
logger.warning(
|
||||
"Malformed tallyUnderstanding for invited user %s; skipping",
|
||||
invited_user.id,
|
||||
)
|
||||
return
|
||||
|
||||
payload = merge_business_understanding_data({}, input_data)
|
||||
@@ -223,6 +215,18 @@ async def list_invited_users(
|
||||
return [InvitedUserRecord.from_db(iu) for iu in invited_users], total
|
||||
|
||||
|
||||
async def list_invited_users_for_auth_users(
|
||||
auth_user_ids: list[str],
|
||||
) -> list[InvitedUserRecord]:
|
||||
if not auth_user_ids:
|
||||
return []
|
||||
|
||||
invited_users = await prisma.models.InvitedUser.prisma().find_many(
|
||||
where={"authUserId": {"in": auth_user_ids}}
|
||||
)
|
||||
return [InvitedUserRecord.from_db(invited_user) for invited_user in invited_users]
|
||||
|
||||
|
||||
async def create_invited_user(
|
||||
email: str, name: Optional[str] = None
|
||||
) -> InvitedUserRecord:
|
||||
|
||||
@@ -31,6 +31,18 @@ def _json_to_list(value: Any) -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
def parse_business_understanding_input(
|
||||
payload: Any,
|
||||
) -> "BusinessUnderstandingInput | None":
|
||||
if payload is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return BusinessUnderstandingInput.model_validate(payload)
|
||||
except pydantic.ValidationError:
|
||||
return None
|
||||
|
||||
|
||||
class BusinessUnderstandingInput(pydantic.BaseModel):
|
||||
"""Input model for updating business understanding - all fields optional for incremental updates."""
|
||||
|
||||
|
||||
@@ -62,6 +62,61 @@ async def get_user_by_id(user_id: str) -> User:
|
||||
return User.from_db(user)
|
||||
|
||||
|
||||
async def list_users(
|
||||
limit: int = 500,
|
||||
cursor: str | None = None,
|
||||
) -> list[User]:
|
||||
try:
|
||||
kwargs: dict = {
|
||||
"take": limit,
|
||||
"order": {"id": "asc"},
|
||||
}
|
||||
if cursor is not None:
|
||||
kwargs["cursor"] = {"id": cursor}
|
||||
kwargs["skip"] = 1
|
||||
users = await PrismaUser.prisma().find_many(**kwargs)
|
||||
return [User.from_db(user) for user in users]
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to list users: {e}") from e
|
||||
|
||||
|
||||
async def search_users(query: str, limit: int = 20) -> list[User]:
|
||||
normalized_query = query.strip()
|
||||
if not normalized_query:
|
||||
return []
|
||||
|
||||
try:
|
||||
users = await PrismaUser.prisma().find_many(
|
||||
where={
|
||||
"OR": [
|
||||
{
|
||||
"email": {
|
||||
"contains": normalized_query,
|
||||
"mode": "insensitive",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": {
|
||||
"contains": normalized_query,
|
||||
"mode": "insensitive",
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": {
|
||||
"contains": normalized_query,
|
||||
"mode": "insensitive",
|
||||
}
|
||||
},
|
||||
]
|
||||
},
|
||||
order={"updatedAt": "desc"},
|
||||
take=limit,
|
||||
)
|
||||
return [User.from_db(user) for user in users]
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to search users for query {query!r}: {e}") from e
|
||||
|
||||
|
||||
async def get_user_email_by_id(user_id: str) -> Optional[str]:
|
||||
try:
|
||||
user = await prisma.user.find_unique(where={"id": user_id})
|
||||
|
||||
@@ -254,6 +254,23 @@ async def list_workspace_files(
|
||||
return [WorkspaceFile.from_db(f) for f in files]
|
||||
|
||||
|
||||
async def get_workspace_files_by_ids(
|
||||
workspace_id: str,
|
||||
file_ids: list[str],
|
||||
) -> list[WorkspaceFile]:
|
||||
if not file_ids:
|
||||
return []
|
||||
|
||||
files = await UserWorkspaceFile.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": file_ids},
|
||||
"workspaceId": workspace_id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
)
|
||||
return [WorkspaceFile.from_db(file) for file in files]
|
||||
|
||||
|
||||
async def count_workspace_files(
|
||||
workspace_id: str,
|
||||
path_prefix: Optional[str] = None,
|
||||
|
||||
@@ -24,6 +24,12 @@ from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
|
||||
from backend.copilot.autopilot import (
|
||||
dispatch_nightly_copilot as dispatch_nightly_copilot_async,
|
||||
)
|
||||
from backend.copilot.autopilot import (
|
||||
send_nightly_copilot_emails as send_nightly_copilot_emails_async,
|
||||
)
|
||||
from backend.copilot.optimize_blocks import optimize_block_descriptions
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.data.model import CredentialsMetaInput, GraphInput
|
||||
@@ -259,6 +265,16 @@ def cleanup_oauth_tokens():
|
||||
run_async(_cleanup())
|
||||
|
||||
|
||||
def dispatch_nightly_copilot():
|
||||
"""Dispatch proactive nightly copilot sessions."""
|
||||
return run_async(dispatch_nightly_copilot_async())
|
||||
|
||||
|
||||
def send_nightly_copilot_emails():
|
||||
"""Send emails for completed non-manual copilot sessions."""
|
||||
return run_async(send_nightly_copilot_emails_async())
|
||||
|
||||
|
||||
def execution_accuracy_alerts():
|
||||
"""Check execution accuracy and send alerts if drops are detected."""
|
||||
return report_execution_accuracy_alerts()
|
||||
@@ -404,7 +420,7 @@ class GraphExecutionJobInfo(GraphExecutionJobArgs):
|
||||
) -> "GraphExecutionJobInfo":
|
||||
# Extract timezone from the trigger if it's a CronTrigger
|
||||
timezone_str = "UTC"
|
||||
if hasattr(job_obj.trigger, "timezone"):
|
||||
if isinstance(job_obj.trigger, CronTrigger):
|
||||
timezone_str = str(job_obj.trigger.timezone)
|
||||
|
||||
return GraphExecutionJobInfo(
|
||||
@@ -619,6 +635,24 @@ class Scheduler(AppService):
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
self.scheduler.add_job(
|
||||
dispatch_nightly_copilot,
|
||||
id="dispatch_nightly_copilot",
|
||||
trigger=CronTrigger(minute="0,30", timezone=ZoneInfo("UTC")),
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
self.scheduler.add_job(
|
||||
send_nightly_copilot_emails,
|
||||
id="send_nightly_copilot_emails",
|
||||
trigger=CronTrigger(minute="15,45", timezone=ZoneInfo("UTC")),
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
|
||||
self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED)
|
||||
self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES)
|
||||
@@ -793,6 +827,14 @@ class Scheduler(AppService):
|
||||
"""Manually trigger embedding backfill for approved store agents."""
|
||||
return ensure_embeddings_coverage()
|
||||
|
||||
@expose
|
||||
def execute_dispatch_nightly_copilot(self):
|
||||
return dispatch_nightly_copilot()
|
||||
|
||||
@expose
|
||||
def execute_send_nightly_copilot_emails(self):
|
||||
return send_nightly_copilot_emails()
|
||||
|
||||
|
||||
class SchedulerClient(AppServiceClient):
|
||||
@classmethod
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import pathlib
|
||||
from typing import Any
|
||||
|
||||
from postmarker.core import PostmarkClient
|
||||
from postmarker.models.emails import EmailManager
|
||||
@@ -7,12 +8,14 @@ from prisma.enums import NotificationType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
NotificationDataType_co,
|
||||
NotificationEventModel,
|
||||
NotificationTypeOverride,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.text import TextFormatter
|
||||
from backend.util.url import get_frontend_base_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
@@ -46,6 +49,102 @@ class EmailSender:
|
||||
|
||||
MAX_EMAIL_CHARS = 5_000_000 # ~5MB buffer
|
||||
|
||||
def _get_unsubscribe_link(self, user_unsubscribe_link: str | None) -> str:
|
||||
return user_unsubscribe_link or f"{get_frontend_base_url()}/profile/settings"
|
||||
|
||||
def _format_template_email(
|
||||
self,
|
||||
*,
|
||||
subject_template: str,
|
||||
content_template: str,
|
||||
data: Any,
|
||||
unsubscribe_link: str,
|
||||
) -> tuple[str, str]:
|
||||
return self.formatter.format_email(
|
||||
base_template=self._read_template("templates/base.html.jinja2"),
|
||||
subject_template=subject_template,
|
||||
content_template=content_template,
|
||||
data=data,
|
||||
unsubscribe_link=unsubscribe_link,
|
||||
)
|
||||
|
||||
def _build_large_output_summary(
|
||||
self,
|
||||
data: (
|
||||
NotificationEventModel[NotificationDataType_co]
|
||||
| list[NotificationEventModel[NotificationDataType_co]]
|
||||
),
|
||||
*,
|
||||
email_size: int,
|
||||
base_url: str,
|
||||
) -> str:
|
||||
if isinstance(data, list):
|
||||
if not data:
|
||||
return (
|
||||
"⚠️ A notification generated a very large output "
|
||||
f"({email_size / 1_000_000:.2f} MB)."
|
||||
)
|
||||
event = data[0]
|
||||
else:
|
||||
event = data
|
||||
execution_url = (
|
||||
f"{base_url}/executions/{event.id}" if event.id is not None else None
|
||||
)
|
||||
|
||||
if isinstance(event.data, AgentRunData):
|
||||
lines = [
|
||||
f"⚠️ Your agent '{event.data.agent_name}' generated a very large output ({email_size / 1_000_000:.2f} MB).",
|
||||
"",
|
||||
f"Execution time: {event.data.execution_time}",
|
||||
f"Credits used: {event.data.credits_used}",
|
||||
]
|
||||
if execution_url is not None:
|
||||
lines.append(f"View full results: {execution_url}")
|
||||
return "\n".join(lines)
|
||||
|
||||
lines = [
|
||||
f"⚠️ A notification generated a very large output ({email_size / 1_000_000:.2f} MB).",
|
||||
]
|
||||
if execution_url is not None:
|
||||
lines.extend(["", f"View full results: {execution_url}"])
|
||||
return "\n".join(lines)
|
||||
|
||||
def send_template(
|
||||
self,
|
||||
*,
|
||||
user_email: str,
|
||||
subject: str,
|
||||
template_name: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
user_unsubscribe_link: str | None = None,
|
||||
) -> None:
|
||||
"""Send an email using a named Jinja2 template file.
|
||||
|
||||
Unlike ``send_templated`` (which resolves templates via
|
||||
``NotificationType``), this method accepts a template filename
|
||||
directly. Both delegate to the shared ``_format_template_email``
|
||||
+ ``_send_email`` pipeline.
|
||||
"""
|
||||
if not self.postmark:
|
||||
logger.warning("Postmark client not initialized, email not sent")
|
||||
return
|
||||
|
||||
unsubscribe_link = self._get_unsubscribe_link(user_unsubscribe_link)
|
||||
|
||||
_, full_message = self._format_template_email(
|
||||
subject_template="{{ subject }}",
|
||||
content_template=self._read_template(f"templates/{template_name}"),
|
||||
data={"subject": subject, **(data or {})},
|
||||
unsubscribe_link=unsubscribe_link,
|
||||
)
|
||||
|
||||
self._send_email(
|
||||
user_email=user_email,
|
||||
subject=subject,
|
||||
body=full_message,
|
||||
user_unsubscribe_link=unsubscribe_link,
|
||||
)
|
||||
|
||||
def send_templated(
|
||||
self,
|
||||
notification: NotificationType,
|
||||
@@ -62,21 +161,18 @@ class EmailSender:
|
||||
return
|
||||
|
||||
template = self._get_template(notification)
|
||||
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
base_url = get_frontend_base_url()
|
||||
unsubscribe_link = self._get_unsubscribe_link(user_unsub_link)
|
||||
|
||||
# Normalize data
|
||||
template_data = {"notifications": data} if isinstance(data, list) else data
|
||||
|
||||
try:
|
||||
subject, full_message = self.formatter.format_email(
|
||||
base_template=template.base_template,
|
||||
subject, full_message = self._format_template_email(
|
||||
subject_template=template.subject_template,
|
||||
content_template=template.body_template,
|
||||
data=template_data,
|
||||
unsubscribe_link=f"{base_url}/profile/settings",
|
||||
unsubscribe_link=unsubscribe_link,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting full message: {e}")
|
||||
@@ -90,20 +186,17 @@ class EmailSender:
|
||||
"Sending summary email instead."
|
||||
)
|
||||
|
||||
# Create lightweight summary
|
||||
summary_message = (
|
||||
f"⚠️ Your agent '{getattr(data, 'agent_name', 'Unknown')}' "
|
||||
f"generated a very large output ({email_size / 1_000_000:.2f} MB).\n\n"
|
||||
f"Execution time: {getattr(data, 'execution_time', 'N/A')}\n"
|
||||
f"Credits used: {getattr(data, 'credits_used', 'N/A')}\n"
|
||||
f"View full results: {base_url}/executions/{getattr(data, 'id', 'N/A')}"
|
||||
summary_message = self._build_large_output_summary(
|
||||
data,
|
||||
email_size=email_size,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
self._send_email(
|
||||
user_email=user_email,
|
||||
subject=f"{subject} (Output Too Large)",
|
||||
body=summary_message,
|
||||
user_unsubscribe_link=user_unsub_link,
|
||||
user_unsubscribe_link=unsubscribe_link,
|
||||
)
|
||||
return # Skip sending full email
|
||||
|
||||
@@ -112,7 +205,7 @@ class EmailSender:
|
||||
user_email=user_email,
|
||||
subject=subject,
|
||||
body=full_message,
|
||||
user_unsubscribe_link=user_unsub_link,
|
||||
user_unsubscribe_link=unsubscribe_link,
|
||||
)
|
||||
|
||||
def _get_template(self, notification: NotificationType):
|
||||
@@ -123,17 +216,18 @@ class EmailSender:
|
||||
logger.debug(
|
||||
f"Template full path: {pathlib.Path(__file__).parent / template_path}"
|
||||
)
|
||||
base_template_path = "templates/base.html.jinja2"
|
||||
with open(pathlib.Path(__file__).parent / base_template_path, "r") as file:
|
||||
base_template = file.read()
|
||||
with open(pathlib.Path(__file__).parent / template_path, "r") as file:
|
||||
template = file.read()
|
||||
base_template = self._read_template("templates/base.html.jinja2")
|
||||
template = self._read_template(template_path)
|
||||
return Template(
|
||||
subject_template=notification_type_override.subject,
|
||||
body_template=template,
|
||||
base_template=base_template,
|
||||
)
|
||||
|
||||
def _read_template(self, template_path: str) -> str:
|
||||
with open(pathlib.Path(__file__).parent / template_path, "r") as file:
|
||||
return file.read()
|
||||
|
||||
def _send_email(
|
||||
self,
|
||||
user_email: str,
|
||||
@@ -144,18 +238,33 @@ class EmailSender:
|
||||
if not self.postmark:
|
||||
logger.warning("Email tried to send without postmark configured")
|
||||
return
|
||||
sender_email = settings.config.postmark_sender_email
|
||||
if not sender_email:
|
||||
logger.warning("postmark_sender_email not configured, email not sent")
|
||||
return
|
||||
unsubscribe_link = self._get_unsubscribe_link(user_unsubscribe_link)
|
||||
logger.debug(f"Sending email to {user_email} with subject {subject}")
|
||||
self.postmark.emails.send(
|
||||
From=settings.config.postmark_sender_email,
|
||||
From=sender_email,
|
||||
To=user_email,
|
||||
Subject=subject,
|
||||
HtmlBody=body,
|
||||
Headers=(
|
||||
{
|
||||
"List-Unsubscribe-Post": "List-Unsubscribe=One-Click",
|
||||
"List-Unsubscribe": f"<{user_unsubscribe_link}>",
|
||||
}
|
||||
if user_unsubscribe_link
|
||||
else None
|
||||
),
|
||||
Headers={
|
||||
"List-Unsubscribe-Post": "List-Unsubscribe=One-Click",
|
||||
"List-Unsubscribe": f"<{unsubscribe_link}>",
|
||||
},
|
||||
)
|
||||
|
||||
def send_html(
|
||||
self,
|
||||
user_email: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
user_unsubscribe_link: str | None = None,
|
||||
) -> None:
|
||||
self._send_email(
|
||||
user_email=user_email,
|
||||
subject=subject,
|
||||
body=body,
|
||||
user_unsubscribe_link=user_unsubscribe_link,
|
||||
)
|
||||
|
||||
241
autogpt_platform/backend/backend/notifications/email_test.py
Normal file
241
autogpt_platform/backend/backend/notifications/email_test.py
Normal file
@@ -0,0 +1,241 @@
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
|
||||
from backend.api.test_helpers import override_config
|
||||
from backend.copilot.autopilot_email import _markdown_to_email_html
|
||||
from backend.notifications.email import EmailSender, settings
|
||||
from backend.util.settings import AppEnvironment
|
||||
|
||||
|
||||
def test_markdown_to_email_html_renders_bold_and_italic() -> None:
|
||||
html = _markdown_to_email_html("**bold** and *italic*")
|
||||
assert "<strong>bold</strong>" in html
|
||||
assert "<em>italic</em>" in html
|
||||
assert 'style="' in html
|
||||
|
||||
|
||||
def test_markdown_to_email_html_renders_links() -> None:
|
||||
html = _markdown_to_email_html("[click here](https://example.com)")
|
||||
assert 'href="https://example.com"' in html
|
||||
assert "click here" in html
|
||||
assert "color: #7733F5" in html
|
||||
|
||||
|
||||
def test_markdown_to_email_html_renders_bullet_list() -> None:
|
||||
html = _markdown_to_email_html("- item one\n- item two")
|
||||
assert "<ul" in html
|
||||
assert "<li" in html
|
||||
assert "item one" in html
|
||||
assert "item two" in html
|
||||
|
||||
|
||||
def test_markdown_to_email_html_handles_empty_input() -> None:
|
||||
assert _markdown_to_email_html(None) == ""
|
||||
assert _markdown_to_email_html("") == ""
|
||||
assert _markdown_to_email_html(" ") == ""
|
||||
|
||||
|
||||
def test_send_template_renders_nightly_copilot_email(mocker) -> None:
|
||||
sender = EmailSender()
|
||||
sender.postmark = cast(Any, object())
|
||||
send_email = mocker.patch.object(sender, "_send_email")
|
||||
|
||||
sender.send_template(
|
||||
user_email="user@example.com",
|
||||
subject="Autopilot update",
|
||||
template_name="nightly_copilot.html.jinja2",
|
||||
data={
|
||||
"email_body_html": _markdown_to_email_html(
|
||||
"I found something useful for you.\n\n"
|
||||
"Open Copilot and I will walk you through it."
|
||||
),
|
||||
"cta_url": "https://example.com/copilot?callbackToken=token-1",
|
||||
"cta_label": "Open Copilot",
|
||||
},
|
||||
)
|
||||
|
||||
body = send_email.call_args.kwargs["body"]
|
||||
|
||||
assert "I found something useful for you." in body
|
||||
assert "Open Copilot" in body
|
||||
assert "Approval needed" not in body
|
||||
assert send_email.call_args.kwargs["user_unsubscribe_link"].endswith(
|
||||
"/profile/settings"
|
||||
)
|
||||
|
||||
|
||||
def test_send_template_renders_nightly_copilot_approval_block(mocker) -> None:
|
||||
sender = EmailSender()
|
||||
sender.postmark = cast(Any, object())
|
||||
send_email = mocker.patch.object(sender, "_send_email")
|
||||
|
||||
sender.send_template(
|
||||
user_email="user@example.com",
|
||||
subject="Autopilot update",
|
||||
template_name="nightly_copilot.html.jinja2",
|
||||
data={
|
||||
"email_body_html": _markdown_to_email_html(
|
||||
"I prepared a change worth reviewing."
|
||||
),
|
||||
"approval_summary_html": _markdown_to_email_html(
|
||||
"I drafted a follow-up because it matches your recent activity."
|
||||
),
|
||||
"cta_url": "https://example.com/copilot?sessionId=session-1&showAutopilot=1",
|
||||
"cta_label": "Review in Copilot",
|
||||
},
|
||||
)
|
||||
|
||||
body = send_email.call_args.kwargs["body"]
|
||||
|
||||
assert "Approval needed" in body
|
||||
assert "If you want it to happen, please hit approve." in body
|
||||
assert "Review in Copilot" in body
|
||||
|
||||
|
||||
def test_send_template_renders_nightly_copilot_callback_email(mocker) -> None:
|
||||
sender = EmailSender()
|
||||
sender.postmark = cast(Any, object())
|
||||
send_email = mocker.patch.object(sender, "_send_email")
|
||||
|
||||
sender.send_template(
|
||||
user_email="user@example.com",
|
||||
subject="Autopilot update",
|
||||
template_name="nightly_copilot_callback.html.jinja2",
|
||||
data={
|
||||
"email_body_html": _markdown_to_email_html(
|
||||
"I prepared a follow-up based on your recent work."
|
||||
),
|
||||
"cta_url": "https://example.com/copilot?callbackToken=token-1",
|
||||
"cta_label": "Open Copilot",
|
||||
},
|
||||
)
|
||||
|
||||
body = send_email.call_args.kwargs["body"]
|
||||
|
||||
assert "Autopilot picked up where you left off" in body
|
||||
assert "I prepared a follow-up based on your recent work." in body
|
||||
|
||||
|
||||
def test_send_template_renders_nightly_copilot_callback_approval_block(mocker) -> None:
|
||||
sender = EmailSender()
|
||||
sender.postmark = cast(Any, object())
|
||||
send_email = mocker.patch.object(sender, "_send_email")
|
||||
|
||||
sender.send_template(
|
||||
user_email="user@example.com",
|
||||
subject="Autopilot update",
|
||||
template_name="nightly_copilot_callback.html.jinja2",
|
||||
data={
|
||||
"email_body_html": _markdown_to_email_html(
|
||||
"I prepared a follow-up based on your recent work."
|
||||
),
|
||||
"approval_summary_html": _markdown_to_email_html(
|
||||
"I want your approval before I apply the next step."
|
||||
),
|
||||
"cta_url": "https://example.com/copilot?sessionId=session-1&showAutopilot=1",
|
||||
"cta_label": "Review in Copilot",
|
||||
},
|
||||
)
|
||||
|
||||
body = send_email.call_args.kwargs["body"]
|
||||
|
||||
assert "Approval needed" in body
|
||||
assert "I want your approval before I apply the next step." in body
|
||||
|
||||
|
||||
def test_send_template_renders_nightly_copilot_invite_cta_email(mocker) -> None:
|
||||
sender = EmailSender()
|
||||
sender.postmark = cast(Any, object())
|
||||
send_email = mocker.patch.object(sender, "_send_email")
|
||||
|
||||
sender.send_template(
|
||||
user_email="user@example.com",
|
||||
subject="Autopilot update",
|
||||
template_name="nightly_copilot_invite_cta.html.jinja2",
|
||||
data={
|
||||
"email_body_html": _markdown_to_email_html(
|
||||
"I put together an example of how Autopilot could help you."
|
||||
),
|
||||
"cta_url": "https://example.com/copilot?callbackToken=token-1",
|
||||
"cta_label": "Try Copilot",
|
||||
},
|
||||
)
|
||||
|
||||
body = send_email.call_args.kwargs["body"]
|
||||
|
||||
assert "Your Autopilot beta access is waiting" in body
|
||||
assert "I put together an example of how Autopilot could help you." in body
|
||||
assert "Try Copilot" in body
|
||||
|
||||
|
||||
def test_send_template_renders_nightly_copilot_invite_cta_approval_block(
|
||||
mocker,
|
||||
) -> None:
|
||||
sender = EmailSender()
|
||||
sender.postmark = cast(Any, object())
|
||||
send_email = mocker.patch.object(sender, "_send_email")
|
||||
|
||||
sender.send_template(
|
||||
user_email="user@example.com",
|
||||
subject="Autopilot update",
|
||||
template_name="nightly_copilot_invite_cta.html.jinja2",
|
||||
data={
|
||||
"email_body_html": _markdown_to_email_html(
|
||||
"I put together an example of how Autopilot could help you."
|
||||
),
|
||||
"approval_summary_html": _markdown_to_email_html(
|
||||
"If this looks useful, approve the next step to try it."
|
||||
),
|
||||
"cta_url": "https://example.com/copilot?sessionId=session-1&showAutopilot=1",
|
||||
"cta_label": "Review in Copilot",
|
||||
},
|
||||
)
|
||||
|
||||
body = send_email.call_args.kwargs["body"]
|
||||
|
||||
assert "Approval needed" in body
|
||||
assert "If this looks useful, approve the next step to try it." in body
|
||||
|
||||
|
||||
def test_send_template_still_sends_in_production(mocker) -> None:
|
||||
sender = EmailSender()
|
||||
sender.postmark = cast(Any, object())
|
||||
send_email = mocker.patch.object(sender, "_send_email")
|
||||
|
||||
with override_config(settings, "app_env", AppEnvironment.PRODUCTION):
|
||||
sender.send_template(
|
||||
user_email="user@example.com",
|
||||
subject="Autopilot update",
|
||||
template_name="nightly_copilot.html.jinja2",
|
||||
data={
|
||||
"email_body_html": _markdown_to_email_html(
|
||||
"I found something useful for you."
|
||||
),
|
||||
"cta_url": "https://example.com/copilot?callbackToken=token-1",
|
||||
"cta_label": "Open Copilot",
|
||||
},
|
||||
)
|
||||
|
||||
send_email.assert_called_once()
|
||||
|
||||
|
||||
def test_send_html_uses_default_unsubscribe_link(mocker) -> None:
|
||||
sender = EmailSender()
|
||||
send = mocker.Mock()
|
||||
sender.postmark = cast(Any, SimpleNamespace(emails=SimpleNamespace(send=send)))
|
||||
|
||||
mocker.patch(
|
||||
"backend.notifications.email.get_frontend_base_url",
|
||||
return_value="https://example.com",
|
||||
)
|
||||
with override_config(settings, "postmark_sender_email", "test@example.com"):
|
||||
sender.send_html(
|
||||
user_email="user@example.com",
|
||||
subject="Autopilot update",
|
||||
body="<p>Hello</p>",
|
||||
)
|
||||
|
||||
headers = send.call_args.kwargs["Headers"]
|
||||
|
||||
assert headers["List-Unsubscribe-Post"] == "List-Unsubscribe=One-Click"
|
||||
assert headers["List-Unsubscribe"] == "<https://example.com/profile/settings>"
|
||||
@@ -29,7 +29,6 @@
|
||||
</noscript>
|
||||
<![endif]-->
|
||||
<style type="text/css">
|
||||
/* RESET STYLES */
|
||||
html,
|
||||
body {
|
||||
margin: 0 !important;
|
||||
@@ -85,7 +84,6 @@
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
/* iOS BLUE LINKS */
|
||||
a[x-apple-data-detectors] {
|
||||
color: inherit !important;
|
||||
text-decoration: none !important;
|
||||
@@ -95,13 +93,11 @@
|
||||
line-height: inherit !important;
|
||||
}
|
||||
|
||||
/* ANDROID CENTER FIX */
|
||||
div[style*="margin: 16px 0;"] {
|
||||
margin: 0 !important;
|
||||
}
|
||||
|
||||
/* MEDIA QUERIES */
|
||||
@media all and (max-width:639px) {
|
||||
@media all and (max-width: 639px) {
|
||||
.wrapper {
|
||||
width: 100% !important;
|
||||
}
|
||||
@@ -113,8 +109,8 @@
|
||||
}
|
||||
|
||||
.row {
|
||||
padding-left: 20px !important;
|
||||
padding-right: 20px !important;
|
||||
padding-left: 24px !important;
|
||||
padding-right: 24px !important;
|
||||
}
|
||||
|
||||
.col-mobile {
|
||||
@@ -136,11 +132,6 @@
|
||||
float: none !important;
|
||||
}
|
||||
|
||||
.mobile-left {
|
||||
text-align: center !important;
|
||||
float: left !important;
|
||||
}
|
||||
|
||||
.mobile-hide {
|
||||
display: none !important;
|
||||
}
|
||||
@@ -155,9 +146,9 @@
|
||||
max-width: 100% !important;
|
||||
}
|
||||
|
||||
.ml-btn-container {
|
||||
width: 100% !important;
|
||||
max-width: 100% !important;
|
||||
.card-inner {
|
||||
padding-left: 24px !important;
|
||||
padding-right: 24px !important;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
@@ -174,170 +165,139 @@
|
||||
<title>{{data.title}}</title>
|
||||
</head>
|
||||
|
||||
<body style="margin: 0 !important; padding: 0 !important; background-color:#070629;">
|
||||
<body style="margin: 0 !important; padding: 0 !important; background-color: #070629;">
|
||||
<div class="document" role="article" aria-roledescription="email" aria-label lang dir="ltr"
|
||||
style="background-color:#070629; line-height: 100%; font-size:medium; font-size:max(16px, 1rem);">
|
||||
<!-- Main Content -->
|
||||
style="background-color: #070629; line-height: 100%; font-size: medium; font-size: max(16px, 1rem);">
|
||||
|
||||
<table width="100%" align="center" cellspacing="0" cellpadding="0" border="0">
|
||||
<tr>
|
||||
<td class="background" bgcolor="#070629" align="center" valign="top" style="padding: 0 8px;">
|
||||
<!-- Email Content -->
|
||||
<td align="center" valign="top" style="padding: 48px 16px 40px;">
|
||||
|
||||
<!-- ============ CARD ============ -->
|
||||
<table class="container" align="center" width="640" cellpadding="0" cellspacing="0" border="0"
|
||||
style="max-width: 640px;">
|
||||
|
||||
<!-- Gradient Accent Strip -->
|
||||
<tr>
|
||||
<td align="center">
|
||||
<!-- Logo Section -->
|
||||
<table class="container ml-4 ml-default-border" width="640" bgcolor="#E2ECFD" align="center" border="0"
|
||||
cellspacing="0" cellpadding="0" style="width: 640px; min-width: 640px;">
|
||||
<tr>
|
||||
<td class="ml-default-border container" height="40" style="line-height: 40px; min-width: 640px;">
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
|
||||
<tr>
|
||||
<td class="row" align="center" style="padding: 0 50px;">
|
||||
<img
|
||||
src="https://storage.mlcdn.com/account_image/597379/8QJ8kOjXakVvfe1kJLY2wWCObU1mp5EiDLfBlbQa.png"
|
||||
border="0" alt="" width="120" class="logo"
|
||||
style="max-width: 120px; display: inline-block;">
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
<td bgcolor="#7733F5"
|
||||
style="background: linear-gradient(90deg, #7733F5 0%, #60A5FA 35%, #EC4899 65%, #7733F5 100%); height: 6px; border-radius: 16px 16px 0 0; font-size: 0; line-height: 0;">
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<!-- Main Content Section -->
|
||||
<table class="container ml-6 ml-default-border" width="640" bgcolor="#E2ECFD" align="center" border="0"
|
||||
cellspacing="0" cellpadding="0" style="color: #070629; width: 640px; min-width: 640px;">
|
||||
<tr>
|
||||
<td class="row" style="padding: 0 50px;">
|
||||
{{data.message|safe}}
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
<!-- Logo -->
|
||||
<tr>
|
||||
<td bgcolor="#FFFFFF" align="center" class="card-inner" style="padding: 32px 48px 24px;">
|
||||
<img
|
||||
src="https://storage.mlcdn.com/account_image/597379/8QJ8kOjXakVvfe1kJLY2wWCObU1mp5EiDLfBlbQa.png"
|
||||
border="0" alt="AutoGPT" width="120"
|
||||
style="max-width: 120px; display: inline-block;">
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<!-- Signature Section -->
|
||||
<table class="container ml-8 ml-default-border" width="640" bgcolor="#E2ECFD" align="center" border="0"
|
||||
cellspacing="0" cellpadding="0" style="color: #070629; width: 640px; min-width: 640px;">
|
||||
<!-- Divider -->
|
||||
<tr>
|
||||
<td bgcolor="#FFFFFF" class="card-inner" style="padding: 0 48px;">
|
||||
<table width="100%" cellpadding="0" cellspacing="0" border="0">
|
||||
<tr>
|
||||
<td class="row mobile-center" align="left" style="padding: 0 50px;">
|
||||
<table class="ml-8 wrapper" border="0" cellspacing="0" cellpadding="0"
|
||||
style="color: #070629; text-align: left;">
|
||||
<tr>
|
||||
<td class="col center mobile-center" align>
|
||||
<p
|
||||
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-top: 0; margin-bottom: 0;">
|
||||
Thank you for being a part of the AutoGPT community! Join the conversation on our Discord <a href="https://discord.gg/autogpt" style="color: #4285F4; text-decoration: underline;">here</a> and share your thoughts with us anytime.
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<!-- Footer Section -->
|
||||
<table class="container ml-10 ml-default-border" width="640" bgcolor="#ffffff" align="center" border="0"
|
||||
cellspacing="0" cellpadding="0" style="width: 640px; min-width: 640px;">
|
||||
<tr>
|
||||
<td class="row" style="padding: 0 50px;">
|
||||
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
|
||||
<tr>
|
||||
<td>
|
||||
<!-- Footer Content -->
|
||||
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
|
||||
<tr>
|
||||
<td class="col" align="left" valign="middle" width="120">
|
||||
<img
|
||||
src="https://storage.mlcdn.com/account_image/597379/8QJ8kOjXakVvfe1kJLY2wWCObU1mp5EiDLfBlbQa.png"
|
||||
border="0" alt="" width="120" class="logo"
|
||||
style="max-width: 120px; display: inline-block;">
|
||||
</td>
|
||||
<td class="col" width="40" height="30" style="line-height: 30px;"></td>
|
||||
<td class="col mobile-left" align="right" valign="middle" width="250">
|
||||
<table role="presentation" cellpadding="0" cellspacing="0" border="0">
|
||||
<tr>
|
||||
<td align="center" valign="middle" width="18" style="padding: 0 5px 0 0;">
|
||||
<a href="https://x.com/auto_gpt" target="blank" style="text-decoration: none;">
|
||||
<img
|
||||
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/x.png"
|
||||
width="18" alt="x">
|
||||
</a>
|
||||
</td>
|
||||
<td align="center" valign="middle" width="18" style="padding: 0 5px;">
|
||||
<a href="https://discord.gg/autogpt" target="blank"
|
||||
style="text-decoration: none;">
|
||||
<img
|
||||
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/discord.png"
|
||||
width="18" alt="discord">
|
||||
</a>
|
||||
</td>
|
||||
<td align="center" valign="middle" width="18" style="padding: 0 0 0 5px;">
|
||||
<a href="https://agpt.co/" target="blank" style="text-decoration: none;">
|
||||
<img
|
||||
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/website.png"
|
||||
width="18" alt="website">
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center" style="text-align: left!important;">
|
||||
<h5
|
||||
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 15px; line-height: 125%; font-weight: bold; font-style: normal; text-decoration: none; margin-bottom: 6px;">
|
||||
AutoGPT
|
||||
</h5>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center" style="text-align: left!important;">
|
||||
<p
|
||||
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 14px; line-height: 150%; display: inline-block; margin-bottom: 0;">
|
||||
3rd Floor 1 Ashley Road, Cheshire, United Kingdom, WA14 2DT, Altrincham<br>United Kingdom
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td height="8" style="line-height: 8px;"></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left" style="text-align: left!important;">
|
||||
<p
|
||||
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 14px; line-height: 150%; display: inline-block; margin-bottom: 0;">
|
||||
You received this email because you signed up on our website.</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td height="1" style="line-height: 12px;"></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left">
|
||||
<p
|
||||
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 14px; line-height: 150%; display: inline-block; margin-bottom: 0;">
|
||||
<a href="{{data.unsubscribe_link}}"
|
||||
style="color: #4285F4; font-weight: normal; font-style: normal; text-decoration: underline;">Unsubscribe</a>
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td style="border-top: 1px solid #E8E5F0; font-size: 0; line-height: 0; height: 1px;"> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<!-- Content -->
|
||||
<tr>
|
||||
<td class="card-inner" bgcolor="#FFFFFF"
|
||||
style="padding: 36px 48px 44px; color: #1F1F20; font-family: 'Poppins', sans-serif; border-radius: 0 0 16px 16px;">
|
||||
{{data.message|safe}}
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
</table>
|
||||
<!-- ============ END CARD ============ -->
|
||||
|
||||
<!-- Spacer -->
|
||||
<table width="640" class="container" align="center" cellpadding="0" cellspacing="0" border="0"
|
||||
style="max-width: 640px;">
|
||||
<tr>
|
||||
<td style="height: 40px; font-size: 0; line-height: 0;"> </td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<!-- ============ FOOTER ============ -->
|
||||
<table class="container" align="center" width="640" cellpadding="0" cellspacing="0" border="0"
|
||||
style="max-width: 640px;">
|
||||
<tr>
|
||||
<td align="center" style="padding: 0 48px;">
|
||||
|
||||
<!-- Logo Text -->
|
||||
<p
|
||||
style="font-family: 'Poppins', sans-serif; font-size: 22px; font-weight: 700; color: #FFFFFF; margin: 0 0 16px; letter-spacing: -0.5px;">
|
||||
AutoGPT</p>
|
||||
|
||||
<!-- Social Icons -->
|
||||
<table role="presentation" cellpadding="0" cellspacing="0" border="0" align="center">
|
||||
<tr>
|
||||
<td align="center" valign="middle" style="padding: 0 8px;">
|
||||
<a href="https://x.com/auto_gpt" target="_blank" style="text-decoration: none;">
|
||||
<img
|
||||
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/white/x.png"
|
||||
width="22" alt="X" style="display: block; opacity: 0.6;">
|
||||
</a>
|
||||
</td>
|
||||
<td align="center" valign="middle" style="padding: 0 8px;">
|
||||
<a href="https://discord.gg/autogpt" target="_blank" style="text-decoration: none;">
|
||||
<img
|
||||
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/white/discord.png"
|
||||
width="22" alt="Discord" style="display: block; opacity: 0.6;">
|
||||
</a>
|
||||
</td>
|
||||
<td align="center" valign="middle" style="padding: 0 8px;">
|
||||
<a href="https://agpt.co/" target="_blank" style="text-decoration: none;">
|
||||
<img
|
||||
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/white/website.png"
|
||||
width="22" alt="Website" style="display: block; opacity: 0.6;">
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<!-- Spacer -->
|
||||
<table width="100%" cellpadding="0" cellspacing="0" border="0">
|
||||
<tr>
|
||||
<td style="height: 20px; font-size: 0; line-height: 0;"> </td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<!-- Divider -->
|
||||
<table width="100%" cellpadding="0" cellspacing="0" border="0">
|
||||
<tr>
|
||||
<td
|
||||
style="border-top: 1px solid rgba(255,255,255,0.08); font-size: 0; line-height: 0; height: 1px;">
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<!-- Footer Text -->
|
||||
<p
|
||||
style="font-family: 'Poppins', sans-serif; color: rgba(255,255,255,0.3); font-size: 12px; line-height: 165%; margin: 16px 0 0; text-align: center;">
|
||||
AutoGPT · 3rd Floor 1 Ashley Road, Altrincham, WA14 2DT, United Kingdom
|
||||
</p>
|
||||
<p
|
||||
style="font-family: 'Poppins', sans-serif; color: rgba(255,255,255,0.3); font-size: 12px; line-height: 165%; margin: 4px 0 0; text-align: center;">
|
||||
You received this email because you signed up on our website.
|
||||
<a href="{{data.unsubscribe_link}}"
|
||||
style="color: rgba(255,255,255,0.45); text-decoration: underline;">Unsubscribe</a>
|
||||
</p>
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
</html>
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
<div style="font-family: 'Poppins', sans-serif; color: #1F1F20;">
|
||||
{{ email_body_html|safe }}
|
||||
|
||||
{% if approval_summary_html %}
|
||||
<!-- Approval Callout -->
|
||||
<table width="100%" cellpadding="0" cellspacing="0" border="0" style="margin-top: 28px; margin-bottom: 8px;">
|
||||
<tr>
|
||||
<td bgcolor="#FFF3E6"
|
||||
style="background-color: #FFF3E6; border-left: 4px solid #FE8700; border-radius: 12px; padding: 20px 24px;">
|
||||
<table width="100%" cellpadding="0" cellspacing="0" border="0">
|
||||
<tr>
|
||||
<td style="padding: 0 0 12px 0;">
|
||||
<span
|
||||
style="display: inline-block; background-color: #FE8700; color: #FFFFFF; font-size: 11px; font-weight: 600; letter-spacing: 0.05em; text-transform: uppercase; padding: 4px 10px; border-radius: 999px;">
|
||||
Approval needed
|
||||
</span>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
{{ approval_summary_html|safe }}
|
||||
<p style="font-size: 14px; line-height: 165%; margin-top: 8px; margin-bottom: 0; color: #505057;">
|
||||
I thought this was a good idea. If you want it to happen, please hit approve.
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
{% endif %}
|
||||
|
||||
<!-- CTA Button -->
|
||||
<table cellpadding="0" cellspacing="0" border="0" style="margin-top: 32px;">
|
||||
<tr>
|
||||
<td align="center" bgcolor="#7733F5"
|
||||
style="background-color: #7733F5; border-radius: 12px;">
|
||||
<a href="{{ cta_url }}"
|
||||
style="display: inline-block; padding: 16px 36px; background-color: #7733F5; color: #FFFFFF; text-decoration: none; font-family: 'Poppins', sans-serif; font-weight: 600; font-size: 16px; border-radius: 12px; line-height: 1;">
|
||||
{{ cta_label }}
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
@@ -0,0 +1,58 @@
|
||||
<div style="font-family: 'Poppins', sans-serif; color: #1F1F20;">
|
||||
<!-- Header -->
|
||||
<h2
|
||||
style="font-size: 24px; line-height: 130%; font-weight: 700; margin-top: 0; margin-bottom: 8px; color: #1F1F20; letter-spacing: -0.5px;">
|
||||
Autopilot picked up where you left off
|
||||
</h2>
|
||||
<p style="font-size: 14px; line-height: 165%; margin-top: 0; margin-bottom: 28px; color: #505057;">
|
||||
We used your recent Copilot activity to prepare a concrete follow-up for you.
|
||||
</p>
|
||||
|
||||
<!-- Divider -->
|
||||
<table width="100%" cellpadding="0" cellspacing="0" border="0" style="margin-bottom: 28px;">
|
||||
<tr>
|
||||
<td style="border-top: 1px solid #E8E5F0; font-size: 0; line-height: 0; height: 1px;"> </td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<!-- Body -->
|
||||
{{ email_body_html|safe }}
|
||||
|
||||
{% if approval_summary_html %}
|
||||
<!-- Approval Callout -->
|
||||
<table width="100%" cellpadding="0" cellspacing="0" border="0" style="margin-top: 28px; margin-bottom: 8px;">
|
||||
<tr>
|
||||
<td bgcolor="#FFF3E6"
|
||||
style="background-color: #FFF3E6; border-left: 4px solid #FE8700; border-radius: 12px; padding: 20px 24px;">
|
||||
<table width="100%" cellpadding="0" cellspacing="0" border="0">
|
||||
<tr>
|
||||
<td style="padding: 0 0 12px 0;">
|
||||
<span
|
||||
style="display: inline-block; background-color: #FE8700; color: #FFFFFF; font-size: 11px; font-weight: 600; letter-spacing: 0.05em; text-transform: uppercase; padding: 4px 10px; border-radius: 999px;">
|
||||
Approval needed
|
||||
</span>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
{{ approval_summary_html|safe }}
|
||||
<p style="font-size: 14px; line-height: 165%; margin-top: 8px; margin-bottom: 0; color: #505057;">
|
||||
I thought this was a good idea. If you want it to happen, please hit approve.
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
{% endif %}
|
||||
|
||||
<!-- CTA Button -->
|
||||
<table cellpadding="0" cellspacing="0" border="0" style="margin-top: 32px;">
|
||||
<tr>
|
||||
<td align="center" bgcolor="#7733F5"
|
||||
style="background-color: #7733F5; border-radius: 12px;">
|
||||
<a href="{{ cta_url }}"
|
||||
style="display: inline-block; padding: 16px 36px; background-color: #7733F5; color: #FFFFFF; text-decoration: none; font-family: 'Poppins', sans-serif; font-weight: 600; font-size: 16px; border-radius: 12px; line-height: 1;">
|
||||
{{ cta_label }}
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
@@ -0,0 +1,64 @@
|
||||
<div style="font-family: 'Poppins', sans-serif; color: #1F1F20;">
|
||||
<!-- Header -->
|
||||
<h2
|
||||
style="font-size: 24px; line-height: 130%; font-weight: 700; margin-top: 0; margin-bottom: 8px; color: #1F1F20; letter-spacing: -0.5px;">
|
||||
Your Autopilot beta access is waiting
|
||||
</h2>
|
||||
<p style="font-size: 14px; line-height: 165%; margin-top: 0; margin-bottom: 28px; color: #505057;">
|
||||
You applied to try Autopilot. Here is a tailored example of how it can help once you jump back in.
|
||||
</p>
|
||||
|
||||
<!-- Highlight Card -->
|
||||
<table width="100%" cellpadding="0" cellspacing="0" border="0" style="margin-bottom: 28px;">
|
||||
<tr>
|
||||
<td bgcolor="#5424AE"
|
||||
style="background-color: #5424AE; border-radius: 12px; padding: 20px 24px;">
|
||||
<p
|
||||
style="font-size: 14px; line-height: 165%; margin-top: 0; margin-bottom: 0; color: #FFFFFF; font-weight: 500;">
|
||||
Autopilot works in the background to handle tasks, surface insights, and take action on your behalf.
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<!-- Body -->
|
||||
{{ email_body_html|safe }}
|
||||
|
||||
{% if approval_summary_html %}
|
||||
<!-- Approval Callout -->
|
||||
<table width="100%" cellpadding="0" cellspacing="0" border="0" style="margin-top: 28px; margin-bottom: 8px;">
|
||||
<tr>
|
||||
<td bgcolor="#FFF3E6"
|
||||
style="background-color: #FFF3E6; border-left: 4px solid #FE8700; border-radius: 12px; padding: 20px 24px;">
|
||||
<table width="100%" cellpadding="0" cellspacing="0" border="0">
|
||||
<tr>
|
||||
<td style="padding: 0 0 12px 0;">
|
||||
<span
|
||||
style="display: inline-block; background-color: #FE8700; color: #FFFFFF; font-size: 11px; font-weight: 600; letter-spacing: 0.05em; text-transform: uppercase; padding: 4px 10px; border-radius: 999px;">
|
||||
Approval needed
|
||||
</span>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
{{ approval_summary_html|safe }}
|
||||
<p style="font-size: 14px; line-height: 165%; margin-top: 8px; margin-bottom: 0; color: #505057;">
|
||||
I thought this was a good idea. If you want it to happen, please hit approve.
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
{% endif %}
|
||||
|
||||
<!-- CTA Button -->
|
||||
<table cellpadding="0" cellspacing="0" border="0" style="margin-top: 32px;">
|
||||
<tr>
|
||||
<td align="center" bgcolor="#7733F5"
|
||||
style="background-color: #7733F5; border-radius: 12px;">
|
||||
<a href="{{ cta_url }}"
|
||||
style="display: inline-block; padding: 16px 36px; background-color: #7733F5; color: #FFFFFF; text-decoration: none; font-family: 'Poppins', sans-serif; font-weight: 600; font-size: 16px; border-radius: 12px; line-height: 1;">
|
||||
{{ cta_label }}
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
@@ -39,6 +39,7 @@ class Flag(str, Enum):
|
||||
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
|
||||
CHAT = "chat"
|
||||
COPILOT_SDK = "copilot-sdk"
|
||||
NIGHTLY_COPILOT = "nightly-copilot"
|
||||
|
||||
|
||||
def is_configured() -> bool:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import date
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Generic, List, Set, Tuple, Type, TypeVar
|
||||
|
||||
@@ -125,6 +126,22 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default=True,
|
||||
description="If the invite-only signup gate is enforced",
|
||||
)
|
||||
nightly_copilot_callback_start_date: date = Field(
|
||||
default=date(2026, 2, 8),
|
||||
description="Users with sessions since this date are eligible for the one-off autopilot callback cohort.",
|
||||
)
|
||||
nightly_copilot_invite_cta_start_date: date = Field(
|
||||
default=date(2026, 3, 13),
|
||||
description="Invite CTA cohort does not run before this date.",
|
||||
)
|
||||
nightly_copilot_invite_cta_delay_hours: int = Field(
|
||||
default=48,
|
||||
description="Delay after invite creation before the invite CTA can run.",
|
||||
)
|
||||
nightly_copilot_callback_token_ttl_hours: int = Field(
|
||||
default=24 * 14,
|
||||
description="TTL for nightly copilot callback tokens.",
|
||||
)
|
||||
enable_credit: bool = Field(
|
||||
default=False,
|
||||
description="If user credit system is enabled or not",
|
||||
|
||||
@@ -86,9 +86,13 @@ class TextFormatter:
|
||||
"i",
|
||||
"img",
|
||||
"li",
|
||||
"ol",
|
||||
"p",
|
||||
"span",
|
||||
"strong",
|
||||
"table",
|
||||
"td",
|
||||
"tr",
|
||||
"u",
|
||||
"ul",
|
||||
]
|
||||
@@ -98,6 +102,15 @@ class TextFormatter:
|
||||
"*": ["class", "style"],
|
||||
"a": ["href"],
|
||||
"img": ["src"],
|
||||
"table": [
|
||||
"align",
|
||||
"border",
|
||||
"cellpadding",
|
||||
"cellspacing",
|
||||
"role",
|
||||
"width",
|
||||
],
|
||||
"td": ["align", "bgcolor", "colspan", "height", "valign", "width"],
|
||||
}
|
||||
|
||||
def format_string(self, template_str: str, values=None, **kwargs) -> str:
|
||||
|
||||
9
autogpt_platform/backend/backend/util/url.py
Normal file
9
autogpt_platform/backend/backend/util/url.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def get_frontend_base_url() -> str:
|
||||
return (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
).rstrip("/")
|
||||
@@ -0,0 +1,67 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "ChatSessionStartType" AS ENUM(
|
||||
'MANUAL',
|
||||
'AUTOPILOT_NIGHTLY',
|
||||
'AUTOPILOT_CALLBACK',
|
||||
'AUTOPILOT_INVITE_CTA'
|
||||
);
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "ChatSession"
|
||||
ADD COLUMN "startType" "ChatSessionStartType" NOT NULL DEFAULT 'MANUAL',
|
||||
ADD COLUMN "executionTag" TEXT,
|
||||
ADD COLUMN "sessionConfig" JSONB NOT NULL DEFAULT '{}',
|
||||
ADD COLUMN "completionReport" JSONB,
|
||||
ADD COLUMN "completionReportRepairCount" INTEGER NOT NULL DEFAULT 0,
|
||||
ADD COLUMN "completionReportRepairQueuedAt" TIMESTAMP(3),
|
||||
ADD COLUMN "completedAt" TIMESTAMP(3),
|
||||
ADD COLUMN "notificationEmailSentAt" TIMESTAMP(3),
|
||||
ADD COLUMN "notificationEmailSkippedAt" TIMESTAMP(3);
|
||||
|
||||
COMMENT ON COLUMN "ChatSession"."sessionConfig" IS 'Validated by backend.copilot.session_types.ChatSessionConfig';
|
||||
COMMENT ON COLUMN "ChatSession"."completionReport" IS 'Validated by backend.copilot.session_types.StoredCompletionReport';
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "ChatSessionCallbackToken"(
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"userId" TEXT NOT NULL,
|
||||
"sourceSessionId" TEXT,
|
||||
"callbackSessionMessage" TEXT NOT NULL,
|
||||
"expiresAt" TIMESTAMP(3) NOT NULL,
|
||||
"consumedAt" TIMESTAMP(3),
|
||||
"consumedSessionId" TEXT,
|
||||
CONSTRAINT "ChatSessionCallbackToken_pkey" PRIMARY KEY("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "ChatSession_userId_executionTag_key"
|
||||
ON "ChatSession"("userId",
|
||||
"executionTag");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ChatSession_userId_startType_updatedAt_idx"
|
||||
ON "ChatSession"("userId",
|
||||
"startType",
|
||||
"updatedAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ChatSessionCallbackToken_userId_expiresAt_idx"
|
||||
ON "ChatSessionCallbackToken"("userId",
|
||||
"expiresAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ChatSessionCallbackToken_consumedSessionId_idx"
|
||||
ON "ChatSessionCallbackToken"("consumedSessionId");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "ChatSessionCallbackToken" ADD CONSTRAINT "ChatSessionCallbackToken_userId_fkey" FOREIGN KEY("userId") REFERENCES "User"("id")
|
||||
ON DELETE CASCADE
|
||||
ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "ChatSessionCallbackToken" ADD CONSTRAINT "ChatSessionCallbackToken_sourceSessionId_fkey" FOREIGN KEY("sourceSessionId") REFERENCES "ChatSession"("id")
|
||||
ON DELETE
|
||||
CASCADE
|
||||
ON UPDATE CASCADE;
|
||||
@@ -37,6 +37,7 @@ jinja2 = "^3.1.6"
|
||||
jsonref = "^1.1.0"
|
||||
jsonschema = "^4.25.0"
|
||||
langfuse = "^3.14.1"
|
||||
markdown-it-py = "^3.0.0"
|
||||
launchdarkly-server-sdk = "^9.14.1"
|
||||
mem0ai = "^0.1.115"
|
||||
moviepy = "^2.1.2"
|
||||
|
||||
@@ -66,6 +66,7 @@ model User {
|
||||
PendingHumanReviews PendingHumanReview[]
|
||||
Workspace UserWorkspace?
|
||||
ClaimedInvite InvitedUser? @relation("InvitedUserAuthUser")
|
||||
ChatSessionCallbackTokens ChatSessionCallbackToken[]
|
||||
|
||||
// OAuth Provider relations
|
||||
OAuthApplications OAuthApplication[]
|
||||
@@ -87,6 +88,13 @@ enum TallyComputationStatus {
|
||||
FAILED
|
||||
}
|
||||
|
||||
enum ChatSessionStartType {
|
||||
MANUAL
|
||||
AUTOPILOT_NIGHTLY
|
||||
AUTOPILOT_CALLBACK
|
||||
AUTOPILOT_INVITE_CTA
|
||||
}
|
||||
|
||||
model InvitedUser {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
@@ -248,6 +256,15 @@ model ChatSession {
|
||||
// Session metadata
|
||||
title String?
|
||||
credentials Json @default("{}") // Map of provider -> credential metadata
|
||||
startType ChatSessionStartType @default(MANUAL)
|
||||
executionTag String?
|
||||
sessionConfig Json @default("{}") // ChatSessionConfig payload from backend.copilot.session_types.ChatSessionConfig
|
||||
completionReport Json? // StoredCompletionReport payload from backend.copilot.session_types.StoredCompletionReport
|
||||
completionReportRepairCount Int @default(0)
|
||||
completionReportRepairQueuedAt DateTime?
|
||||
completedAt DateTime?
|
||||
notificationEmailSentAt DateTime?
|
||||
notificationEmailSkippedAt DateTime?
|
||||
|
||||
// Rate limiting counters (stored as JSON maps)
|
||||
successfulAgentRuns Json @default("{}") // Map of graph_id -> count
|
||||
@@ -258,8 +275,31 @@ model ChatSession {
|
||||
totalCompletionTokens Int @default(0)
|
||||
|
||||
Messages ChatMessage[]
|
||||
CallbackTokens ChatSessionCallbackToken[] @relation("ChatSessionCallbackSource")
|
||||
|
||||
@@index([userId, updatedAt])
|
||||
@@index([userId, startType, updatedAt])
|
||||
@@unique([userId, executionTag])
|
||||
}
|
||||
|
||||
model ChatSessionCallbackToken {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
sourceSessionId String?
|
||||
SourceSession ChatSession? @relation("ChatSessionCallbackSource", fields: [sourceSessionId], references: [id], onDelete: Cascade)
|
||||
|
||||
callbackSessionMessage String
|
||||
expiresAt DateTime
|
||||
consumedAt DateTime?
|
||||
consumedSessionId String?
|
||||
|
||||
@@index([userId, expiresAt])
|
||||
@@index([consumedSessionId])
|
||||
}
|
||||
|
||||
model ChatMessage {
|
||||
|
||||
@@ -0,0 +1,318 @@
|
||||
"use client";
|
||||
|
||||
import { ChatSessionStartType } from "@/app/api/__generated__/models/chatSessionStartType";
|
||||
import { Badge } from "@/components/atoms/Badge/Badge";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Card } from "@/components/atoms/Card/Card";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { CopilotUsersTable } from "../CopilotUsersTable/CopilotUsersTable";
|
||||
import { useAdminCopilotPage } from "../../useAdminCopilotPage";
|
||||
|
||||
function getStartTypeLabel(startType: ChatSessionStartType) {
|
||||
if (startType === ChatSessionStartType.AUTOPILOT_INVITE_CTA) {
|
||||
return "CTA";
|
||||
}
|
||||
|
||||
if (startType === ChatSessionStartType.AUTOPILOT_NIGHTLY) {
|
||||
return "Nightly";
|
||||
}
|
||||
|
||||
if (startType === ChatSessionStartType.AUTOPILOT_CALLBACK) {
|
||||
return "Callback";
|
||||
}
|
||||
|
||||
return startType;
|
||||
}
|
||||
|
||||
const triggerOptions = [
|
||||
{
|
||||
label: "Trigger CTA",
|
||||
description:
|
||||
"Runs the beta invite CTA flow even if the user would not normally qualify.",
|
||||
startType: ChatSessionStartType.AUTOPILOT_INVITE_CTA,
|
||||
variant: "primary" as const,
|
||||
},
|
||||
{
|
||||
label: "Trigger Nightly",
|
||||
description:
|
||||
"Runs the nightly proactive Autopilot flow immediately for the selected user.",
|
||||
startType: ChatSessionStartType.AUTOPILOT_NIGHTLY,
|
||||
variant: "outline" as const,
|
||||
},
|
||||
{
|
||||
label: "Trigger Callback",
|
||||
description:
|
||||
"Runs the callback re-engagement flow without checking the normal callback cohort.",
|
||||
startType: ChatSessionStartType.AUTOPILOT_CALLBACK,
|
||||
variant: "secondary" as const,
|
||||
},
|
||||
];
|
||||
|
||||
export function AdminCopilotPage() {
|
||||
const {
|
||||
search,
|
||||
selectedUser,
|
||||
pendingTriggerType,
|
||||
lastTriggeredSession,
|
||||
lastEmailSweepResult,
|
||||
searchedUsers,
|
||||
searchErrorMessage,
|
||||
isSearchingUsers,
|
||||
isRefreshingUsers,
|
||||
isTriggeringSession,
|
||||
isSendingPendingEmails,
|
||||
hasSearch,
|
||||
setSearch,
|
||||
handleSelectUser,
|
||||
handleSendPendingEmails,
|
||||
handleTriggerSession,
|
||||
} = useAdminCopilotPage();
|
||||
|
||||
return (
|
||||
<div className="mx-auto flex max-w-7xl flex-col gap-6 p-6">
|
||||
<div className="flex flex-col gap-2">
|
||||
<h1 className="text-3xl font-bold text-zinc-900">Copilot</h1>
|
||||
<p className="max-w-3xl text-sm text-zinc-600">
|
||||
Manually create CTA, Nightly, or Callback Copilot sessions for a
|
||||
specific user. These controls bypass the normal eligibility checks so
|
||||
you can test each flow directly.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-6 xl:grid-cols-[minmax(0,1.35fr),24rem]">
|
||||
<Card className="border border-zinc-200 shadow-sm">
|
||||
<div className="flex flex-col gap-4">
|
||||
<Input
|
||||
id="copilot-user-search"
|
||||
label="Search users"
|
||||
hint="Results update as you type"
|
||||
placeholder="Search by email, name, or user ID"
|
||||
value={search}
|
||||
onChange={(event) => setSearch(event.target.value)}
|
||||
/>
|
||||
{searchErrorMessage ? (
|
||||
<p className="-mt-2 text-sm text-red-500">{searchErrorMessage}</p>
|
||||
) : null}
|
||||
<CopilotUsersTable
|
||||
users={searchedUsers}
|
||||
isLoading={isSearchingUsers}
|
||||
isRefreshing={isRefreshingUsers}
|
||||
hasSearch={hasSearch}
|
||||
selectedUserId={selectedUser?.id ?? null}
|
||||
onSelectUser={handleSelectUser}
|
||||
/>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
<div className="flex flex-col gap-6">
|
||||
<Card className="border border-zinc-200 shadow-sm">
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex items-center justify-between gap-3">
|
||||
<h2 className="text-xl font-semibold text-zinc-900">
|
||||
Selected user
|
||||
</h2>
|
||||
{selectedUser ? <Badge variant="info">Ready</Badge> : null}
|
||||
</div>
|
||||
|
||||
{selectedUser ? (
|
||||
<div className="flex flex-col gap-3 text-sm text-zinc-600">
|
||||
<div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
Email
|
||||
</span>
|
||||
<p className="mt-1 font-medium text-zinc-900">
|
||||
{selectedUser.email}
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
Name
|
||||
</span>
|
||||
<p className="mt-1">
|
||||
{selectedUser.name || "No display name"}
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
Timezone
|
||||
</span>
|
||||
<p className="mt-1">{selectedUser.timezone}</p>
|
||||
</div>
|
||||
<div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
User ID
|
||||
</span>
|
||||
<p className="mt-1 break-all font-mono text-xs text-zinc-500">
|
||||
{selectedUser.id}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<p className="text-sm text-zinc-500">
|
||||
Select a user from the results table to enable manual Copilot
|
||||
triggers.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
<Card className="border border-zinc-200 shadow-sm">
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex flex-col gap-1">
|
||||
<h2 className="text-xl font-semibold text-zinc-900">
|
||||
Trigger flows
|
||||
</h2>
|
||||
<p className="text-sm text-zinc-600">
|
||||
Each action creates a new session immediately for the selected
|
||||
user.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col gap-3">
|
||||
{triggerOptions.map((option) => (
|
||||
<div
|
||||
key={option.startType}
|
||||
className="rounded-2xl border border-zinc-200 p-4"
|
||||
>
|
||||
<div className="flex flex-col gap-3">
|
||||
<div className="flex flex-col gap-1">
|
||||
<span className="font-medium text-zinc-900">
|
||||
{option.label}
|
||||
</span>
|
||||
<p className="text-sm text-zinc-600">
|
||||
{option.description}
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
variant={option.variant}
|
||||
disabled={!selectedUser || isTriggeringSession}
|
||||
loading={pendingTriggerType === option.startType}
|
||||
onClick={() => handleTriggerSession(option.startType)}
|
||||
>
|
||||
{option.label}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
<Card className="border border-zinc-200 shadow-sm">
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex flex-col gap-1">
|
||||
<h2 className="text-xl font-semibold text-zinc-900">
|
||||
Email follow-up
|
||||
</h2>
|
||||
<p className="text-sm text-zinc-600">
|
||||
Run the pending Copilot completion-email sweep immediately for
|
||||
the selected user.
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
variant="secondary"
|
||||
disabled={!selectedUser || isSendingPendingEmails}
|
||||
loading={isSendingPendingEmails}
|
||||
onClick={handleSendPendingEmails}
|
||||
>
|
||||
Send pending emails
|
||||
</Button>
|
||||
{selectedUser && lastEmailSweepResult ? (
|
||||
<div className="rounded-2xl border border-zinc-200 p-4 text-sm text-zinc-600">
|
||||
<p className="font-medium text-zinc-900">
|
||||
Last sweep for {selectedUser.email}
|
||||
</p>
|
||||
<div className="mt-3 grid gap-3 sm:grid-cols-2">
|
||||
<div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
Candidates
|
||||
</span>
|
||||
<p className="mt-1 text-zinc-900">
|
||||
{lastEmailSweepResult.candidate_count}
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
Processed
|
||||
</span>
|
||||
<p className="mt-1 text-zinc-900">
|
||||
{lastEmailSweepResult.processed_count}
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
Sent
|
||||
</span>
|
||||
<p className="mt-1 text-zinc-900">
|
||||
{lastEmailSweepResult.sent_count}
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
Skipped
|
||||
</span>
|
||||
<p className="mt-1 text-zinc-900">
|
||||
{lastEmailSweepResult.skipped_count}
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
Repairs queued
|
||||
</span>
|
||||
<p className="mt-1 text-zinc-900">
|
||||
{lastEmailSweepResult.repair_queued_count}
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
Running / failed
|
||||
</span>
|
||||
<p className="mt-1 text-zinc-900">
|
||||
{lastEmailSweepResult.running_count} /{" "}
|
||||
{lastEmailSweepResult.failed_count}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
{selectedUser && lastTriggeredSession ? (
|
||||
<Card className="border border-zinc-200 shadow-sm">
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex items-center justify-between gap-3">
|
||||
<h2 className="text-xl font-semibold text-zinc-900">
|
||||
Latest session
|
||||
</h2>
|
||||
<Badge variant="success">
|
||||
{getStartTypeLabel(lastTriggeredSession.start_type)}
|
||||
</Badge>
|
||||
</div>
|
||||
<p className="text-sm text-zinc-600">
|
||||
A new Copilot session was created for {selectedUser.email}.
|
||||
</p>
|
||||
<div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
Session ID
|
||||
</span>
|
||||
<p className="mt-1 break-all font-mono text-xs text-zinc-500">
|
||||
{lastTriggeredSession.session_id}
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
as="NextLink"
|
||||
href={`/copilot?sessionId=${lastTriggeredSession.session_id}&showAutopilot=1`}
|
||||
target="_blank"
|
||||
rel="noreferrer"
|
||||
>
|
||||
Open session
|
||||
</Button>
|
||||
</div>
|
||||
</Card>
|
||||
) : null}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
"use client";
|
||||
|
||||
import type { AdminCopilotUserSummary } from "@/app/api/__generated__/models/adminCopilotUserSummary";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/__legacy__/ui/table";
|
||||
|
||||
interface Props {
|
||||
users: AdminCopilotUserSummary[];
|
||||
isLoading: boolean;
|
||||
isRefreshing: boolean;
|
||||
hasSearch: boolean;
|
||||
selectedUserId: string | null;
|
||||
onSelectUser: (user: AdminCopilotUserSummary) => void;
|
||||
}
|
||||
|
||||
function formatDate(value: Date) {
|
||||
return value.toLocaleString();
|
||||
}
|
||||
|
||||
export function CopilotUsersTable({
|
||||
users,
|
||||
isLoading,
|
||||
isRefreshing,
|
||||
hasSearch,
|
||||
selectedUserId,
|
||||
onSelectUser,
|
||||
}: Props) {
|
||||
let emptyMessage = "Search by email, name, or user ID to find a user.";
|
||||
if (hasSearch && isLoading) {
|
||||
emptyMessage = "Searching users...";
|
||||
} else if (hasSearch) {
|
||||
emptyMessage = "No matching users found.";
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex items-center justify-between gap-4">
|
||||
<div className="flex flex-col gap-1">
|
||||
<h2 className="text-xl font-semibold text-zinc-900">User results</h2>
|
||||
<p className="text-sm text-zinc-600">
|
||||
Select an existing user, then run an Autopilot flow manually.
|
||||
</p>
|
||||
</div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
{isRefreshing
|
||||
? "Refreshing"
|
||||
: `${users.length} result${users.length === 1 ? "" : "s"}`}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div className="overflow-hidden rounded-2xl border border-zinc-200">
|
||||
<Table>
|
||||
<TableHeader className="bg-zinc-50">
|
||||
<TableRow>
|
||||
<TableHead>User</TableHead>
|
||||
<TableHead>Timezone</TableHead>
|
||||
<TableHead>Updated</TableHead>
|
||||
<TableHead className="text-right">Action</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{users.length === 0 ? (
|
||||
<TableRow>
|
||||
<TableCell
|
||||
colSpan={4}
|
||||
className="py-10 text-center text-zinc-500"
|
||||
>
|
||||
{emptyMessage}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
) : (
|
||||
users.map((user) => (
|
||||
<TableRow key={user.id} className="align-top">
|
||||
<TableCell>
|
||||
<div className="flex flex-col gap-1">
|
||||
<span className="font-medium text-zinc-900">
|
||||
{user.email}
|
||||
</span>
|
||||
<span className="text-sm text-zinc-600">
|
||||
{user.name || "No display name"}
|
||||
</span>
|
||||
<span className="font-mono text-xs text-zinc-400">
|
||||
{user.id}
|
||||
</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell className="text-sm text-zinc-600">
|
||||
{user.timezone}
|
||||
</TableCell>
|
||||
<TableCell className="text-sm text-zinc-600">
|
||||
{formatDate(user.updated_at)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex justify-end">
|
||||
<Button
|
||||
variant={
|
||||
user.id === selectedUserId ? "secondary" : "outline"
|
||||
}
|
||||
size="small"
|
||||
onClick={() => onSelectUser(user)}
|
||||
>
|
||||
{user.id === selectedUserId ? "Selected" : "Select"}
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))
|
||||
)}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
import { withRoleAccess } from "@/lib/withRoleAccess";
|
||||
import { AdminCopilotPage } from "./components/AdminCopilotPage/AdminCopilotPage";
|
||||
|
||||
function AdminCopilot() {
|
||||
return <AdminCopilotPage />;
|
||||
}
|
||||
|
||||
export default async function AdminCopilotRoute() {
|
||||
"use server";
|
||||
const withAdminAccess = await withRoleAccess(["admin"]);
|
||||
const ProtectedAdminCopilot = await withAdminAccess(AdminCopilot);
|
||||
return <ProtectedAdminCopilot />;
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
"use client";
|
||||
|
||||
import type { AdminCopilotUserSummary } from "@/app/api/__generated__/models/adminCopilotUserSummary";
|
||||
import { ChatSessionStartType } from "@/app/api/__generated__/models/chatSessionStartType";
|
||||
import type { SendCopilotEmailsResponse } from "@/app/api/__generated__/models/sendCopilotEmailsResponse";
|
||||
import type { TriggerCopilotSessionResponse } from "@/app/api/__generated__/models/triggerCopilotSessionResponse";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { customMutator } from "@/app/api/mutators/custom-mutator";
|
||||
import {
|
||||
useGetV2SearchCopilotUsers,
|
||||
usePostV2TriggerCopilotSession,
|
||||
} from "@/app/api/__generated__/endpoints/admin/admin";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { ApiError } from "@/lib/autogpt-server-api/helpers";
|
||||
import { useMutation } from "@tanstack/react-query";
|
||||
import { useDeferredValue, useState } from "react";
|
||||
|
||||
function getErrorMessage(error: unknown) {
|
||||
if (error instanceof ApiError) {
|
||||
if (
|
||||
typeof error.response === "object" &&
|
||||
error.response !== null &&
|
||||
"detail" in error.response &&
|
||||
typeof error.response.detail === "string"
|
||||
) {
|
||||
return error.response.detail;
|
||||
}
|
||||
|
||||
return error.message;
|
||||
}
|
||||
|
||||
if (error instanceof Error) {
|
||||
return error.message;
|
||||
}
|
||||
|
||||
return "Something went wrong";
|
||||
}
|
||||
|
||||
export function useAdminCopilotPage() {
|
||||
const { toast } = useToast();
|
||||
const [search, setSearch] = useState("");
|
||||
const [selectedUser, setSelectedUser] =
|
||||
useState<AdminCopilotUserSummary | null>(null);
|
||||
const [pendingTriggerType, setPendingTriggerType] =
|
||||
useState<ChatSessionStartType | null>(null);
|
||||
const [lastTriggeredSession, setLastTriggeredSession] =
|
||||
useState<TriggerCopilotSessionResponse | null>(null);
|
||||
const [lastEmailSweepResult, setLastEmailSweepResult] =
|
||||
useState<SendCopilotEmailsResponse | null>(null);
|
||||
|
||||
const deferredSearch = useDeferredValue(search);
|
||||
const normalizedSearch = deferredSearch.trim();
|
||||
|
||||
const searchUsersQuery = useGetV2SearchCopilotUsers(
|
||||
normalizedSearch ? { search: normalizedSearch, limit: 20 } : undefined,
|
||||
{
|
||||
query: {
|
||||
enabled: normalizedSearch.length > 0,
|
||||
select: okData,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
const triggerCopilotSessionMutation = usePostV2TriggerCopilotSession({
|
||||
mutation: {
|
||||
onSuccess: (response) => {
|
||||
setPendingTriggerType(null);
|
||||
const session = okData(response) ?? null;
|
||||
setLastTriggeredSession(session);
|
||||
toast({
|
||||
title: "Copilot session created",
|
||||
variant: "default",
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
setPendingTriggerType(null);
|
||||
toast({
|
||||
title: getErrorMessage(error),
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const sendPendingCopilotEmailsMutation = useMutation({
|
||||
mutationKey: ["sendPendingCopilotEmails"],
|
||||
mutationFn: async (userId: string) =>
|
||||
customMutator<{
|
||||
data: SendCopilotEmailsResponse;
|
||||
status: number;
|
||||
headers: Headers;
|
||||
}>("/api/users/admin/copilot/send-emails", {
|
||||
method: "POST",
|
||||
body: JSON.stringify({ user_id: userId }),
|
||||
}),
|
||||
onSuccess: (response) => {
|
||||
const result = okData(response) ?? null;
|
||||
setLastEmailSweepResult(result);
|
||||
if (!result) {
|
||||
toast({
|
||||
title: "Email sweep completed",
|
||||
variant: "default",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
toast({
|
||||
title:
|
||||
result.sent_count > 0
|
||||
? `Sent ${result.sent_count} Copilot email${result.sent_count === 1 ? "" : "s"}`
|
||||
: "Email sweep completed",
|
||||
description: [
|
||||
`${result.candidate_count} candidate${result.candidate_count === 1 ? "" : "s"}`,
|
||||
`${result.sent_count} sent`,
|
||||
`${result.skipped_count} skipped`,
|
||||
`${result.repair_queued_count} repairs queued`,
|
||||
`${result.running_count} still running`,
|
||||
`${result.failed_count} failed`,
|
||||
].join(" • "),
|
||||
variant: "default",
|
||||
});
|
||||
},
|
||||
onError: (error: unknown) => {
|
||||
toast({
|
||||
title: getErrorMessage(error),
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
function handleSelectUser(user: AdminCopilotUserSummary) {
|
||||
setSelectedUser(user);
|
||||
setLastTriggeredSession(null);
|
||||
setLastEmailSweepResult(null);
|
||||
}
|
||||
|
||||
function handleTriggerSession(startType: ChatSessionStartType) {
|
||||
if (!selectedUser) {
|
||||
return;
|
||||
}
|
||||
|
||||
setPendingTriggerType(startType);
|
||||
setLastTriggeredSession(null);
|
||||
triggerCopilotSessionMutation.mutate({
|
||||
data: {
|
||||
user_id: selectedUser.id,
|
||||
start_type: startType,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function handleSendPendingEmails() {
|
||||
if (!selectedUser) {
|
||||
return;
|
||||
}
|
||||
|
||||
setLastEmailSweepResult(null);
|
||||
sendPendingCopilotEmailsMutation.mutate(selectedUser.id);
|
||||
}
|
||||
|
||||
return {
|
||||
search,
|
||||
selectedUser,
|
||||
pendingTriggerType,
|
||||
lastTriggeredSession,
|
||||
lastEmailSweepResult,
|
||||
searchedUsers: searchUsersQuery.data?.users ?? [],
|
||||
searchErrorMessage: searchUsersQuery.error
|
||||
? getErrorMessage(searchUsersQuery.error)
|
||||
: null,
|
||||
isSearchingUsers: searchUsersQuery.isLoading,
|
||||
isRefreshingUsers:
|
||||
searchUsersQuery.isFetching && !searchUsersQuery.isLoading,
|
||||
isTriggeringSession: triggerCopilotSessionMutation.isPending,
|
||||
isSendingPendingEmails: sendPendingCopilotEmailsMutation.isPending,
|
||||
hasSearch: normalizedSearch.length > 0,
|
||||
setSearch,
|
||||
handleSelectUser,
|
||||
handleTriggerSession,
|
||||
handleSendPendingEmails,
|
||||
};
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
MagnifyingGlassIcon,
|
||||
FileTextIcon,
|
||||
SlidersHorizontalIcon,
|
||||
LightningIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
|
||||
const sidebarLinkGroups = [
|
||||
@@ -33,6 +34,11 @@ const sidebarLinkGroups = [
|
||||
href: "/admin/impersonation",
|
||||
icon: <MagnifyingGlassIcon size={24} />,
|
||||
},
|
||||
{
|
||||
text: "Copilot",
|
||||
href: "/admin/copilot",
|
||||
icon: <LightningIcon size={24} />,
|
||||
},
|
||||
{
|
||||
text: "Execution Analytics",
|
||||
href: "/admin/execution-analytics",
|
||||
|
||||
@@ -23,25 +23,22 @@ import {
|
||||
useSidebar,
|
||||
} from "@/components/ui/sidebar";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
CheckCircle,
|
||||
DotsThree,
|
||||
PlusCircleIcon,
|
||||
PlusIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { DotsThree, PlusCircleIcon, PlusIcon } from "@phosphor-icons/react";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import { motion } from "framer-motion";
|
||||
import { parseAsString, useQueryState } from "nuqs";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { getSessionListParams } from "../../helpers";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
import { SessionListItem } from "../SessionListItem/SessionListItem";
|
||||
import { NotificationToggle } from "./components/NotificationToggle/NotificationToggle";
|
||||
import { DeleteChatDialog } from "../DeleteChatDialog/DeleteChatDialog";
|
||||
import { PulseLoader } from "../PulseLoader/PulseLoader";
|
||||
|
||||
export function ChatSidebar() {
|
||||
const { state } = useSidebar();
|
||||
const isCollapsed = state === "collapsed";
|
||||
const [sessionId, setSessionId] = useQueryState("sessionId", parseAsString);
|
||||
const listSessionsParams = getSessionListParams();
|
||||
const {
|
||||
sessionToDelete,
|
||||
setSessionToDelete,
|
||||
@@ -52,7 +49,9 @@ export function ChatSidebar() {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
const { data: sessionsResponse, isLoading: isLoadingSessions } =
|
||||
useGetV2ListSessions({ limit: 50 }, { query: { refetchInterval: 10_000 } });
|
||||
useGetV2ListSessions(listSessionsParams, {
|
||||
query: { refetchInterval: 10_000 },
|
||||
});
|
||||
|
||||
const { mutate: deleteSession, isPending: isDeleting } =
|
||||
useDeleteV2DeleteSession({
|
||||
@@ -180,31 +179,6 @@ export function ChatSidebar() {
|
||||
}
|
||||
}
|
||||
|
||||
function formatDate(dateString: string) {
|
||||
const date = new Date(dateString);
|
||||
const now = new Date();
|
||||
const diffMs = now.getTime() - date.getTime();
|
||||
const diffDays = Math.floor(diffMs / (1000 * 60 * 60 * 24));
|
||||
|
||||
if (diffDays === 0) return "Today";
|
||||
if (diffDays === 1) return "Yesterday";
|
||||
if (diffDays < 7) return `${diffDays} days ago`;
|
||||
|
||||
const day = date.getDate();
|
||||
const ordinal =
|
||||
day % 10 === 1 && day !== 11
|
||||
? "st"
|
||||
: day % 10 === 2 && day !== 12
|
||||
? "nd"
|
||||
: day % 10 === 3 && day !== 13
|
||||
? "rd"
|
||||
: "th";
|
||||
const month = date.toLocaleDateString("en-US", { month: "short" });
|
||||
const year = date.getFullYear();
|
||||
|
||||
return `${day}${ordinal} ${month} ${year}`;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Sidebar
|
||||
@@ -295,17 +269,17 @@ export function ChatSidebar() {
|
||||
No conversations yet
|
||||
</p>
|
||||
) : (
|
||||
sessions.map((session) => (
|
||||
<div
|
||||
key={session.id}
|
||||
className={cn(
|
||||
"group relative w-full rounded-lg transition-colors",
|
||||
session.id === sessionId
|
||||
? "bg-zinc-100"
|
||||
: "hover:bg-zinc-50",
|
||||
)}
|
||||
>
|
||||
{editingSessionId === session.id ? (
|
||||
sessions.map((session) =>
|
||||
editingSessionId === session.id ? (
|
||||
<div
|
||||
key={session.id}
|
||||
className={cn(
|
||||
"group relative w-full rounded-lg transition-colors",
|
||||
session.id === sessionId
|
||||
? "bg-zinc-100"
|
||||
: "hover:bg-zinc-50",
|
||||
)}
|
||||
>
|
||||
<div className="px-3 py-2.5">
|
||||
<input
|
||||
ref={renameInputRef}
|
||||
@@ -331,87 +305,49 @@ export function ChatSidebar() {
|
||||
className="w-full rounded border border-zinc-300 bg-white px-2 py-1 text-sm text-zinc-800 outline-none focus:border-purple-500 focus:ring-1 focus:ring-purple-500"
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<button
|
||||
onClick={() => handleSelectSession(session.id)}
|
||||
className="w-full px-3 py-2.5 pr-10 text-left"
|
||||
>
|
||||
<div className="flex min-w-0 max-w-full items-center gap-2">
|
||||
<div className="min-w-0 flex-1">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"truncate font-normal",
|
||||
session.id === sessionId
|
||||
? "text-zinc-600"
|
||||
: "text-zinc-800",
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<SessionListItem
|
||||
key={session.id}
|
||||
session={session}
|
||||
currentSessionId={sessionId}
|
||||
isCompleted={completedSessionIDs.has(session.id)}
|
||||
onSelect={handleSelectSession}
|
||||
variant="sidebar"
|
||||
actionSlot={
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<button
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
className="rounded-full p-1.5 text-zinc-600 transition-all hover:bg-neutral-100"
|
||||
aria-label="More actions"
|
||||
>
|
||||
<AnimatePresence mode="wait" initial={false}>
|
||||
<motion.span
|
||||
key={session.title || "untitled"}
|
||||
initial={{ opacity: 0, y: 4 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
exit={{ opacity: 0, y: -4 }}
|
||||
transition={{ duration: 0.2 }}
|
||||
className="block truncate"
|
||||
>
|
||||
{session.title || "Untitled chat"}
|
||||
</motion.span>
|
||||
</AnimatePresence>
|
||||
</Text>
|
||||
<Text variant="small" className="text-neutral-400">
|
||||
{formatDate(session.updated_at)}
|
||||
</Text>
|
||||
</div>
|
||||
{session.is_processing &&
|
||||
session.id !== sessionId &&
|
||||
!completedSessionIDs.has(session.id) && (
|
||||
<PulseLoader size={16} className="shrink-0" />
|
||||
)}
|
||||
{completedSessionIDs.has(session.id) &&
|
||||
session.id !== sessionId && (
|
||||
<CheckCircle
|
||||
className="h-4 w-4 shrink-0 text-green-500"
|
||||
weight="fill"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</button>
|
||||
)}
|
||||
{editingSessionId !== session.id && (
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<button
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
className="absolute right-2 top-1/2 -translate-y-1/2 rounded-full p-1.5 text-zinc-600 transition-all hover:bg-neutral-100"
|
||||
aria-label="More actions"
|
||||
>
|
||||
<DotsThree className="h-4 w-4" />
|
||||
</button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem
|
||||
onClick={(e) =>
|
||||
handleRenameClick(e, session.id, session.title)
|
||||
}
|
||||
>
|
||||
Rename
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem
|
||||
onClick={(e) =>
|
||||
handleDeleteClick(e, session.id, session.title)
|
||||
}
|
||||
disabled={isDeleting}
|
||||
className="text-red-600 focus:bg-red-50 focus:text-red-600"
|
||||
>
|
||||
Delete chat
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
)}
|
||||
</div>
|
||||
))
|
||||
<DotsThree className="h-4 w-4" />
|
||||
</button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem
|
||||
onClick={(e) =>
|
||||
handleRenameClick(e, session.id, session.title)
|
||||
}
|
||||
>
|
||||
Rename
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem
|
||||
onClick={(e) =>
|
||||
handleDeleteClick(e, session.id, session.title)
|
||||
}
|
||||
disabled={isDeleting}
|
||||
className="text-red-600 focus:bg-red-50 focus:text-red-600"
|
||||
>
|
||||
Delete chat
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
}
|
||||
/>
|
||||
),
|
||||
)
|
||||
)}
|
||||
</motion.div>
|
||||
)}
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { scrollbarStyles } from "@/components/styles/scrollbars";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
CheckCircle,
|
||||
PlusIcon,
|
||||
SpeakerHigh,
|
||||
SpeakerSlash,
|
||||
@@ -12,8 +9,9 @@ import {
|
||||
X,
|
||||
} from "@phosphor-icons/react";
|
||||
import { Drawer } from "vaul";
|
||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
import { PulseLoader } from "../PulseLoader/PulseLoader";
|
||||
import { SessionListItem } from "../SessionListItem/SessionListItem";
|
||||
|
||||
interface Props {
|
||||
isOpen: boolean;
|
||||
@@ -26,31 +24,6 @@ interface Props {
|
||||
onOpenChange: (open: boolean) => void;
|
||||
}
|
||||
|
||||
function formatDate(dateString: string) {
|
||||
const date = new Date(dateString);
|
||||
const now = new Date();
|
||||
const diffMs = now.getTime() - date.getTime();
|
||||
const diffDays = Math.floor(diffMs / (1000 * 60 * 60 * 24));
|
||||
|
||||
if (diffDays === 0) return "Today";
|
||||
if (diffDays === 1) return "Yesterday";
|
||||
if (diffDays < 7) return `${diffDays} days ago`;
|
||||
|
||||
const day = date.getDate();
|
||||
const ordinal =
|
||||
day % 10 === 1 && day !== 11
|
||||
? "st"
|
||||
: day % 10 === 2 && day !== 12
|
||||
? "nd"
|
||||
: day % 10 === 3 && day !== 13
|
||||
? "rd"
|
||||
: "th";
|
||||
const month = date.toLocaleDateString("en-US", { month: "short" });
|
||||
const year = date.getFullYear();
|
||||
|
||||
return `${day}${ordinal} ${month} ${year}`;
|
||||
}
|
||||
|
||||
export function MobileDrawer({
|
||||
isOpen,
|
||||
sessions,
|
||||
@@ -134,52 +107,19 @@ export function MobileDrawer({
|
||||
</p>
|
||||
) : (
|
||||
sessions.map((session) => (
|
||||
<button
|
||||
<SessionListItem
|
||||
key={session.id}
|
||||
onClick={() => {
|
||||
onSelectSession(session.id);
|
||||
if (completedSessionIDs.has(session.id)) {
|
||||
clearCompletedSession(session.id);
|
||||
session={session}
|
||||
currentSessionId={currentSessionId}
|
||||
isCompleted={completedSessionIDs.has(session.id)}
|
||||
variant="drawer"
|
||||
onSelect={(selectedSessionId) => {
|
||||
onSelectSession(selectedSessionId);
|
||||
if (completedSessionIDs.has(selectedSessionId)) {
|
||||
clearCompletedSession(selectedSessionId);
|
||||
}
|
||||
}}
|
||||
className={cn(
|
||||
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
|
||||
session.id === currentSessionId
|
||||
? "bg-zinc-100"
|
||||
: "hover:bg-zinc-50",
|
||||
)}
|
||||
>
|
||||
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
|
||||
<div className="flex min-w-0 max-w-full items-center gap-1.5">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"truncate font-normal",
|
||||
session.id === currentSessionId
|
||||
? "text-zinc-600"
|
||||
: "text-zinc-800",
|
||||
)}
|
||||
>
|
||||
{session.title || "Untitled chat"}
|
||||
</Text>
|
||||
{session.is_processing &&
|
||||
!completedSessionIDs.has(session.id) &&
|
||||
session.id !== currentSessionId && (
|
||||
<PulseLoader size={8} className="shrink-0" />
|
||||
)}
|
||||
{completedSessionIDs.has(session.id) &&
|
||||
session.id !== currentSessionId && (
|
||||
<CheckCircle
|
||||
className="h-4 w-4 shrink-0 text-green-500"
|
||||
weight="fill"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<Text variant="small" className="text-neutral-400">
|
||||
{formatDate(session.updated_at)}
|
||||
</Text>
|
||||
</div>
|
||||
</button>
|
||||
/>
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
"use client";
|
||||
|
||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||
import { Badge } from "@/components/atoms/Badge/Badge";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { CheckCircle } from "@phosphor-icons/react";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import type { ReactNode } from "react";
|
||||
import {
|
||||
formatSessionDate,
|
||||
getSessionStartTypeLabel,
|
||||
isNonManualSessionStartType,
|
||||
} from "../../helpers";
|
||||
import { PulseLoader } from "../PulseLoader/PulseLoader";
|
||||
|
||||
interface Props {
|
||||
actionSlot?: ReactNode;
|
||||
currentSessionId: string | null;
|
||||
isCompleted: boolean;
|
||||
onSelect: (sessionId: string) => void;
|
||||
session: SessionSummaryResponse;
|
||||
variant?: "sidebar" | "drawer";
|
||||
}
|
||||
|
||||
export function SessionListItem({
|
||||
actionSlot,
|
||||
currentSessionId,
|
||||
isCompleted,
|
||||
onSelect,
|
||||
session,
|
||||
variant = "sidebar",
|
||||
}: Props) {
|
||||
const isActive = session.id === currentSessionId;
|
||||
const showProcessing = session.is_processing && !isCompleted && !isActive;
|
||||
const showCompleted = isCompleted && !isActive;
|
||||
const startTypeLabel = isNonManualSessionStartType(session.start_type)
|
||||
? getSessionStartTypeLabel(session.start_type)
|
||||
: null;
|
||||
|
||||
if (variant === "drawer") {
|
||||
return (
|
||||
<button
|
||||
onClick={() => onSelect(session.id)}
|
||||
className={cn(
|
||||
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
|
||||
isActive ? "bg-zinc-100" : "hover:bg-zinc-50",
|
||||
)}
|
||||
>
|
||||
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
|
||||
<div className="flex min-w-0 max-w-full items-center gap-1.5">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"truncate font-normal",
|
||||
isActive ? "text-zinc-600" : "text-zinc-800",
|
||||
)}
|
||||
>
|
||||
{session.title || "Untitled chat"}
|
||||
</Text>
|
||||
{showProcessing ? (
|
||||
<PulseLoader size={8} className="shrink-0" />
|
||||
) : null}
|
||||
{showCompleted ? (
|
||||
<CheckCircle
|
||||
className="h-4 w-4 shrink-0 text-green-500"
|
||||
weight="fill"
|
||||
/>
|
||||
) : null}
|
||||
</div>
|
||||
{startTypeLabel ? (
|
||||
<div className="mt-1">
|
||||
<Badge variant="info" size="small">
|
||||
{startTypeLabel}
|
||||
</Badge>
|
||||
</div>
|
||||
) : null}
|
||||
<Text variant="small" className="text-neutral-400">
|
||||
{formatSessionDate(session.updated_at)}
|
||||
</Text>
|
||||
</div>
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"group relative w-full rounded-lg transition-colors",
|
||||
isActive ? "bg-zinc-100" : "hover:bg-zinc-50",
|
||||
)}
|
||||
>
|
||||
<button
|
||||
onClick={() => onSelect(session.id)}
|
||||
className="w-full px-3 py-2.5 pr-10 text-left"
|
||||
>
|
||||
<div className="flex min-w-0 max-w-full items-center gap-2">
|
||||
<div className="min-w-0 flex-1">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"truncate font-normal",
|
||||
isActive ? "text-zinc-600" : "text-zinc-800",
|
||||
)}
|
||||
>
|
||||
<AnimatePresence mode="wait" initial={false}>
|
||||
<motion.span
|
||||
key={session.title || "untitled"}
|
||||
initial={{ opacity: 0, y: 4 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
exit={{ opacity: 0, y: -4 }}
|
||||
transition={{ duration: 0.2 }}
|
||||
className="block truncate"
|
||||
>
|
||||
{session.title || "Untitled chat"}
|
||||
</motion.span>
|
||||
</AnimatePresence>
|
||||
</Text>
|
||||
{startTypeLabel ? (
|
||||
<div className="mt-1">
|
||||
<Badge variant="info" size="small">
|
||||
{startTypeLabel}
|
||||
</Badge>
|
||||
</div>
|
||||
) : null}
|
||||
<Text variant="small" className="text-neutral-400">
|
||||
{formatSessionDate(session.updated_at)}
|
||||
</Text>
|
||||
</div>
|
||||
{showProcessing ? (
|
||||
<PulseLoader size={16} className="shrink-0" />
|
||||
) : null}
|
||||
{showCompleted ? (
|
||||
<CheckCircle
|
||||
className="h-4 w-4 shrink-0 text-green-500"
|
||||
weight="fill"
|
||||
/>
|
||||
) : null}
|
||||
</div>
|
||||
</button>
|
||||
{actionSlot ? (
|
||||
<div className="absolute right-2 top-1/2 -translate-y-1/2">
|
||||
{actionSlot}
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,5 +1,65 @@
|
||||
import type { GetV2ListSessionsParams } from "@/app/api/__generated__/models/getV2ListSessionsParams";
|
||||
import {
|
||||
ChatSessionStartType,
|
||||
type ChatSessionStartType as ChatSessionStartTypeValue,
|
||||
} from "@/app/api/__generated__/models/chatSessionStartType";
|
||||
import type { UIMessage } from "ai";
|
||||
|
||||
export const COPILOT_SESSION_LIST_LIMIT = 50;
|
||||
|
||||
export function getSessionListParams(): GetV2ListSessionsParams {
|
||||
return {
|
||||
limit: COPILOT_SESSION_LIST_LIMIT,
|
||||
with_auto: true,
|
||||
};
|
||||
}
|
||||
|
||||
export function isNonManualSessionStartType(
|
||||
startType: ChatSessionStartTypeValue | null | undefined,
|
||||
): boolean {
|
||||
return startType != null && startType !== ChatSessionStartType.MANUAL;
|
||||
}
|
||||
|
||||
export function getSessionStartTypeLabel(
|
||||
startType: ChatSessionStartTypeValue,
|
||||
): string | null {
|
||||
switch (startType) {
|
||||
case ChatSessionStartType.AUTOPILOT_NIGHTLY:
|
||||
return "Nightly";
|
||||
case ChatSessionStartType.AUTOPILOT_CALLBACK:
|
||||
return "Callback";
|
||||
case ChatSessionStartType.AUTOPILOT_INVITE_CTA:
|
||||
return "Invite CTA";
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export function formatSessionDate(dateString: string): string {
|
||||
const date = new Date(dateString);
|
||||
const now = new Date();
|
||||
const diffMs = now.getTime() - date.getTime();
|
||||
const diffDays = Math.floor(diffMs / (1000 * 60 * 60 * 24));
|
||||
|
||||
if (diffDays === 0) return "Today";
|
||||
if (diffDays === 1) return "Yesterday";
|
||||
if (diffDays < 7) return `${diffDays} days ago`;
|
||||
|
||||
const day = date.getDate();
|
||||
const ordinal =
|
||||
day % 10 === 1 && day !== 11
|
||||
? "st"
|
||||
: day % 10 === 2 && day !== 12
|
||||
? "nd"
|
||||
: day % 10 === 3 && day !== 13
|
||||
? "rd"
|
||||
: "th";
|
||||
const month = date.toLocaleDateString("en-US", { month: "short" });
|
||||
const year = date.getFullYear();
|
||||
|
||||
return `${day}${ordinal} ${month} ${year}`;
|
||||
}
|
||||
|
||||
/** Mark any in-progress tool parts as completed/errored so spinners stop. */
|
||||
export function resolveInProgressTools(
|
||||
messages: UIMessage[],
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
import { usePostV2ConsumeCallbackTokenRoute } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { parseAsString, useQueryState } from "nuqs";
|
||||
import { useEffect, useState } from "react";
|
||||
import { getGetV2ListSessionsQueryKey } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
|
||||
interface Props {
|
||||
isLoggedIn: boolean;
|
||||
onConsumed: (sessionId: string) => void;
|
||||
onClearAutopilot: () => void;
|
||||
}
|
||||
|
||||
export function useCallbackToken({
|
||||
isLoggedIn,
|
||||
onConsumed,
|
||||
onClearAutopilot,
|
||||
}: Props) {
|
||||
const queryClient = useQueryClient();
|
||||
const [callbackToken, setCallbackToken] = useQueryState(
|
||||
"callbackToken",
|
||||
parseAsString,
|
||||
);
|
||||
const [consumedTokens, setConsumedTokens] = useState<Set<string>>(
|
||||
() => new Set(),
|
||||
);
|
||||
const { mutateAsync: consumeCallbackToken, isPending } =
|
||||
usePostV2ConsumeCallbackTokenRoute();
|
||||
|
||||
const hasConsumedToken =
|
||||
callbackToken != null && consumedTokens.has(callbackToken);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isLoggedIn || !callbackToken || hasConsumedToken) {
|
||||
return;
|
||||
}
|
||||
|
||||
let isCancelled = false;
|
||||
const token = callbackToken;
|
||||
setConsumedTokens((current) => new Set(current).add(token));
|
||||
|
||||
void consumeCallbackToken({ data: { token } })
|
||||
.then((response) => {
|
||||
if (isCancelled) {
|
||||
return;
|
||||
}
|
||||
if (response.status !== 200 || !response.data?.session_id) {
|
||||
throw new Error("Failed to open callback session");
|
||||
}
|
||||
|
||||
onConsumed(response.data.session_id);
|
||||
onClearAutopilot();
|
||||
void setCallbackToken(null);
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
});
|
||||
})
|
||||
.catch((error) => {
|
||||
if (isCancelled) {
|
||||
return;
|
||||
}
|
||||
setConsumedTokens((current) => {
|
||||
const next = new Set(current);
|
||||
next.delete(token);
|
||||
return next;
|
||||
});
|
||||
void setCallbackToken(null);
|
||||
toast({
|
||||
title: "Unable to open callback session",
|
||||
description:
|
||||
error instanceof Error ? error.message : "Please try again.",
|
||||
variant: "destructive",
|
||||
});
|
||||
});
|
||||
|
||||
return () => {
|
||||
isCancelled = true;
|
||||
};
|
||||
}, [
|
||||
callbackToken,
|
||||
consumeCallbackToken,
|
||||
hasConsumedToken,
|
||||
isLoggedIn,
|
||||
onClearAutopilot,
|
||||
onConsumed,
|
||||
queryClient,
|
||||
setCallbackToken,
|
||||
]);
|
||||
|
||||
return {
|
||||
isConsumingCallbackToken: isPending,
|
||||
};
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import {
|
||||
useGetV2GetSession,
|
||||
usePostV2CreateSession,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import type { ChatSessionStartType } from "@/app/api/__generated__/models/chatSessionStartType";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
@@ -70,6 +71,14 @@ export function useChatSession() {
|
||||
);
|
||||
}, [sessionQuery.data, sessionId, hasActiveStream]);
|
||||
|
||||
const sessionStartType = useMemo<ChatSessionStartType | null>(() => {
|
||||
if (sessionQuery.data?.status !== 200) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return sessionQuery.data.data.start_type;
|
||||
}, [sessionQuery.data]);
|
||||
|
||||
const { mutateAsync: createSessionMutation, isPending: isCreatingSession } =
|
||||
usePostV2CreateSession({
|
||||
mutation: {
|
||||
@@ -121,6 +130,7 @@ export function useChatSession() {
|
||||
return {
|
||||
sessionId,
|
||||
setSessionId,
|
||||
sessionStartType,
|
||||
hydratedMessages,
|
||||
hasActiveStream,
|
||||
isLoadingSession: sessionQuery.isLoading,
|
||||
|
||||
@@ -2,34 +2,24 @@ import {
|
||||
getGetV2ListSessionsQueryKey,
|
||||
useDeleteV2DeleteSession,
|
||||
useGetV2ListSessions,
|
||||
type getV2ListSessionsResponse,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { uploadFileDirect } from "@/lib/direct-upload";
|
||||
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import type { FileUIPart } from "ai";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { getSessionListParams } from "./helpers";
|
||||
import { useCopilotUIStore } from "./store";
|
||||
import { useCallbackToken } from "./useCallbackToken";
|
||||
import { useChatSession } from "./useChatSession";
|
||||
import { useFileUpload } from "./useFileUpload";
|
||||
import { useCopilotNotifications } from "./useCopilotNotifications";
|
||||
import { useCopilotStream } from "./useCopilotStream";
|
||||
|
||||
const TITLE_POLL_INTERVAL_MS = 2_000;
|
||||
const TITLE_POLL_MAX_ATTEMPTS = 5;
|
||||
|
||||
interface UploadedFile {
|
||||
file_id: string;
|
||||
name: string;
|
||||
mime_type: string;
|
||||
}
|
||||
import { useTitlePolling } from "./useTitlePolling";
|
||||
|
||||
export function useCopilotPage() {
|
||||
const { isUserLoading, isLoggedIn } = useSupabase();
|
||||
const [isUploadingFiles, setIsUploadingFiles] = useState(false);
|
||||
const [pendingMessage, setPendingMessage] = useState<string | null>(null);
|
||||
const queryClient = useQueryClient();
|
||||
const listSessionsParams = getSessionListParams();
|
||||
|
||||
const { sessionToDelete, setSessionToDelete, isDrawerOpen, setDrawerOpen } =
|
||||
useCopilotUIStore();
|
||||
@@ -39,7 +29,7 @@ export function useCopilotPage() {
|
||||
setSessionId,
|
||||
hydratedMessages,
|
||||
hasActiveStream,
|
||||
isLoadingSession,
|
||||
isLoadingSession: isLoadingCurrentSession,
|
||||
isSessionError,
|
||||
createSession,
|
||||
isCreatingSession,
|
||||
@@ -93,198 +83,33 @@ export function useCopilotPage() {
|
||||
const isMobile =
|
||||
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
||||
|
||||
const pendingFilesRef = useRef<File[]>([]);
|
||||
const { isConsumingCallbackToken } = useCallbackToken({
|
||||
isLoggedIn,
|
||||
onConsumed: setSessionId,
|
||||
onClearAutopilot() {},
|
||||
});
|
||||
|
||||
// --- Send pending message after session creation ---
|
||||
useEffect(() => {
|
||||
if (!sessionId || pendingMessage === null) return;
|
||||
const msg = pendingMessage;
|
||||
const files = pendingFilesRef.current;
|
||||
setPendingMessage(null);
|
||||
pendingFilesRef.current = [];
|
||||
|
||||
if (files.length > 0) {
|
||||
setIsUploadingFiles(true);
|
||||
void uploadFiles(files, sessionId)
|
||||
.then((uploaded) => {
|
||||
if (uploaded.length === 0) {
|
||||
toast({
|
||||
title: "File upload failed",
|
||||
description: "Could not upload any files. Please try again.",
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
const fileParts = buildFileParts(uploaded);
|
||||
sendMessage({
|
||||
text: msg,
|
||||
files: fileParts.length > 0 ? fileParts : undefined,
|
||||
});
|
||||
})
|
||||
.finally(() => setIsUploadingFiles(false));
|
||||
} else {
|
||||
sendMessage({ text: msg });
|
||||
}
|
||||
}, [sessionId, pendingMessage, sendMessage]);
|
||||
|
||||
async function uploadFiles(
|
||||
files: File[],
|
||||
sid: string,
|
||||
): Promise<UploadedFile[]> {
|
||||
const results = await Promise.allSettled(
|
||||
files.map(async (file) => {
|
||||
try {
|
||||
const data = await uploadFileDirect(file, sid);
|
||||
if (!data.file_id) throw new Error("No file_id returned");
|
||||
return {
|
||||
file_id: data.file_id,
|
||||
name: data.name || file.name,
|
||||
mime_type: data.mime_type || "application/octet-stream",
|
||||
} as UploadedFile;
|
||||
} catch (err) {
|
||||
console.error("File upload failed:", err);
|
||||
toast({
|
||||
title: "File upload failed",
|
||||
description: file.name,
|
||||
variant: "destructive",
|
||||
});
|
||||
throw err;
|
||||
}
|
||||
}),
|
||||
);
|
||||
return results
|
||||
.filter(
|
||||
(r): r is PromiseFulfilledResult<UploadedFile> =>
|
||||
r.status === "fulfilled",
|
||||
)
|
||||
.map((r) => r.value);
|
||||
}
|
||||
|
||||
function buildFileParts(uploaded: UploadedFile[]): FileUIPart[] {
|
||||
return uploaded.map((f) => ({
|
||||
type: "file" as const,
|
||||
mediaType: f.mime_type,
|
||||
filename: f.name,
|
||||
url: `/api/proxy/api/workspace/files/${f.file_id}/download`,
|
||||
}));
|
||||
}
|
||||
|
||||
async function onSend(message: string, files?: File[]) {
|
||||
const trimmed = message.trim();
|
||||
if (!trimmed && (!files || files.length === 0)) return;
|
||||
|
||||
// Client-side file limits
|
||||
if (files && files.length > 0) {
|
||||
const MAX_FILES = 10;
|
||||
const MAX_FILE_SIZE_BYTES = 100 * 1024 * 1024; // 100 MB
|
||||
|
||||
if (files.length > MAX_FILES) {
|
||||
toast({
|
||||
title: "Too many files",
|
||||
description: `You can attach up to ${MAX_FILES} files at once.`,
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const oversized = files.filter((f) => f.size > MAX_FILE_SIZE_BYTES);
|
||||
if (oversized.length > 0) {
|
||||
toast({
|
||||
title: "File too large",
|
||||
description: `${oversized[0].name} exceeds the 100 MB limit.`,
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
isUserStoppingRef.current = false;
|
||||
|
||||
if (sessionId) {
|
||||
if (files && files.length > 0) {
|
||||
setIsUploadingFiles(true);
|
||||
try {
|
||||
const uploaded = await uploadFiles(files, sessionId);
|
||||
if (uploaded.length === 0) {
|
||||
// All uploads failed — abort send so chips revert to editable
|
||||
throw new Error("All file uploads failed");
|
||||
}
|
||||
const fileParts = buildFileParts(uploaded);
|
||||
sendMessage({
|
||||
text: trimmed || "",
|
||||
files: fileParts.length > 0 ? fileParts : undefined,
|
||||
});
|
||||
} finally {
|
||||
setIsUploadingFiles(false);
|
||||
}
|
||||
} else {
|
||||
sendMessage({ text: trimmed });
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
setPendingMessage(trimmed || "");
|
||||
if (files && files.length > 0) {
|
||||
pendingFilesRef.current = files;
|
||||
}
|
||||
await createSession();
|
||||
}
|
||||
const { isUploadingFiles, onSend } = useFileUpload({
|
||||
createSession,
|
||||
isUserStoppingRef,
|
||||
sendMessage,
|
||||
sessionId,
|
||||
});
|
||||
|
||||
// --- Session list (for mobile drawer & sidebar) ---
|
||||
const { data: sessionsResponse, isLoading: isLoadingSessions } =
|
||||
useGetV2ListSessions(
|
||||
{ limit: 50 },
|
||||
{ query: { enabled: !isUserLoading && isLoggedIn } },
|
||||
);
|
||||
useGetV2ListSessions(listSessionsParams, {
|
||||
query: { enabled: !isUserLoading && isLoggedIn },
|
||||
});
|
||||
|
||||
const sessions =
|
||||
sessionsResponse?.status === 200 ? sessionsResponse.data.sessions : [];
|
||||
|
||||
// Start title polling when stream ends cleanly — sidebar title animates in
|
||||
const titlePollRef = useRef<ReturnType<typeof setInterval>>();
|
||||
const prevStatusRef = useRef(status);
|
||||
|
||||
useEffect(() => {
|
||||
const prev = prevStatusRef.current;
|
||||
prevStatusRef.current = status;
|
||||
|
||||
const wasActive = prev === "streaming" || prev === "submitted";
|
||||
const isNowReady = status === "ready";
|
||||
|
||||
if (!wasActive || !isNowReady || !sessionId || isReconnecting) return;
|
||||
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey({ limit: 50 }),
|
||||
});
|
||||
const sid = sessionId;
|
||||
let attempts = 0;
|
||||
clearInterval(titlePollRef.current);
|
||||
titlePollRef.current = setInterval(() => {
|
||||
const data = queryClient.getQueryData<getV2ListSessionsResponse>(
|
||||
getGetV2ListSessionsQueryKey({ limit: 50 }),
|
||||
);
|
||||
const hasTitle =
|
||||
data?.status === 200 &&
|
||||
data.data.sessions.some((s) => s.id === sid && s.title);
|
||||
if (hasTitle || attempts >= TITLE_POLL_MAX_ATTEMPTS) {
|
||||
clearInterval(titlePollRef.current);
|
||||
titlePollRef.current = undefined;
|
||||
return;
|
||||
}
|
||||
attempts += 1;
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey({ limit: 50 }),
|
||||
});
|
||||
}, TITLE_POLL_INTERVAL_MS);
|
||||
}, [status, sessionId, isReconnecting, queryClient]);
|
||||
|
||||
// Clean up polling on session change or unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
clearInterval(titlePollRef.current);
|
||||
titlePollRef.current = undefined;
|
||||
};
|
||||
}, [sessionId]);
|
||||
useTitlePolling({
|
||||
isReconnecting,
|
||||
sessionId,
|
||||
status,
|
||||
});
|
||||
|
||||
// --- Mobile drawer handlers ---
|
||||
function handleOpenDrawer() {
|
||||
@@ -334,7 +159,7 @@ export function useCopilotPage() {
|
||||
error,
|
||||
stop,
|
||||
isReconnecting,
|
||||
isLoadingSession,
|
||||
isLoadingSession: isLoadingCurrentSession || isConsumingCallbackToken,
|
||||
isSessionError,
|
||||
isCreatingSession,
|
||||
isUploadingFiles,
|
||||
|
||||
@@ -0,0 +1,178 @@
|
||||
import { uploadFileDirect } from "@/lib/direct-upload";
|
||||
import type { FileUIPart } from "ai";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useEffect, useRef, useState, type MutableRefObject } from "react";
|
||||
|
||||
const MAX_FILES = 10;
|
||||
const MAX_FILE_SIZE_BYTES = 100 * 1024 * 1024;
|
||||
|
||||
interface UploadedFile {
|
||||
file_id: string;
|
||||
name: string;
|
||||
mime_type: string;
|
||||
}
|
||||
|
||||
interface SendMessageInput {
|
||||
text: string;
|
||||
files?: FileUIPart[];
|
||||
}
|
||||
|
||||
interface Props {
|
||||
createSession: () => Promise<unknown>;
|
||||
isUserStoppingRef: MutableRefObject<boolean>;
|
||||
sendMessage: (input: SendMessageInput) => void;
|
||||
sessionId: string | null;
|
||||
}
|
||||
|
||||
async function uploadFiles(
|
||||
files: File[],
|
||||
sessionId: string,
|
||||
): Promise<UploadedFile[]> {
|
||||
const results = await Promise.allSettled(
|
||||
files.map(async (file) => {
|
||||
try {
|
||||
const data = await uploadFileDirect(file, sessionId);
|
||||
if (!data.file_id) throw new Error("No file_id returned");
|
||||
return {
|
||||
file_id: data.file_id,
|
||||
name: data.name || file.name,
|
||||
mime_type: data.mime_type || "application/octet-stream",
|
||||
} satisfies UploadedFile;
|
||||
} catch (error) {
|
||||
console.error("File upload failed:", error);
|
||||
toast({
|
||||
title: "File upload failed",
|
||||
description: file.name,
|
||||
variant: "destructive",
|
||||
});
|
||||
throw error;
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
return results
|
||||
.filter(
|
||||
(result): result is PromiseFulfilledResult<UploadedFile> =>
|
||||
result.status === "fulfilled",
|
||||
)
|
||||
.map((result) => result.value);
|
||||
}
|
||||
|
||||
function buildFileParts(uploaded: UploadedFile[]): FileUIPart[] {
|
||||
return uploaded.map((file) => ({
|
||||
type: "file" as const,
|
||||
mediaType: file.mime_type,
|
||||
filename: file.name,
|
||||
url: `/api/proxy/api/workspace/files/${file.file_id}/download`,
|
||||
}));
|
||||
}
|
||||
|
||||
export function useFileUpload({
|
||||
createSession,
|
||||
isUserStoppingRef,
|
||||
sendMessage,
|
||||
sessionId,
|
||||
}: Props) {
|
||||
const [isUploadingFiles, setIsUploadingFiles] = useState(false);
|
||||
const [pendingMessage, setPendingMessage] = useState<string | null>(null);
|
||||
const pendingFilesRef = useRef<File[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!sessionId || pendingMessage === null) {
|
||||
return;
|
||||
}
|
||||
|
||||
const message = pendingMessage;
|
||||
const files = pendingFilesRef.current;
|
||||
setPendingMessage(null);
|
||||
pendingFilesRef.current = [];
|
||||
|
||||
if (files.length === 0) {
|
||||
sendMessage({ text: message });
|
||||
return;
|
||||
}
|
||||
|
||||
setIsUploadingFiles(true);
|
||||
void uploadFiles(files, sessionId)
|
||||
.then((uploaded) => {
|
||||
if (uploaded.length === 0) {
|
||||
toast({
|
||||
title: "File upload failed",
|
||||
description: "Could not upload any files. Please try again.",
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const fileParts = buildFileParts(uploaded);
|
||||
sendMessage({
|
||||
text: message,
|
||||
files: fileParts.length > 0 ? fileParts : undefined,
|
||||
});
|
||||
})
|
||||
.finally(() => setIsUploadingFiles(false));
|
||||
}, [pendingMessage, sendMessage, sessionId]);
|
||||
|
||||
async function onSend(message: string, files?: File[]) {
|
||||
const trimmed = message.trim();
|
||||
if (!trimmed && (!files || files.length === 0)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (files && files.length > 0) {
|
||||
if (files.length > MAX_FILES) {
|
||||
toast({
|
||||
title: "Too many files",
|
||||
description: `You can attach up to ${MAX_FILES} files at once.`,
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const oversized = files.filter((file) => file.size > MAX_FILE_SIZE_BYTES);
|
||||
if (oversized.length > 0) {
|
||||
toast({
|
||||
title: "File too large",
|
||||
description: `${oversized[0].name} exceeds the 100 MB limit.`,
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
isUserStoppingRef.current = false;
|
||||
|
||||
if (sessionId) {
|
||||
if (!files || files.length === 0) {
|
||||
sendMessage({ text: trimmed });
|
||||
return;
|
||||
}
|
||||
|
||||
setIsUploadingFiles(true);
|
||||
try {
|
||||
const uploaded = await uploadFiles(files, sessionId);
|
||||
if (uploaded.length === 0) {
|
||||
throw new Error("All file uploads failed");
|
||||
}
|
||||
|
||||
const fileParts = buildFileParts(uploaded);
|
||||
sendMessage({
|
||||
text: trimmed || "",
|
||||
files: fileParts.length > 0 ? fileParts : undefined,
|
||||
});
|
||||
} finally {
|
||||
setIsUploadingFiles(false);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
setPendingMessage(trimmed || "");
|
||||
pendingFilesRef.current = files ?? [];
|
||||
await createSession();
|
||||
}
|
||||
|
||||
return {
|
||||
isUploadingFiles,
|
||||
onSend,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
import {
|
||||
getGetV2ListSessionsQueryKey,
|
||||
type getV2ListSessionsResponse,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useEffect, useRef } from "react";
|
||||
import { getSessionListParams } from "./helpers";
|
||||
|
||||
const TITLE_POLL_INTERVAL_MS = 2_000;
|
||||
const TITLE_POLL_MAX_ATTEMPTS = 5;
|
||||
|
||||
interface Props {
|
||||
isReconnecting: boolean;
|
||||
sessionId: string | null;
|
||||
status: string;
|
||||
}
|
||||
|
||||
export function useTitlePolling({ isReconnecting, sessionId, status }: Props) {
|
||||
const queryClient = useQueryClient();
|
||||
const previousStatusRef = useRef(status);
|
||||
|
||||
useEffect(() => {
|
||||
const previousStatus = previousStatusRef.current;
|
||||
previousStatusRef.current = status;
|
||||
|
||||
const wasActive =
|
||||
previousStatus === "streaming" || previousStatus === "submitted";
|
||||
const isNowReady = status === "ready";
|
||||
|
||||
if (!wasActive || !isNowReady || !sessionId || isReconnecting) {
|
||||
return;
|
||||
}
|
||||
|
||||
const params = getSessionListParams();
|
||||
const queryKey = getGetV2ListSessionsQueryKey(params);
|
||||
let attempts = 0;
|
||||
let timeoutId: ReturnType<typeof setTimeout> | undefined;
|
||||
let isCancelled = false;
|
||||
|
||||
const poll = () => {
|
||||
if (isCancelled) {
|
||||
return;
|
||||
}
|
||||
|
||||
const data =
|
||||
queryClient.getQueryData<getV2ListSessionsResponse>(queryKey);
|
||||
const hasTitle =
|
||||
data?.status === 200 &&
|
||||
data.data.sessions.some(
|
||||
(session) => session.id === sessionId && session.title,
|
||||
);
|
||||
|
||||
if (hasTitle || attempts >= TITLE_POLL_MAX_ATTEMPTS) {
|
||||
return;
|
||||
}
|
||||
|
||||
attempts += 1;
|
||||
queryClient.invalidateQueries({ queryKey });
|
||||
timeoutId = setTimeout(poll, TITLE_POLL_INTERVAL_MS);
|
||||
};
|
||||
|
||||
queryClient.invalidateQueries({ queryKey });
|
||||
timeoutId = setTimeout(poll, TITLE_POLL_INTERVAL_MS);
|
||||
|
||||
return () => {
|
||||
isCancelled = true;
|
||||
if (timeoutId) {
|
||||
clearTimeout(timeoutId);
|
||||
}
|
||||
};
|
||||
}, [isReconnecting, queryClient, sessionId, status]);
|
||||
}
|
||||
@@ -1030,6 +1030,16 @@
|
||||
"default": 0,
|
||||
"title": "Offset"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "with_auto",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"title": "With Auto"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
@@ -1079,6 +1089,47 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/sessions/callback-token/consume": {
|
||||
"post": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Consume Callback Token Route",
|
||||
"operationId": "postV2ConsumeCallbackTokenRoute",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ConsumeCallbackTokenRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ConsumeCallbackTokenResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/chat/sessions/{session_id}": {
|
||||
"delete": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
@@ -6670,6 +6721,145 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/users/admin/copilot/send-emails": {
|
||||
"post": {
|
||||
"tags": ["v2", "admin", "users", "admin"],
|
||||
"summary": "Send Pending Copilot Emails",
|
||||
"operationId": "postV2SendPendingCopilotEmails",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/SendCopilotEmailsRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/SendCopilotEmailsResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/users/admin/copilot/trigger": {
|
||||
"post": {
|
||||
"tags": ["v2", "admin", "users", "admin"],
|
||||
"summary": "Trigger Copilot Session",
|
||||
"operationId": "postV2TriggerCopilotSession",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/TriggerCopilotSessionRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/TriggerCopilotSessionResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/users/admin/copilot/users": {
|
||||
"get": {
|
||||
"tags": ["v2", "admin", "users", "admin"],
|
||||
"summary": "Search Copilot Users",
|
||||
"operationId": "getV2SearchCopilotUsers",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "search",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"description": "Search by email, name, or user ID",
|
||||
"default": "",
|
||||
"title": "Search"
|
||||
},
|
||||
"description": "Search by email, name, or user ID"
|
||||
},
|
||||
{
|
||||
"name": "limit",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "integer",
|
||||
"maximum": 50,
|
||||
"minimum": 1,
|
||||
"default": 20,
|
||||
"title": "Limit"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/AdminCopilotUsersResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/users/admin/invited-users": {
|
||||
"get": {
|
||||
"tags": ["v2", "admin", "users", "admin"],
|
||||
@@ -7270,6 +7460,42 @@
|
||||
"required": ["new_balance", "transaction_key"],
|
||||
"title": "AddUserCreditsResponse"
|
||||
},
|
||||
"AdminCopilotUserSummary": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
"email": { "type": "string", "title": "Email" },
|
||||
"name": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Name"
|
||||
},
|
||||
"timezone": { "type": "string", "title": "Timezone" },
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"title": "Created At"
|
||||
},
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"title": "Updated At"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["id", "email", "timezone", "created_at", "updated_at"],
|
||||
"title": "AdminCopilotUserSummary"
|
||||
},
|
||||
"AdminCopilotUsersResponse": {
|
||||
"properties": {
|
||||
"users": {
|
||||
"items": { "$ref": "#/components/schemas/AdminCopilotUserSummary" },
|
||||
"type": "array",
|
||||
"title": "Users"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["users"],
|
||||
"title": "AdminCopilotUsersResponse"
|
||||
},
|
||||
"AgentDetails": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
@@ -7283,14 +7509,12 @@
|
||||
"inputs": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Inputs",
|
||||
"default": {}
|
||||
"title": "Inputs"
|
||||
},
|
||||
"credentials": {
|
||||
"items": { "$ref": "#/components/schemas/CredentialsMetaInput" },
|
||||
"type": "array",
|
||||
"title": "Credentials",
|
||||
"default": []
|
||||
"title": "Credentials"
|
||||
},
|
||||
"execution_options": {
|
||||
"$ref": "#/components/schemas/ExecutionOptions"
|
||||
@@ -7867,20 +8091,17 @@
|
||||
"inputs": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Inputs",
|
||||
"default": {}
|
||||
"title": "Inputs"
|
||||
},
|
||||
"outputs": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Outputs",
|
||||
"default": {}
|
||||
"title": "Outputs"
|
||||
},
|
||||
"credentials": {
|
||||
"items": { "$ref": "#/components/schemas/CredentialsMetaInput" },
|
||||
"type": "array",
|
||||
"title": "Credentials",
|
||||
"default": []
|
||||
"title": "Credentials"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
@@ -8419,6 +8640,16 @@
|
||||
"required": ["query", "conversation_history", "message_id"],
|
||||
"title": "ChatRequest"
|
||||
},
|
||||
"ChatSessionStartType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"MANUAL",
|
||||
"AUTOPILOT_NIGHTLY",
|
||||
"AUTOPILOT_CALLBACK",
|
||||
"AUTOPILOT_INVITE_CTA"
|
||||
],
|
||||
"title": "ChatSessionStartType"
|
||||
},
|
||||
"ClarificationNeededResponse": {
|
||||
"properties": {
|
||||
"type": {
|
||||
@@ -8455,6 +8686,20 @@
|
||||
"title": "ClarifyingQuestion",
|
||||
"description": "A question that needs user clarification."
|
||||
},
|
||||
"ConsumeCallbackTokenRequest": {
|
||||
"properties": { "token": { "type": "string", "title": "Token" } },
|
||||
"type": "object",
|
||||
"required": ["token"],
|
||||
"title": "ConsumeCallbackTokenRequest"
|
||||
},
|
||||
"ConsumeCallbackTokenResponse": {
|
||||
"properties": {
|
||||
"session_id": { "type": "string", "title": "Session Id" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["session_id"],
|
||||
"title": "ConsumeCallbackTokenResponse"
|
||||
},
|
||||
"ContentType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -10878,8 +11123,7 @@
|
||||
"suggestions": {
|
||||
"items": { "type": "string" },
|
||||
"type": "array",
|
||||
"title": "Suggestions",
|
||||
"default": []
|
||||
"title": "Suggestions"
|
||||
},
|
||||
"name": { "type": "string", "title": "Name", "default": "no_results" }
|
||||
},
|
||||
@@ -11915,6 +12159,7 @@
|
||||
"error",
|
||||
"no_results",
|
||||
"need_login",
|
||||
"completion_report_saved",
|
||||
"agents_found",
|
||||
"agent_details",
|
||||
"setup_requirements",
|
||||
@@ -12171,6 +12416,37 @@
|
||||
"required": ["items", "search_id", "total_items", "pagination"],
|
||||
"title": "SearchResponse"
|
||||
},
|
||||
"SendCopilotEmailsRequest": {
|
||||
"properties": { "user_id": { "type": "string", "title": "User Id" } },
|
||||
"type": "object",
|
||||
"required": ["user_id"],
|
||||
"title": "SendCopilotEmailsRequest"
|
||||
},
|
||||
"SendCopilotEmailsResponse": {
|
||||
"properties": {
|
||||
"candidate_count": { "type": "integer", "title": "Candidate Count" },
|
||||
"processed_count": { "type": "integer", "title": "Processed Count" },
|
||||
"sent_count": { "type": "integer", "title": "Sent Count" },
|
||||
"skipped_count": { "type": "integer", "title": "Skipped Count" },
|
||||
"repair_queued_count": {
|
||||
"type": "integer",
|
||||
"title": "Repair Queued Count"
|
||||
},
|
||||
"running_count": { "type": "integer", "title": "Running Count" },
|
||||
"failed_count": { "type": "integer", "title": "Failed Count" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"candidate_count",
|
||||
"processed_count",
|
||||
"sent_count",
|
||||
"skipped_count",
|
||||
"repair_queued_count",
|
||||
"running_count",
|
||||
"failed_count"
|
||||
],
|
||||
"title": "SendCopilotEmailsResponse"
|
||||
},
|
||||
"SessionDetailResponse": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
@@ -12180,6 +12456,11 @@
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "User Id"
|
||||
},
|
||||
"start_type": { "$ref": "#/components/schemas/ChatSessionStartType" },
|
||||
"execution_tag": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Execution Tag"
|
||||
},
|
||||
"messages": {
|
||||
"items": { "additionalProperties": true, "type": "object" },
|
||||
"type": "array",
|
||||
@@ -12193,7 +12474,14 @@
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["id", "created_at", "updated_at", "user_id", "messages"],
|
||||
"required": [
|
||||
"id",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"user_id",
|
||||
"start_type",
|
||||
"messages"
|
||||
],
|
||||
"title": "SessionDetailResponse",
|
||||
"description": "Response model providing complete details for a chat session, including messages."
|
||||
},
|
||||
@@ -12206,10 +12494,21 @@
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Title"
|
||||
},
|
||||
"start_type": { "$ref": "#/components/schemas/ChatSessionStartType" },
|
||||
"execution_tag": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Execution Tag"
|
||||
},
|
||||
"is_processing": { "type": "boolean", "title": "Is Processing" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["id", "created_at", "updated_at", "is_processing"],
|
||||
"required": [
|
||||
"id",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"start_type",
|
||||
"is_processing"
|
||||
],
|
||||
"title": "SessionSummaryResponse",
|
||||
"description": "Response model for a session summary (without messages)."
|
||||
},
|
||||
@@ -13840,6 +14139,24 @@
|
||||
"required": ["transactions", "next_transaction_time"],
|
||||
"title": "TransactionHistory"
|
||||
},
|
||||
"TriggerCopilotSessionRequest": {
|
||||
"properties": {
|
||||
"user_id": { "type": "string", "title": "User Id" },
|
||||
"start_type": { "$ref": "#/components/schemas/ChatSessionStartType" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["user_id", "start_type"],
|
||||
"title": "TriggerCopilotSessionRequest"
|
||||
},
|
||||
"TriggerCopilotSessionResponse": {
|
||||
"properties": {
|
||||
"session_id": { "type": "string", "title": "Session Id" },
|
||||
"start_type": { "$ref": "#/components/schemas/ChatSessionStartType" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["session_id", "start_type"],
|
||||
"title": "TriggerCopilotSessionResponse"
|
||||
},
|
||||
"TriggeredPresetSetupRequest": {
|
||||
"properties": {
|
||||
"name": { "type": "string", "title": "Name" },
|
||||
@@ -14771,8 +15088,7 @@
|
||||
"missing_credentials": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Missing Credentials",
|
||||
"default": {}
|
||||
"title": "Missing Credentials"
|
||||
},
|
||||
"ready_to_run": {
|
||||
"type": "boolean",
|
||||
|
||||
Reference in New Issue
Block a user