mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
PR Comments
This commit is contained in:
@@ -11,7 +11,6 @@ from uuid import uuid4
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from prisma.models import UserWorkspaceFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
@@ -61,6 +60,7 @@ from backend.copilot.tools.models import (
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import get_business_understanding
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
@@ -738,18 +738,14 @@ async def stream_chat_post(
|
||||
|
||||
if valid_ids:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Batch query instead of N+1
|
||||
files = await UserWorkspaceFile.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": valid_ids},
|
||||
"workspaceId": workspace.id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
files = await workspace_db().get_workspace_files_by_ids(
|
||||
workspace_id=workspace.id,
|
||||
file_ids=valid_ids,
|
||||
)
|
||||
# Only keep IDs that actually exist in the user's workspace
|
||||
sanitized_file_ids = [wf.id for wf in files] or None
|
||||
file_lines: list[str] = [
|
||||
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
|
||||
f"- {wf.name} ({wf.mime_type}, {round(wf.size_bytes / 1024, 1)} KB), file_id={wf.id}"
|
||||
for wf in files
|
||||
]
|
||||
if file_lines:
|
||||
|
||||
@@ -346,11 +346,11 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
workspace_store = mocker.MagicMock()
|
||||
workspace_store.get_workspace_files_by_ids = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
"backend.api.features.chat.routes.workspace_db",
|
||||
return_value=workspace_store,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
@@ -376,11 +376,11 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
workspace_store = mocker.MagicMock()
|
||||
workspace_store.get_workspace_files_by_ids = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
"backend.api.features.chat.routes.workspace_db",
|
||||
return_value=workspace_store,
|
||||
)
|
||||
|
||||
valid_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
@@ -398,9 +398,10 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
)
|
||||
|
||||
# The find_many call should only receive the one valid UUID
|
||||
mock_prisma.find_many.assert_called_once()
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["id"]["in"] == [valid_id]
|
||||
workspace_store.get_workspace_files_by_ids.assert_called_once_with(
|
||||
workspace_id="ws-1",
|
||||
file_ids=[valid_id],
|
||||
)
|
||||
|
||||
|
||||
# ─── Cross-workspace file_ids ─────────────────────────────────────────
|
||||
@@ -414,11 +415,11 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
return_value=type("W", (), {"id": "my-workspace-id"})(),
|
||||
)
|
||||
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
workspace_store = mocker.MagicMock()
|
||||
workspace_store.get_workspace_files_by_ids = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
"backend.api.features.chat.routes.workspace_db",
|
||||
return_value=workspace_store,
|
||||
)
|
||||
|
||||
fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
@@ -427,9 +428,10 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
json={"message": "hi", "file_ids": [fid]},
|
||||
)
|
||||
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
|
||||
assert call_kwargs["where"]["isDeleted"] is False
|
||||
workspace_store.get_workspace_files_by_ids.assert_called_once_with(
|
||||
workspace_id="my-workspace-id",
|
||||
file_ids=[fid],
|
||||
)
|
||||
|
||||
|
||||
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
||||
|
||||
@@ -4,12 +4,11 @@ import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import UTC, date, datetime, time, timedelta
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import uuid4
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from backend.copilot import stream_registry
|
||||
@@ -29,14 +28,23 @@ from backend.copilot.session_types import (
|
||||
CompletionReportInput,
|
||||
StoredCompletionReport,
|
||||
)
|
||||
from backend.data.understanding import (
|
||||
format_understanding_for_prompt,
|
||||
get_business_understanding,
|
||||
parse_business_understanding_input,
|
||||
from backend.data.db_accessors import (
|
||||
chat_db,
|
||||
invited_user_db,
|
||||
review_db,
|
||||
understanding_db,
|
||||
user_db,
|
||||
)
|
||||
from backend.data.model import User
|
||||
from backend.data.understanding import format_understanding_for_prompt
|
||||
from backend.notifications.email import EmailSender
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.timezone_utils import get_user_timezone_or_utc
|
||||
from backend.util.url import get_frontend_base_url
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.invited_user import InvitedUserRecord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
@@ -104,26 +112,13 @@ def get_invite_cta_execution_tag() -> str:
|
||||
return AUTOPILOT_INVITE_CTA_TAG
|
||||
|
||||
|
||||
def _get_frontend_base_url() -> str:
|
||||
return (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
).rstrip("/")
|
||||
|
||||
|
||||
def _bucket_end_for_now(now_utc: datetime) -> datetime:
|
||||
minute = 30 if now_utc.minute >= 30 else 0
|
||||
return now_utc.replace(minute=minute, second=0, microsecond=0)
|
||||
|
||||
|
||||
def _resolve_timezone_name(raw_timezone: str | None) -> str:
|
||||
if not raw_timezone or raw_timezone == "not-set":
|
||||
return "UTC"
|
||||
try:
|
||||
ZoneInfo(raw_timezone)
|
||||
return raw_timezone
|
||||
except ZoneInfoNotFoundError:
|
||||
logger.warning("Unknown timezone %s; falling back to UTC", raw_timezone)
|
||||
return "UTC"
|
||||
return get_user_timezone_or_utc(raw_timezone)
|
||||
|
||||
|
||||
def _crosses_local_midnight(
|
||||
@@ -140,50 +135,28 @@ def _crosses_local_midnight(
|
||||
|
||||
|
||||
async def _user_has_recent_manual_message(user_id: str, since: datetime) -> bool:
|
||||
message = await prisma.models.ChatMessage.prisma().find_first(
|
||||
where={
|
||||
"role": "user",
|
||||
"createdAt": {"gte": since},
|
||||
"Session": {
|
||||
"is": {
|
||||
"userId": user_id,
|
||||
"startType": ChatSessionStartType.MANUAL.value,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
return message is not None
|
||||
return await chat_db().has_recent_manual_message(user_id, since)
|
||||
|
||||
|
||||
async def _user_has_session_since(user_id: str, since: datetime) -> bool:
|
||||
session = await prisma.models.ChatSession.prisma().find_first(
|
||||
where={"userId": user_id, "createdAt": {"gte": since}}
|
||||
)
|
||||
return session is not None
|
||||
return await chat_db().has_session_since(user_id, since)
|
||||
|
||||
|
||||
async def _session_exists_for_execution_tag(user_id: str, execution_tag: str) -> bool:
|
||||
existing = await prisma.models.ChatSession.prisma().find_first(
|
||||
where={"userId": user_id, "executionTag": execution_tag}
|
||||
)
|
||||
return existing is not None
|
||||
return await chat_db().session_exists_for_execution_tag(user_id, execution_tag)
|
||||
|
||||
|
||||
def _get_invited_user_tally_understanding(
|
||||
invited_user: prisma.models.InvitedUser | None,
|
||||
invited_user: InvitedUserRecord | None,
|
||||
) -> dict[str, Any] | None:
|
||||
if invited_user is None:
|
||||
return None
|
||||
|
||||
understanding = parse_business_understanding_input(invited_user.tallyUnderstanding)
|
||||
return understanding.model_dump(mode="json") if understanding is not None else None
|
||||
return invited_user.tally_understanding if invited_user is not None else None
|
||||
|
||||
|
||||
def _render_initial_message(
|
||||
start_type: ChatSessionStartType,
|
||||
*,
|
||||
user_name: str | None,
|
||||
invited_user: prisma.models.InvitedUser | None = None,
|
||||
invited_user: InvitedUserRecord | None = None,
|
||||
) -> str:
|
||||
display_name = user_name or "the user"
|
||||
if start_type == ChatSessionStartType.AUTOPILOT_NIGHTLY:
|
||||
@@ -263,14 +236,10 @@ async def _get_recent_manual_session_context(
|
||||
*,
|
||||
since_utc: datetime,
|
||||
) -> str:
|
||||
sessions = await prisma.models.ChatSession.prisma().find_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"startType": ChatSessionStartType.MANUAL.value,
|
||||
"updatedAt": {"gte": since_utc},
|
||||
},
|
||||
order={"updatedAt": "desc"},
|
||||
take=AUTOPILOT_RECENT_SESSION_LIMIT,
|
||||
sessions = await chat_db().get_manual_chat_sessions_since(
|
||||
user_id,
|
||||
since_utc,
|
||||
AUTOPILOT_RECENT_SESSION_LIMIT,
|
||||
)
|
||||
|
||||
if not sessions:
|
||||
@@ -280,12 +249,8 @@ async def _get_recent_manual_session_context(
|
||||
used_chars = 0
|
||||
|
||||
for session in sessions:
|
||||
messages = await prisma.models.ChatMessage.prisma().find_many(
|
||||
where={
|
||||
"sessionId": session.id,
|
||||
"createdAt": {"gte": since_utc},
|
||||
},
|
||||
order={"sequence": "asc"},
|
||||
messages = await chat_db().get_chat_messages_since(
|
||||
session.session_id, since_utc
|
||||
)
|
||||
|
||||
visible_messages: list[str] = []
|
||||
@@ -312,7 +277,7 @@ async def _get_recent_manual_session_context(
|
||||
|
||||
title_suffix = f" ({session.title})" if session.title else ""
|
||||
block = (
|
||||
f"### Session updated {session.updatedAt.isoformat()}{title_suffix}\n"
|
||||
f"### Session updated {session.updated_at.isoformat()}{title_suffix}\n"
|
||||
+ "\n".join(visible_messages)
|
||||
)
|
||||
if used_chars + len(block) > AUTOPILOT_RECENT_CONTEXT_CHAR_LIMIT:
|
||||
@@ -329,14 +294,14 @@ async def _get_recent_manual_session_context(
|
||||
|
||||
|
||||
async def _build_autopilot_system_prompt(
|
||||
user: prisma.models.User,
|
||||
user: User,
|
||||
*,
|
||||
start_type: ChatSessionStartType,
|
||||
timezone_name: str,
|
||||
target_local_date: date | None = None,
|
||||
invited_user: prisma.models.InvitedUser | None = None,
|
||||
invited_user: InvitedUserRecord | None = None,
|
||||
) -> str:
|
||||
understanding = await get_business_understanding(user.id)
|
||||
understanding = await understanding_db().get_business_understanding(user.id)
|
||||
context_sections = [
|
||||
(
|
||||
format_understanding_for_prompt(understanding)
|
||||
@@ -400,13 +365,13 @@ async def _enqueue_session_turn(
|
||||
|
||||
|
||||
async def _create_autopilot_session(
|
||||
user: prisma.models.User,
|
||||
user: User,
|
||||
*,
|
||||
start_type: ChatSessionStartType,
|
||||
execution_tag: str,
|
||||
timezone_name: str,
|
||||
target_local_date: date | None = None,
|
||||
invited_user: prisma.models.InvitedUser | None = None,
|
||||
invited_user: InvitedUserRecord | None = None,
|
||||
) -> ChatSession | None:
|
||||
if await _session_exists_for_execution_tag(user.id, execution_tag):
|
||||
return None
|
||||
@@ -446,9 +411,9 @@ async def _create_autopilot_session(
|
||||
|
||||
|
||||
async def _try_create_invite_cta_session(
|
||||
user: prisma.models.User,
|
||||
user: User,
|
||||
*,
|
||||
invited_user: prisma.models.InvitedUser | None,
|
||||
invited_user: InvitedUserRecord | None,
|
||||
now_utc: datetime,
|
||||
timezone_name: str,
|
||||
invite_cta_start: date,
|
||||
@@ -458,9 +423,9 @@ async def _try_create_invite_cta_session(
|
||||
return False
|
||||
if invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
|
||||
return False
|
||||
if invited_user.createdAt.date() < invite_cta_start:
|
||||
if invited_user.created_at.date() < invite_cta_start:
|
||||
return False
|
||||
if invited_user.createdAt > now_utc - invite_cta_delay:
|
||||
if invited_user.created_at > now_utc - invite_cta_delay:
|
||||
return False
|
||||
if await _session_exists_for_execution_tag(user.id, get_invite_cta_execution_tag()):
|
||||
return False
|
||||
@@ -476,7 +441,7 @@ async def _try_create_invite_cta_session(
|
||||
|
||||
|
||||
async def _try_create_nightly_session(
|
||||
user: prisma.models.User,
|
||||
user: User,
|
||||
*,
|
||||
now_utc: datetime,
|
||||
timezone_name: str,
|
||||
@@ -499,7 +464,7 @@ async def _try_create_nightly_session(
|
||||
|
||||
|
||||
async def _try_create_callback_session(
|
||||
user: prisma.models.User,
|
||||
user: User,
|
||||
*,
|
||||
callback_start: datetime,
|
||||
timezone_name: str,
|
||||
@@ -518,7 +483,7 @@ async def _try_create_callback_session(
|
||||
return created is not None
|
||||
|
||||
|
||||
async def dispatch_nightly_copilot() -> int:
|
||||
async def _dispatch_nightly_copilot() -> int:
|
||||
now_utc = datetime.now(UTC)
|
||||
bucket_end = _bucket_end_for_now(now_utc)
|
||||
bucket_start = bucket_end - timedelta(minutes=30)
|
||||
@@ -532,16 +497,12 @@ async def dispatch_nightly_copilot() -> int:
|
||||
hours=settings.config.nightly_copilot_invite_cta_delay_hours
|
||||
)
|
||||
|
||||
users = await prisma.models.User.prisma().find_many()
|
||||
invites = await prisma.models.InvitedUser.prisma().find_many(
|
||||
where={
|
||||
"authUserId": {
|
||||
"in": [user.id for user in users],
|
||||
}
|
||||
}
|
||||
users = await user_db().list_users()
|
||||
invites = await invited_user_db().list_invited_users_for_auth_users(
|
||||
[user.id for user in users]
|
||||
)
|
||||
invites_by_user_id = {
|
||||
invite.authUserId: invite for invite in invites if invite.authUserId
|
||||
invite.auth_user_id: invite for invite in invites if invite.auth_user_id
|
||||
}
|
||||
|
||||
created_count = 0
|
||||
@@ -589,16 +550,17 @@ async def dispatch_nightly_copilot() -> int:
|
||||
return created_count
|
||||
|
||||
|
||||
async def dispatch_nightly_copilot() -> int:
|
||||
return await _dispatch_nightly_copilot()
|
||||
|
||||
|
||||
async def _get_pending_approval_metadata(
|
||||
session: ChatSession,
|
||||
) -> tuple[int, str | None]:
|
||||
graph_exec_id = get_graph_exec_id_for_session(session.session_id)
|
||||
pending_count = await prisma.models.PendingHumanReview.prisma().count(
|
||||
where={
|
||||
"userId": session.user_id,
|
||||
"graphExecId": graph_exec_id,
|
||||
"status": prisma.enums.ReviewStatus.WAITING,
|
||||
}
|
||||
pending_count = await review_db().count_pending_reviews_for_graph_exec(
|
||||
graph_exec_id,
|
||||
session.user_id,
|
||||
)
|
||||
return pending_count, graph_exec_id if pending_count > 0 else None
|
||||
|
||||
@@ -740,32 +702,31 @@ def _split_email_paragraphs(text: str | None) -> list[str]:
|
||||
|
||||
async def _create_callback_token(
|
||||
session: ChatSession,
|
||||
) -> prisma.models.ChatSessionCallbackToken:
|
||||
) -> str:
|
||||
if session.completion_report is None:
|
||||
raise ValueError("Missing completion report")
|
||||
callback_session_message = session.completion_report.callback_session_message
|
||||
if callback_session_message is None:
|
||||
raise ValueError("Missing callback session message")
|
||||
|
||||
return await prisma.models.ChatSessionCallbackToken.prisma().create(
|
||||
data={
|
||||
"userId": session.user_id,
|
||||
"sourceSessionId": session.session_id,
|
||||
"callbackSessionMessage": callback_session_message,
|
||||
"expiresAt": datetime.now(UTC)
|
||||
+ timedelta(hours=settings.config.nightly_copilot_callback_token_ttl_hours),
|
||||
}
|
||||
token = await chat_db().create_chat_session_callback_token(
|
||||
user_id=session.user_id,
|
||||
source_session_id=session.session_id,
|
||||
callback_session_message=callback_session_message,
|
||||
expires_at=datetime.now(UTC)
|
||||
+ timedelta(hours=settings.config.nightly_copilot_callback_token_ttl_hours),
|
||||
)
|
||||
return token.id
|
||||
|
||||
|
||||
def _build_session_link(session_id: str, *, show_autopilot: bool) -> str:
|
||||
base_url = _get_frontend_base_url()
|
||||
base_url = get_frontend_base_url()
|
||||
suffix = "&showAutopilot=1" if show_autopilot else ""
|
||||
return f"{base_url}/copilot?sessionId={session_id}{suffix}"
|
||||
|
||||
|
||||
def _build_callback_link(token_id: str) -> str:
|
||||
return f"{_get_frontend_base_url()}/copilot?callbackToken={token_id}"
|
||||
return f"{get_frontend_base_url()}/copilot?callbackToken={token_id}"
|
||||
|
||||
|
||||
def _get_completion_email_template_name(start_type: ChatSessionStartType) -> str:
|
||||
@@ -782,8 +743,12 @@ async def _send_completion_email(session: ChatSession) -> None:
|
||||
report = session.completion_report
|
||||
if report is None:
|
||||
raise ValueError("Missing completion report")
|
||||
user = await prisma.models.User.prisma().find_unique(where={"id": session.user_id})
|
||||
if user is None:
|
||||
try:
|
||||
user = await user_db().get_user_by_id(session.user_id)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"User {session.user_id} not found") from exc
|
||||
|
||||
if not user.email:
|
||||
raise ValueError(f"User {session.user_id} not found")
|
||||
|
||||
approval_cta = report.has_pending_approvals
|
||||
@@ -792,8 +757,8 @@ async def _send_completion_email(session: ChatSession) -> None:
|
||||
cta_url = _build_session_link(session.session_id, show_autopilot=True)
|
||||
cta_label = "Review in Copilot"
|
||||
else:
|
||||
token = await _create_callback_token(session)
|
||||
cta_url = _build_callback_link(token.id)
|
||||
token_id = await _create_callback_token(session)
|
||||
cta_url = _build_callback_link(token_id)
|
||||
cta_label = (
|
||||
"Try Copilot"
|
||||
if session.start_type == ChatSessionStartType.AUTOPILOT_INVITE_CTA
|
||||
@@ -815,20 +780,12 @@ async def _send_completion_email(session: ChatSession) -> None:
|
||||
)
|
||||
|
||||
|
||||
async def send_nightly_copilot_emails() -> int:
|
||||
candidates = await prisma.models.ChatSession.prisma().find_many(
|
||||
where={
|
||||
"startType": {"not": ChatSessionStartType.MANUAL.value},
|
||||
"notificationEmailSentAt": None,
|
||||
"notificationEmailSkippedAt": None,
|
||||
},
|
||||
order={"updatedAt": "asc"},
|
||||
take=200,
|
||||
)
|
||||
async def _send_nightly_copilot_emails() -> int:
|
||||
candidates = await chat_db().get_pending_notification_chat_sessions(limit=200)
|
||||
|
||||
processed_count = 0
|
||||
for candidate in candidates:
|
||||
session = await get_chat_session(candidate.id)
|
||||
session = await get_chat_session(candidate.session_id)
|
||||
if session is None or session.is_manual:
|
||||
continue
|
||||
|
||||
@@ -891,32 +848,31 @@ async def send_nightly_copilot_emails() -> int:
|
||||
return processed_count
|
||||
|
||||
|
||||
async def send_nightly_copilot_emails() -> int:
|
||||
return await _send_nightly_copilot_emails()
|
||||
|
||||
|
||||
async def consume_callback_token(
|
||||
token_id: str,
|
||||
user_id: str,
|
||||
) -> CallbackTokenConsumeResult:
|
||||
token = await prisma.models.ChatSessionCallbackToken.prisma().find_unique(
|
||||
where={"id": token_id}
|
||||
)
|
||||
if token is None or token.userId != user_id:
|
||||
token = await chat_db().get_chat_session_callback_token(token_id)
|
||||
if token is None or token.user_id != user_id:
|
||||
raise ValueError("Callback token not found")
|
||||
if token.expiresAt <= datetime.now(UTC):
|
||||
if token.expires_at <= datetime.now(UTC):
|
||||
raise ValueError("Callback token has expired")
|
||||
|
||||
if token.consumedSessionId:
|
||||
return CallbackTokenConsumeResult(session_id=token.consumedSessionId)
|
||||
if token.consumed_session_id:
|
||||
return CallbackTokenConsumeResult(session_id=token.consumed_session_id)
|
||||
|
||||
session = await create_chat_session(
|
||||
user_id,
|
||||
initial_messages=[
|
||||
ChatMessage(role="assistant", content=token.callbackSessionMessage)
|
||||
ChatMessage(role="assistant", content=token.callback_session_message)
|
||||
],
|
||||
)
|
||||
await prisma.models.ChatSessionCallbackToken.prisma().update(
|
||||
where={"id": token_id},
|
||||
data={
|
||||
"consumedAt": datetime.now(UTC),
|
||||
"consumedSessionId": session.session_id,
|
||||
},
|
||||
await chat_db().mark_chat_session_callback_token_consumed(
|
||||
token_id,
|
||||
session.session_id,
|
||||
)
|
||||
return CallbackTokenConsumeResult(session_id=session.session_id)
|
||||
|
||||
@@ -69,18 +69,34 @@ def test_crosses_local_midnight_supports_offset_timezones() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_crosses_local_midnight_only_triggers_once_across_dst_shift() -> None:
|
||||
assert _crosses_local_midnight(
|
||||
datetime(2026, 3, 8, 4, 30, tzinfo=UTC),
|
||||
datetime(2026, 3, 8, 5, 0, tzinfo=UTC),
|
||||
"America/New_York",
|
||||
) == date(2026, 3, 8)
|
||||
assert (
|
||||
_crosses_local_midnight(
|
||||
datetime(2026, 3, 8, 5, 0, tzinfo=UTC),
|
||||
datetime(2026, 3, 8, 5, 30, tzinfo=UTC),
|
||||
"America/New_York",
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_recent_manual_session_context_strips_internal_content(
|
||||
mocker,
|
||||
) -> None:
|
||||
session = SimpleNamespace(
|
||||
id="sess-1",
|
||||
session_id="sess-1",
|
||||
title="Manual work",
|
||||
updatedAt=datetime(2026, 3, 14, 9, 0, tzinfo=UTC),
|
||||
updated_at=datetime(2026, 3, 14, 9, 0, tzinfo=UTC),
|
||||
)
|
||||
session_prisma = SimpleNamespace(find_many=AsyncMock(return_value=[session]))
|
||||
message_prisma = SimpleNamespace(
|
||||
find_many=AsyncMock(
|
||||
chat_store = SimpleNamespace(
|
||||
get_manual_chat_sessions_since=AsyncMock(return_value=[session]),
|
||||
get_chat_messages_since=AsyncMock(
|
||||
return_value=[
|
||||
SimpleNamespace(
|
||||
role="user",
|
||||
@@ -95,10 +111,9 @@ async def test_get_recent_manual_session_context_strips_internal_content(
|
||||
content="Completed a useful task for the user.",
|
||||
),
|
||||
]
|
||||
)
|
||||
),
|
||||
)
|
||||
mocker.patch("prisma.models.ChatSession.prisma", return_value=session_prisma)
|
||||
mocker.patch("prisma.models.ChatMessage.prisma", return_value=message_prisma)
|
||||
mocker.patch("backend.copilot.autopilot.chat_db", return_value=chat_store)
|
||||
|
||||
context = await _get_recent_manual_session_context(
|
||||
"user-1",
|
||||
@@ -335,12 +350,12 @@ async def test_dispatch_nightly_copilot_respects_cohort_priority(mocker) -> None
|
||||
)
|
||||
|
||||
invited = SimpleNamespace(
|
||||
authUserId="invite-user",
|
||||
auth_user_id="invite-user",
|
||||
status=prisma.enums.InvitedUserStatus.INVITED,
|
||||
createdAt=fixed_now - timedelta(hours=72),
|
||||
created_at=fixed_now - timedelta(hours=72),
|
||||
)
|
||||
user_prisma = SimpleNamespace(
|
||||
find_many=AsyncMock(
|
||||
user_store = SimpleNamespace(
|
||||
list_users=AsyncMock(
|
||||
return_value=[
|
||||
invite_user,
|
||||
nightly_user,
|
||||
@@ -349,9 +364,14 @@ async def test_dispatch_nightly_copilot_respects_cohort_priority(mocker) -> None
|
||||
]
|
||||
)
|
||||
)
|
||||
invite_prisma = SimpleNamespace(find_many=AsyncMock(return_value=[invited]))
|
||||
mocker.patch("prisma.models.User.prisma", return_value=user_prisma)
|
||||
mocker.patch("prisma.models.InvitedUser.prisma", return_value=invite_prisma)
|
||||
invited_user_store = SimpleNamespace(
|
||||
list_invited_users_for_auth_users=AsyncMock(return_value=[invited])
|
||||
)
|
||||
mocker.patch("backend.copilot.autopilot.user_db", return_value=user_store)
|
||||
mocker.patch(
|
||||
"backend.copilot.autopilot.invited_user_db",
|
||||
return_value=invited_user_store,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.autopilot.is_feature_enabled",
|
||||
new_callable=AsyncMock,
|
||||
@@ -409,10 +429,11 @@ async def test_send_nightly_copilot_emails_queues_repair_for_missing_report(
|
||||
mocker,
|
||||
) -> None:
|
||||
session = _build_autopilot_session()
|
||||
candidate = SimpleNamespace(id=session.session_id)
|
||||
|
||||
chat_session_prisma = SimpleNamespace(find_many=AsyncMock(return_value=[candidate]))
|
||||
mocker.patch("prisma.models.ChatSession.prisma", return_value=chat_session_prisma)
|
||||
candidate = SimpleNamespace(session_id=session.session_id)
|
||||
chat_store = SimpleNamespace(
|
||||
get_pending_notification_chat_sessions=AsyncMock(return_value=[candidate])
|
||||
)
|
||||
mocker.patch("backend.copilot.autopilot.chat_db", return_value=chat_store)
|
||||
mocker.patch(
|
||||
"backend.copilot.autopilot.get_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
@@ -454,10 +475,11 @@ async def test_send_nightly_copilot_emails_sends_and_marks_sent(mocker) -> None:
|
||||
pending_approval_graph_exec_id=None,
|
||||
saved_at=datetime.now(UTC),
|
||||
)
|
||||
candidate = SimpleNamespace(id=session.session_id)
|
||||
|
||||
chat_session_prisma = SimpleNamespace(find_many=AsyncMock(return_value=[candidate]))
|
||||
mocker.patch("prisma.models.ChatSession.prisma", return_value=chat_session_prisma)
|
||||
candidate = SimpleNamespace(session_id=session.session_id)
|
||||
chat_store = SimpleNamespace(
|
||||
get_pending_notification_chat_sessions=AsyncMock(return_value=[candidate])
|
||||
)
|
||||
mocker.patch("backend.copilot.autopilot.chat_db", return_value=chat_store)
|
||||
mocker.patch(
|
||||
"backend.copilot.autopilot.get_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
@@ -508,10 +530,11 @@ async def test_send_nightly_copilot_emails_skips_when_should_not_notify(mocker)
|
||||
pending_approval_graph_exec_id=None,
|
||||
saved_at=datetime.now(UTC),
|
||||
)
|
||||
candidate = SimpleNamespace(id=session.session_id)
|
||||
|
||||
chat_session_prisma = SimpleNamespace(find_many=AsyncMock(return_value=[candidate]))
|
||||
mocker.patch("prisma.models.ChatSession.prisma", return_value=chat_session_prisma)
|
||||
candidate = SimpleNamespace(session_id=session.session_id)
|
||||
chat_store = SimpleNamespace(
|
||||
get_pending_notification_chat_sessions=AsyncMock(return_value=[candidate])
|
||||
)
|
||||
mocker.patch("backend.copilot.autopilot.chat_db", return_value=chat_store)
|
||||
mocker.patch(
|
||||
"backend.copilot.autopilot.get_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
@@ -548,17 +571,17 @@ async def test_send_nightly_copilot_emails_skips_when_should_not_notify(mocker)
|
||||
async def test_consume_callback_token_reuses_existing_session(mocker) -> None:
|
||||
token = SimpleNamespace(
|
||||
id="token-1",
|
||||
userId="user-1",
|
||||
expiresAt=datetime.now(UTC) + timedelta(hours=1),
|
||||
consumedSessionId="sess-existing",
|
||||
user_id="user-1",
|
||||
expires_at=datetime.now(UTC) + timedelta(hours=1),
|
||||
consumed_session_id="sess-existing",
|
||||
)
|
||||
token_prisma = SimpleNamespace(
|
||||
find_unique=AsyncMock(return_value=token),
|
||||
update=AsyncMock(),
|
||||
chat_store = SimpleNamespace(
|
||||
get_chat_session_callback_token=AsyncMock(return_value=token),
|
||||
mark_chat_session_callback_token_consumed=AsyncMock(),
|
||||
)
|
||||
mocker.patch(
|
||||
"prisma.models.ChatSessionCallbackToken.prisma",
|
||||
return_value=token_prisma,
|
||||
"backend.copilot.autopilot.chat_db",
|
||||
return_value=chat_store,
|
||||
)
|
||||
create_chat_session = mocker.patch(
|
||||
"backend.copilot.autopilot.create_chat_session",
|
||||
@@ -569,26 +592,26 @@ async def test_consume_callback_token_reuses_existing_session(mocker) -> None:
|
||||
|
||||
assert result.session_id == "sess-existing"
|
||||
create_chat_session.assert_not_called()
|
||||
token_prisma.update.assert_not_called()
|
||||
chat_store.mark_chat_session_callback_token_consumed.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consume_callback_token_creates_manual_session(mocker) -> None:
|
||||
token = SimpleNamespace(
|
||||
id="token-1",
|
||||
userId="user-1",
|
||||
expiresAt=datetime.now(UTC) + timedelta(hours=1),
|
||||
consumedSessionId=None,
|
||||
callbackSessionMessage="Open this chat",
|
||||
user_id="user-1",
|
||||
expires_at=datetime.now(UTC) + timedelta(hours=1),
|
||||
consumed_session_id=None,
|
||||
callback_session_message="Open this chat",
|
||||
)
|
||||
token_prisma = SimpleNamespace(
|
||||
find_unique=AsyncMock(return_value=token),
|
||||
update=AsyncMock(),
|
||||
chat_store = SimpleNamespace(
|
||||
get_chat_session_callback_token=AsyncMock(return_value=token),
|
||||
mark_chat_session_callback_token_consumed=AsyncMock(),
|
||||
)
|
||||
created_session = ChatSession.new("user-1")
|
||||
mocker.patch(
|
||||
"prisma.models.ChatSessionCallbackToken.prisma",
|
||||
return_value=token_prisma,
|
||||
"backend.copilot.autopilot.chat_db",
|
||||
return_value=chat_store,
|
||||
)
|
||||
create_chat_session = mocker.patch(
|
||||
"backend.copilot.autopilot.create_chat_session",
|
||||
@@ -603,24 +626,27 @@ async def test_consume_callback_token_creates_manual_session(mocker) -> None:
|
||||
create_kwargs = create_chat_session.await_args.kwargs
|
||||
assert create_kwargs["initial_messages"][0].role == "assistant"
|
||||
assert create_kwargs["initial_messages"][0].content == "Open this chat"
|
||||
token_prisma.update.assert_awaited_once()
|
||||
chat_store.mark_chat_session_callback_token_consumed.assert_awaited_once_with(
|
||||
"token-1",
|
||||
created_session.session_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consume_callback_token_rejects_expired_token(mocker) -> None:
|
||||
token = SimpleNamespace(
|
||||
id="token-1",
|
||||
userId="user-1",
|
||||
expiresAt=datetime.now(UTC) - timedelta(minutes=1),
|
||||
consumedSessionId=None,
|
||||
callbackSessionMessage="Open this chat",
|
||||
user_id="user-1",
|
||||
expires_at=datetime.now(UTC) - timedelta(minutes=1),
|
||||
consumed_session_id=None,
|
||||
callback_session_message="Open this chat",
|
||||
)
|
||||
token_prisma = SimpleNamespace(
|
||||
find_unique=AsyncMock(return_value=token),
|
||||
chat_store = SimpleNamespace(
|
||||
get_chat_session_callback_token=AsyncMock(return_value=token),
|
||||
)
|
||||
mocker.patch(
|
||||
"prisma.models.ChatSessionCallbackToken.prisma",
|
||||
return_value=token_prisma,
|
||||
"backend.copilot.autopilot.chat_db",
|
||||
return_value=chat_store,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="expired"):
|
||||
|
||||
@@ -8,12 +8,14 @@ from typing import Any
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from prisma.models import ChatSessionCallbackToken as PrismaChatSessionCallbackToken
|
||||
from prisma.types import (
|
||||
ChatMessageCreateInput,
|
||||
ChatSessionCreateInput,
|
||||
ChatSessionUpdateInput,
|
||||
ChatSessionWhereInput,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data import db
|
||||
from backend.util.json import SafeJson, sanitize_string
|
||||
@@ -25,6 +27,31 @@ logger = logging.getLogger(__name__)
|
||||
_UNSET = object()
|
||||
|
||||
|
||||
class ChatSessionCallbackTokenInfo(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
source_session_id: str | None = None
|
||||
callback_session_message: str
|
||||
expires_at: datetime
|
||||
consumed_at: datetime | None = None
|
||||
consumed_session_id: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_db(
|
||||
cls,
|
||||
token: PrismaChatSessionCallbackToken,
|
||||
) -> "ChatSessionCallbackTokenInfo":
|
||||
return cls(
|
||||
id=token.id,
|
||||
user_id=token.userId,
|
||||
source_session_id=token.sourceSessionId,
|
||||
callback_session_message=token.callbackSessionMessage,
|
||||
expires_at=token.expiresAt,
|
||||
consumed_at=token.consumedAt,
|
||||
consumed_session_id=token.consumedSessionId,
|
||||
)
|
||||
|
||||
|
||||
async def get_chat_session(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session by ID from the database."""
|
||||
session = await PrismaChatSession.prisma().find_unique(
|
||||
@@ -34,6 +61,67 @@ async def get_chat_session(session_id: str) -> ChatSession | None:
|
||||
return ChatSession.from_db(session) if session else None
|
||||
|
||||
|
||||
async def has_recent_manual_message(user_id: str, since: datetime) -> bool:
|
||||
message = await PrismaChatMessage.prisma().find_first(
|
||||
where={
|
||||
"role": "user",
|
||||
"createdAt": {"gte": since},
|
||||
"Session": {
|
||||
"is": {
|
||||
"userId": user_id,
|
||||
"startType": ChatSessionStartType.MANUAL.value,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
return message is not None
|
||||
|
||||
|
||||
async def has_session_since(user_id: str, since: datetime) -> bool:
|
||||
session = await PrismaChatSession.prisma().find_first(
|
||||
where={"userId": user_id, "createdAt": {"gte": since}}
|
||||
)
|
||||
return session is not None
|
||||
|
||||
|
||||
async def session_exists_for_execution_tag(user_id: str, execution_tag: str) -> bool:
|
||||
session = await PrismaChatSession.prisma().find_first(
|
||||
where={"userId": user_id, "executionTag": execution_tag}
|
||||
)
|
||||
return session is not None
|
||||
|
||||
|
||||
async def get_manual_chat_sessions_since(
|
||||
user_id: str,
|
||||
since_utc: datetime,
|
||||
limit: int,
|
||||
) -> list[ChatSessionInfo]:
|
||||
sessions = await PrismaChatSession.prisma().find_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"startType": ChatSessionStartType.MANUAL.value,
|
||||
"updatedAt": {"gte": since_utc},
|
||||
},
|
||||
order={"updatedAt": "desc"},
|
||||
take=limit,
|
||||
)
|
||||
return [ChatSessionInfo.from_db(session) for session in sessions]
|
||||
|
||||
|
||||
async def get_chat_messages_since(
|
||||
session_id: str,
|
||||
since_utc: datetime,
|
||||
) -> list[ChatMessage]:
|
||||
messages = await PrismaChatMessage.prisma().find_many(
|
||||
where={
|
||||
"sessionId": session_id,
|
||||
"createdAt": {"gte": since_utc},
|
||||
},
|
||||
order={"sequence": "asc"},
|
||||
)
|
||||
return [ChatMessage.from_db(message) for message in messages]
|
||||
|
||||
|
||||
def _build_chat_message_create_input(
|
||||
*,
|
||||
session_id: str,
|
||||
@@ -316,6 +404,21 @@ async def get_user_chat_sessions(
|
||||
return [ChatSessionInfo.from_db(s) for s in prisma_sessions]
|
||||
|
||||
|
||||
async def get_pending_notification_chat_sessions(
|
||||
limit: int = 200,
|
||||
) -> list[ChatSessionInfo]:
|
||||
sessions = await PrismaChatSession.prisma().find_many(
|
||||
where={
|
||||
"startType": {"not": ChatSessionStartType.MANUAL.value},
|
||||
"notificationEmailSentAt": None,
|
||||
"notificationEmailSkippedAt": None,
|
||||
},
|
||||
order={"updatedAt": "asc"},
|
||||
take=limit,
|
||||
)
|
||||
return [ChatSessionInfo.from_db(session) for session in sessions]
|
||||
|
||||
|
||||
async def get_user_session_count(
|
||||
user_id: str,
|
||||
with_auto: bool = False,
|
||||
@@ -416,3 +519,42 @@ async def update_tool_message_content(
|
||||
f"tool_call_id {tool_call_id}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def create_chat_session_callback_token(
|
||||
user_id: str,
|
||||
source_session_id: str,
|
||||
callback_session_message: str,
|
||||
expires_at: datetime,
|
||||
) -> ChatSessionCallbackTokenInfo:
|
||||
token = await PrismaChatSessionCallbackToken.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"sourceSessionId": source_session_id,
|
||||
"callbackSessionMessage": callback_session_message,
|
||||
"expiresAt": expires_at,
|
||||
}
|
||||
)
|
||||
return ChatSessionCallbackTokenInfo.from_db(token)
|
||||
|
||||
|
||||
async def get_chat_session_callback_token(
|
||||
token_id: str,
|
||||
) -> ChatSessionCallbackTokenInfo | None:
|
||||
token = await PrismaChatSessionCallbackToken.prisma().find_unique(
|
||||
where={"id": token_id}
|
||||
)
|
||||
return ChatSessionCallbackTokenInfo.from_db(token) if token else None
|
||||
|
||||
|
||||
async def mark_chat_session_callback_token_consumed(
|
||||
token_id: str,
|
||||
consumed_session_id: str,
|
||||
) -> None:
|
||||
await PrismaChatSessionCallbackToken.prisma().update(
|
||||
where={"id": token_id},
|
||||
data={
|
||||
"consumedAt": datetime.now(UTC),
|
||||
"consumedSessionId": consumed_session_id,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -169,7 +169,6 @@ class ChatSession(ChatSessionInfo):
|
||||
def new(
|
||||
cls,
|
||||
user_id: str,
|
||||
*,
|
||||
start_type: ChatSessionStartType = ChatSessionStartType.MANUAL,
|
||||
execution_tag: str | None = None,
|
||||
session_config: ChatSessionConfig | None = None,
|
||||
@@ -678,7 +677,6 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
|
||||
|
||||
async def create_chat_session(
|
||||
user_id: str,
|
||||
*,
|
||||
start_type: ChatSessionStartType = ChatSessionStartType.MANUAL,
|
||||
execution_tag: str | None = None,
|
||||
session_config: ChatSessionConfig | None = None,
|
||||
|
||||
@@ -2,12 +2,10 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from prisma.enums import ReviewStatus
|
||||
from prisma.models import PendingHumanReview
|
||||
|
||||
from backend.copilot.constants import COPILOT_SESSION_PREFIX
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.session_types import CompletionReportInput
|
||||
from backend.data.db_accessors import review_db
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import CompletionReportSavedResponse, ErrorResponse, ToolResponseBase
|
||||
@@ -64,11 +62,9 @@ class CompletionReportTool(BaseTool):
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
pending_approval_count = await PendingHumanReview.prisma().count(
|
||||
where={
|
||||
"graphExecId": f"{COPILOT_SESSION_PREFIX}{session.session_id}",
|
||||
"status": ReviewStatus.WAITING,
|
||||
}
|
||||
pending_approval_count = await review_db().count_pending_reviews_for_graph_exec(
|
||||
f"{COPILOT_SESSION_PREFIX}{session.session_id}",
|
||||
session.user_id,
|
||||
)
|
||||
|
||||
if pending_approval_count > 0 and not report.approval_summary:
|
||||
|
||||
@@ -39,11 +39,11 @@ async def test_completion_report_requires_approval_summary_when_pending(
|
||||
start_type=ChatSessionStartType.AUTOPILOT_NIGHTLY,
|
||||
)
|
||||
|
||||
pending_reviews = Mock()
|
||||
pending_reviews.count = AsyncMock(return_value=2)
|
||||
review_store = Mock()
|
||||
review_store.count_pending_reviews_for_graph_exec = AsyncMock(return_value=2)
|
||||
mocker.patch(
|
||||
"backend.copilot.tools.completion_report.PendingHumanReview.prisma",
|
||||
return_value=pending_reviews,
|
||||
"backend.copilot.tools.completion_report.review_db",
|
||||
return_value=review_store,
|
||||
)
|
||||
|
||||
response = await tool._execute(
|
||||
@@ -71,11 +71,11 @@ async def test_completion_report_succeeds_without_pending_approvals(
|
||||
start_type=ChatSessionStartType.AUTOPILOT_CALLBACK,
|
||||
)
|
||||
|
||||
pending_reviews = Mock()
|
||||
pending_reviews.count = AsyncMock(return_value=0)
|
||||
review_store = Mock()
|
||||
review_store.count_pending_reviews_for_graph_exec = AsyncMock(return_value=0)
|
||||
mocker.patch(
|
||||
"backend.copilot.tools.completion_report.PendingHumanReview.prisma",
|
||||
return_value=pending_reviews,
|
||||
"backend.copilot.tools.completion_report.review_db",
|
||||
return_value=review_store,
|
||||
)
|
||||
|
||||
response = await tool._execute(
|
||||
|
||||
@@ -92,6 +92,19 @@ def user_db():
|
||||
return user_db
|
||||
|
||||
|
||||
def invited_user_db():
|
||||
if db.is_connected():
|
||||
from backend.data import invited_user as _invited_user_db
|
||||
|
||||
invited_user_db = _invited_user_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
invited_user_db = get_database_manager_async_client()
|
||||
|
||||
return invited_user_db
|
||||
|
||||
|
||||
def understanding_db():
|
||||
if db.is_connected():
|
||||
from backend.data import understanding as _understanding_db
|
||||
|
||||
@@ -79,6 +79,7 @@ from backend.data.graph import (
|
||||
from backend.data.human_review import (
|
||||
cancel_pending_reviews_for_execution,
|
||||
check_approval,
|
||||
count_pending_reviews_for_graph_exec,
|
||||
delete_review_by_node_exec_id,
|
||||
get_or_create_human_review,
|
||||
get_pending_reviews_for_execution,
|
||||
@@ -86,6 +87,7 @@ from backend.data.human_review import (
|
||||
has_pending_reviews_for_graph_exec,
|
||||
update_review_processed_status,
|
||||
)
|
||||
from backend.data.invited_user import list_invited_users_for_auth_users
|
||||
from backend.data.notifications import (
|
||||
clear_all_user_notification_batches,
|
||||
create_or_add_to_user_notification_batch,
|
||||
@@ -107,6 +109,7 @@ from backend.data.user import (
|
||||
get_user_email_verification,
|
||||
get_user_integrations,
|
||||
get_user_notification_preference,
|
||||
list_users,
|
||||
update_user_integrations,
|
||||
)
|
||||
from backend.data.workspace import (
|
||||
@@ -115,6 +118,7 @@ from backend.data.workspace import (
|
||||
get_or_create_workspace,
|
||||
get_workspace_file,
|
||||
get_workspace_file_by_path,
|
||||
get_workspace_files_by_ids,
|
||||
list_workspace_files,
|
||||
soft_delete_workspace_file,
|
||||
)
|
||||
@@ -237,6 +241,7 @@ class DatabaseManager(AppService):
|
||||
|
||||
# ============ User + Integrations ============ #
|
||||
get_user_by_id = _(get_user_by_id)
|
||||
list_users = _(list_users)
|
||||
get_user_integrations = _(get_user_integrations)
|
||||
update_user_integrations = _(update_user_integrations)
|
||||
|
||||
@@ -249,6 +254,7 @@ class DatabaseManager(AppService):
|
||||
# ============ Human In The Loop ============ #
|
||||
cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution)
|
||||
check_approval = _(check_approval)
|
||||
count_pending_reviews_for_graph_exec = _(count_pending_reviews_for_graph_exec)
|
||||
delete_review_by_node_exec_id = _(delete_review_by_node_exec_id)
|
||||
get_or_create_human_review = _(get_or_create_human_review)
|
||||
get_pending_reviews_for_execution = _(get_pending_reviews_for_execution)
|
||||
@@ -313,12 +319,16 @@ class DatabaseManager(AppService):
|
||||
# ============ Workspace ============ #
|
||||
count_workspace_files = _(count_workspace_files)
|
||||
create_workspace_file = _(create_workspace_file)
|
||||
get_workspace_files_by_ids = _(get_workspace_files_by_ids)
|
||||
get_or_create_workspace = _(get_or_create_workspace)
|
||||
get_workspace_file = _(get_workspace_file)
|
||||
get_workspace_file_by_path = _(get_workspace_file_by_path)
|
||||
list_workspace_files = _(list_workspace_files)
|
||||
soft_delete_workspace_file = _(soft_delete_workspace_file)
|
||||
|
||||
# ============ Invited Users ============ #
|
||||
list_invited_users_for_auth_users = _(list_invited_users_for_auth_users)
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = _(get_business_understanding)
|
||||
upsert_business_understanding = _(upsert_business_understanding)
|
||||
@@ -328,8 +338,21 @@ class DatabaseManager(AppService):
|
||||
update_block_optimized_description = _(update_block_optimized_description)
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
get_chat_messages_since = _(chat_db.get_chat_messages_since)
|
||||
get_chat_session_callback_token = _(chat_db.get_chat_session_callback_token)
|
||||
get_chat_session = _(chat_db.get_chat_session)
|
||||
create_chat_session_callback_token = _(chat_db.create_chat_session_callback_token)
|
||||
create_chat_session = _(chat_db.create_chat_session)
|
||||
get_manual_chat_sessions_since = _(chat_db.get_manual_chat_sessions_since)
|
||||
get_pending_notification_chat_sessions = _(
|
||||
chat_db.get_pending_notification_chat_sessions
|
||||
)
|
||||
has_recent_manual_message = _(chat_db.has_recent_manual_message)
|
||||
has_session_since = _(chat_db.has_session_since)
|
||||
mark_chat_session_callback_token_consumed = _(
|
||||
chat_db.mark_chat_session_callback_token_consumed
|
||||
)
|
||||
session_exists_for_execution_tag = _(chat_db.session_exists_for_execution_tag)
|
||||
update_chat_session = _(chat_db.update_chat_session)
|
||||
add_chat_message = _(chat_db.add_chat_message)
|
||||
add_chat_messages_batch = _(chat_db.add_chat_messages_batch)
|
||||
@@ -374,10 +397,12 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
get_marketplace_graphs_for_monitoring = _(d.get_marketplace_graphs_for_monitoring)
|
||||
|
||||
# Human In The Loop
|
||||
count_pending_reviews_for_graph_exec = _(d.count_pending_reviews_for_graph_exec)
|
||||
has_pending_reviews_for_graph_exec = _(d.has_pending_reviews_for_graph_exec)
|
||||
|
||||
# User Emails
|
||||
get_user_email_by_id = _(d.get_user_email_by_id)
|
||||
list_users = _(d.list_users)
|
||||
|
||||
# Library
|
||||
list_library_agents = _(d.list_library_agents)
|
||||
@@ -433,12 +458,14 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
|
||||
# ============ User + Integrations ============ #
|
||||
get_user_by_id = d.get_user_by_id
|
||||
list_users = d.list_users
|
||||
get_user_integrations = d.get_user_integrations
|
||||
update_user_integrations = d.update_user_integrations
|
||||
|
||||
# ============ Human In The Loop ============ #
|
||||
cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution
|
||||
check_approval = d.check_approval
|
||||
count_pending_reviews_for_graph_exec = d.count_pending_reviews_for_graph_exec
|
||||
delete_review_by_node_exec_id = d.delete_review_by_node_exec_id
|
||||
get_or_create_human_review = d.get_or_create_human_review
|
||||
get_pending_reviews_for_execution = d.get_pending_reviews_for_execution
|
||||
@@ -506,12 +533,16 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
# ============ Workspace ============ #
|
||||
count_workspace_files = d.count_workspace_files
|
||||
create_workspace_file = d.create_workspace_file
|
||||
get_workspace_files_by_ids = d.get_workspace_files_by_ids
|
||||
get_or_create_workspace = d.get_or_create_workspace
|
||||
get_workspace_file = d.get_workspace_file
|
||||
get_workspace_file_by_path = d.get_workspace_file_by_path
|
||||
list_workspace_files = d.list_workspace_files
|
||||
soft_delete_workspace_file = d.soft_delete_workspace_file
|
||||
|
||||
# ============ Invited Users ============ #
|
||||
list_invited_users_for_auth_users = d.list_invited_users_for_auth_users
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = d.get_business_understanding
|
||||
upsert_business_understanding = d.upsert_business_understanding
|
||||
@@ -520,8 +551,19 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
get_blocks_needing_optimization = d.get_blocks_needing_optimization
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
get_chat_messages_since = d.get_chat_messages_since
|
||||
get_chat_session_callback_token = d.get_chat_session_callback_token
|
||||
get_chat_session = d.get_chat_session
|
||||
create_chat_session_callback_token = d.create_chat_session_callback_token
|
||||
create_chat_session = d.create_chat_session
|
||||
get_manual_chat_sessions_since = d.get_manual_chat_sessions_since
|
||||
get_pending_notification_chat_sessions = d.get_pending_notification_chat_sessions
|
||||
has_recent_manual_message = d.has_recent_manual_message
|
||||
has_session_since = d.has_session_since
|
||||
mark_chat_session_callback_token_consumed = (
|
||||
d.mark_chat_session_callback_token_consumed
|
||||
)
|
||||
session_exists_for_execution_tag = d.session_exists_for_execution_tag
|
||||
update_chat_session = d.update_chat_session
|
||||
add_chat_message = d.add_chat_message
|
||||
add_chat_messages_batch = d.add_chat_messages_batch
|
||||
|
||||
@@ -342,6 +342,19 @@ async def has_pending_reviews_for_graph_exec(graph_exec_id: str) -> bool:
|
||||
return count > 0
|
||||
|
||||
|
||||
async def count_pending_reviews_for_graph_exec(
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
) -> int:
|
||||
return await PendingHumanReview.prisma().count(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"graphExecId": graph_exec_id,
|
||||
"status": ReviewStatus.WAITING,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_node_id(node_exec_id: str, get_node_execution) -> str:
|
||||
"""Resolve node_id from a node_exec_id.
|
||||
|
||||
|
||||
@@ -215,6 +215,18 @@ async def list_invited_users(
|
||||
return [InvitedUserRecord.from_db(iu) for iu in invited_users], total
|
||||
|
||||
|
||||
async def list_invited_users_for_auth_users(
|
||||
auth_user_ids: list[str],
|
||||
) -> list[InvitedUserRecord]:
|
||||
if not auth_user_ids:
|
||||
return []
|
||||
|
||||
invited_users = await prisma.models.InvitedUser.prisma().find_many(
|
||||
where={"authUserId": {"in": auth_user_ids}}
|
||||
)
|
||||
return [InvitedUserRecord.from_db(invited_user) for invited_user in invited_users]
|
||||
|
||||
|
||||
async def create_invited_user(
|
||||
email: str, name: Optional[str] = None
|
||||
) -> InvitedUserRecord:
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Data models and access layer for user business understanding."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, cast
|
||||
@@ -35,7 +33,7 @@ def _json_to_list(value: Any) -> list[str]:
|
||||
|
||||
def parse_business_understanding_input(
|
||||
payload: Any,
|
||||
) -> BusinessUnderstandingInput | None:
|
||||
) -> "BusinessUnderstandingInput | None":
|
||||
if payload is None:
|
||||
return None
|
||||
|
||||
|
||||
@@ -62,6 +62,14 @@ async def get_user_by_id(user_id: str) -> User:
|
||||
return User.from_db(user)
|
||||
|
||||
|
||||
async def list_users() -> list[User]:
|
||||
try:
|
||||
users = await PrismaUser.prisma().find_many()
|
||||
return [User.from_db(user) for user in users]
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to list users: {e}") from e
|
||||
|
||||
|
||||
async def get_user_email_by_id(user_id: str) -> Optional[str]:
|
||||
try:
|
||||
user = await prisma.user.find_unique(where={"id": user_id})
|
||||
|
||||
@@ -254,6 +254,23 @@ async def list_workspace_files(
|
||||
return [WorkspaceFile.from_db(f) for f in files]
|
||||
|
||||
|
||||
async def get_workspace_files_by_ids(
|
||||
workspace_id: str,
|
||||
file_ids: list[str],
|
||||
) -> list[WorkspaceFile]:
|
||||
if not file_ids:
|
||||
return []
|
||||
|
||||
files = await UserWorkspaceFile.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": file_ids},
|
||||
"workspaceId": workspace_id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
)
|
||||
return [WorkspaceFile.from_db(file) for file in files]
|
||||
|
||||
|
||||
async def count_workspace_files(
|
||||
workspace_id: str,
|
||||
path_prefix: Optional[str] = None,
|
||||
|
||||
@@ -15,6 +15,7 @@ from backend.data.notifications import (
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.text import TextFormatter
|
||||
from backend.util.url import get_frontend_base_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
@@ -48,6 +49,25 @@ class EmailSender:
|
||||
|
||||
MAX_EMAIL_CHARS = 5_000_000 # ~5MB buffer
|
||||
|
||||
def _get_unsubscribe_link(self, user_unsubscribe_link: str | None) -> str:
|
||||
return user_unsubscribe_link or f"{get_frontend_base_url()}/profile/settings"
|
||||
|
||||
def _format_template_email(
|
||||
self,
|
||||
*,
|
||||
subject_template: str,
|
||||
content_template: str,
|
||||
data: Any,
|
||||
unsubscribe_link: str,
|
||||
) -> tuple[str, str]:
|
||||
return self.formatter.format_email(
|
||||
base_template=self._read_template("templates/base.html.jinja2"),
|
||||
subject_template=subject_template,
|
||||
content_template=content_template,
|
||||
data=data,
|
||||
unsubscribe_link=unsubscribe_link,
|
||||
)
|
||||
|
||||
def _build_large_output_summary(
|
||||
self,
|
||||
data: (
|
||||
@@ -102,14 +122,10 @@ class EmailSender:
|
||||
logger.warning("Postmark client not initialized, email not sent")
|
||||
return
|
||||
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
unsubscribe_link = user_unsubscribe_link or f"{base_url}/profile/settings"
|
||||
unsubscribe_link = self._get_unsubscribe_link(user_unsubscribe_link)
|
||||
|
||||
_, full_message = self.formatter.format_email(
|
||||
_, full_message = self._format_template_email(
|
||||
subject_template="{{ subject }}",
|
||||
base_template=self._read_template("templates/base.html.jinja2"),
|
||||
content_template=self._read_template(f"templates/{template_name}"),
|
||||
data={"subject": subject, **(data or {})},
|
||||
unsubscribe_link=unsubscribe_link,
|
||||
@@ -119,7 +135,7 @@ class EmailSender:
|
||||
user_email=user_email,
|
||||
subject=subject,
|
||||
body=full_message,
|
||||
user_unsubscribe_link=user_unsubscribe_link,
|
||||
user_unsubscribe_link=unsubscribe_link,
|
||||
)
|
||||
|
||||
def send_templated(
|
||||
@@ -138,21 +154,18 @@ class EmailSender:
|
||||
return
|
||||
|
||||
template = self._get_template(notification)
|
||||
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
base_url = get_frontend_base_url()
|
||||
unsubscribe_link = self._get_unsubscribe_link(user_unsub_link)
|
||||
|
||||
# Normalize data
|
||||
template_data = {"notifications": data} if isinstance(data, list) else data
|
||||
|
||||
try:
|
||||
subject, full_message = self.formatter.format_email(
|
||||
base_template=template.base_template,
|
||||
subject, full_message = self._format_template_email(
|
||||
subject_template=template.subject_template,
|
||||
content_template=template.body_template,
|
||||
data=template_data,
|
||||
unsubscribe_link=f"{base_url}/profile/settings",
|
||||
unsubscribe_link=unsubscribe_link,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting full message: {e}")
|
||||
@@ -176,7 +189,7 @@ class EmailSender:
|
||||
user_email=user_email,
|
||||
subject=f"{subject} (Output Too Large)",
|
||||
body=summary_message,
|
||||
user_unsubscribe_link=user_unsub_link,
|
||||
user_unsubscribe_link=unsubscribe_link,
|
||||
)
|
||||
return # Skip sending full email
|
||||
|
||||
@@ -185,7 +198,7 @@ class EmailSender:
|
||||
user_email=user_email,
|
||||
subject=subject,
|
||||
body=full_message,
|
||||
user_unsubscribe_link=user_unsub_link,
|
||||
user_unsubscribe_link=unsubscribe_link,
|
||||
)
|
||||
|
||||
def _get_template(self, notification: NotificationType):
|
||||
@@ -218,20 +231,17 @@ class EmailSender:
|
||||
if not self.postmark:
|
||||
logger.warning("Email tried to send without postmark configured")
|
||||
return
|
||||
unsubscribe_link = self._get_unsubscribe_link(user_unsubscribe_link)
|
||||
logger.debug(f"Sending email to {user_email} with subject {subject}")
|
||||
self.postmark.emails.send(
|
||||
From=settings.config.postmark_sender_email,
|
||||
To=user_email,
|
||||
Subject=subject,
|
||||
HtmlBody=body,
|
||||
Headers=(
|
||||
{
|
||||
"List-Unsubscribe-Post": "List-Unsubscribe=One-Click",
|
||||
"List-Unsubscribe": f"<{user_unsubscribe_link}>",
|
||||
}
|
||||
if user_unsubscribe_link
|
||||
else None
|
||||
),
|
||||
Headers={
|
||||
"List-Unsubscribe-Post": "List-Unsubscribe=One-Click",
|
||||
"List-Unsubscribe": f"<{unsubscribe_link}>",
|
||||
},
|
||||
)
|
||||
|
||||
def send_html(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
|
||||
from backend.api.test_helpers import override_config
|
||||
@@ -29,6 +30,9 @@ def test_send_template_renders_nightly_copilot_email(mocker) -> None:
|
||||
assert "I found something useful for you." in body
|
||||
assert "Open Copilot" in body
|
||||
assert "Approval needed" not in body
|
||||
assert send_email.call_args.kwargs["user_unsubscribe_link"].endswith(
|
||||
"/profile/settings"
|
||||
)
|
||||
|
||||
|
||||
def test_send_template_renders_nightly_copilot_approval_block(mocker) -> None:
|
||||
@@ -180,3 +184,21 @@ def test_send_template_still_sends_in_production(mocker) -> None:
|
||||
)
|
||||
|
||||
send_email.assert_called_once()
|
||||
|
||||
|
||||
def test_send_html_uses_default_unsubscribe_link(mocker) -> None:
|
||||
sender = EmailSender()
|
||||
send = mocker.Mock()
|
||||
sender.postmark = cast(Any, SimpleNamespace(emails=SimpleNamespace(send=send)))
|
||||
|
||||
with override_config(settings, "frontend_base_url", "https://example.com"):
|
||||
sender.send_html(
|
||||
user_email="user@example.com",
|
||||
subject="Autopilot update",
|
||||
body="<p>Hello</p>",
|
||||
)
|
||||
|
||||
headers = send.call_args.kwargs["Headers"]
|
||||
|
||||
assert headers["List-Unsubscribe-Post"] == "List-Unsubscribe=One-Click"
|
||||
assert headers["List-Unsubscribe"] == "<https://example.com/profile/settings>"
|
||||
|
||||
9
autogpt_platform/backend/backend/util/url.py
Normal file
9
autogpt_platform/backend/backend/util/url.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def get_frontend_base_url() -> str:
|
||||
return (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
).rstrip("/")
|
||||
Reference in New Issue
Block a user