mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
revert(platform): Revert invite system (#12485)
## Summary Reverts the invite system PRs due to security gaps identified during review: - The move from Supabase-native `allowed_users` gating to application-level gating allows orphaned Supabase auth accounts (valid JWT without a platform `User`) - The auth middleware never verifies `User` existence, so orphaned users get 500s instead of clean 403s - OAuth/Google SSO signup completely bypasses the invite gate - The DB trigger that atomically created `User` + `Profile` on signup was dropped in favor of a client-initiated API call, introducing a failure window ### Reverted PRs - Reverts #12347 — Foundation: InvitedUser model, invite-gated signup, admin UI - Reverts #12374 — Tally enrichment: personalized prompts from form submissions - Reverts #12451 — Pre-check: POST /auth/check-invite endpoint - Reverts #12452 (collateral) — Themed prompt categories / SuggestionThemes UI. This PR built on top of #12374's `suggested_prompts` backend field and `/chat/suggested-prompts` endpoint, so it cannot remain without #12374. The copilot empty session falls back to hardcoded default prompts. ### Migration Includes a new migration (`20260319120000_revert_invite_system`) that: - Drops the `InvitedUser` table and its enums (`InvitedUserStatus`, `TallyComputationStatus`) - Restores the `add_user_and_profile_to_platform()` trigger on `auth.users` - Backfills `User` + `Profile` rows for any auth accounts created during the invite-gate window ### What's NOT reverted - The `generate_username()` function (never dropped, still used by backfill migration) - The old `add_user_to_platform()` function (superseded by `add_user_and_profile_to_platform()`) - PR #12471 (admin UX improvements) — was never merged, no action needed ## Test plan - [x] Verify migration: `InvitedUser` table dropped, enums dropped, trigger restored - [x] Verify backfill: no orphaned auth users, no users without Profile - [x] Verify existing users can still log in (email + OAuth) - [x] Verify CoPilot chat page loads with default prompts - [ ] Verify new user signup creates `User` + `Profile` via the restored trigger - [ ] Verify admin `/admin/users` page loads without crashing - [ ] Run backend tests: `poetry run test` 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
This commit is contained in:
committed by
GitHub
parent
0ce1c90b55
commit
5b9a4c52c9
@@ -37,10 +37,6 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
||||
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
||||
|
||||
## ===== SIGNUP / INVITE GATE ===== ##
|
||||
# Set to true to require an invite before users can sign up
|
||||
ENABLE_INVITE_GATE=false
|
||||
|
||||
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
||||
# Platform URLs (set these for webhooks and OAuth to work)
|
||||
PLATFORM_BASE_URL=http://localhost:8000
|
||||
|
||||
@@ -1,17 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
import prisma.enums
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import UserTransaction
|
||||
from backend.util.models import Pagination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.invited_user import BulkInvitedUsersResult, InvitedUserRecord
|
||||
|
||||
|
||||
class UserHistoryResponse(BaseModel):
|
||||
"""Response model for listings with version history"""
|
||||
@@ -23,70 +14,3 @@ class UserHistoryResponse(BaseModel):
|
||||
class AddUserCreditsResponse(BaseModel):
|
||||
new_balance: int
|
||||
transaction_key: str
|
||||
|
||||
|
||||
class CreateInvitedUserRequest(BaseModel):
|
||||
email: EmailStr
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class InvitedUserResponse(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
status: prisma.enums.InvitedUserStatus
|
||||
auth_user_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
tally_understanding: Optional[dict[str, Any]] = None
|
||||
tally_status: prisma.enums.TallyComputationStatus
|
||||
tally_computed_at: Optional[datetime] = None
|
||||
tally_error: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_record(cls, record: InvitedUserRecord) -> InvitedUserResponse:
|
||||
return cls.model_validate(record.model_dump())
|
||||
|
||||
|
||||
class InvitedUsersResponse(BaseModel):
|
||||
invited_users: list[InvitedUserResponse]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class BulkInvitedUserRowResponse(BaseModel):
|
||||
row_number: int
|
||||
email: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
status: Literal["CREATED", "SKIPPED", "ERROR"]
|
||||
message: str
|
||||
invited_user: Optional[InvitedUserResponse] = None
|
||||
|
||||
|
||||
class BulkInvitedUsersResponse(BaseModel):
|
||||
created_count: int
|
||||
skipped_count: int
|
||||
error_count: int
|
||||
results: list[BulkInvitedUserRowResponse]
|
||||
|
||||
@classmethod
|
||||
def from_result(cls, result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
|
||||
return cls(
|
||||
created_count=result.created_count,
|
||||
skipped_count=result.skipped_count,
|
||||
error_count=result.error_count,
|
||||
results=[
|
||||
BulkInvitedUserRowResponse(
|
||||
row_number=row.row_number,
|
||||
email=row.email,
|
||||
name=row.name,
|
||||
status=row.status,
|
||||
message=row.message,
|
||||
invited_user=(
|
||||
InvitedUserResponse.from_record(row.invited_user)
|
||||
if row.invited_user is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
for row in result.results
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
import logging
|
||||
import math
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, File, Query, Security, UploadFile
|
||||
|
||||
from backend.data.invited_user import (
|
||||
bulk_create_invited_users_from_file,
|
||||
create_invited_user,
|
||||
list_invited_users,
|
||||
retry_invited_user_tally,
|
||||
revoke_invited_user,
|
||||
)
|
||||
from backend.data.tally import mask_email
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from .model import (
|
||||
BulkInvitedUsersResponse,
|
||||
CreateInvitedUserRequest,
|
||||
InvitedUserResponse,
|
||||
InvitedUsersResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["users", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/invited-users",
|
||||
response_model=InvitedUsersResponse,
|
||||
summary="List Invited Users",
|
||||
)
|
||||
async def get_invited_users(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
) -> InvitedUsersResponse:
|
||||
logger.info("Admin user %s requested invited users", admin_user_id)
|
||||
invited_users, total = await list_invited_users(page=page, page_size=page_size)
|
||||
return InvitedUsersResponse(
|
||||
invited_users=[InvitedUserResponse.from_record(iu) for iu in invited_users],
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=max(1, math.ceil(total / page_size)),
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Create Invited User",
|
||||
)
|
||||
async def create_invited_user_route(
|
||||
request: CreateInvitedUserRequest,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s creating invited user for %s",
|
||||
admin_user_id,
|
||||
mask_email(request.email),
|
||||
)
|
||||
invited_user = await create_invited_user(request.email, request.name)
|
||||
logger.info(
|
||||
"Admin user %s created invited user %s",
|
||||
admin_user_id,
|
||||
invited_user.id,
|
||||
)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/bulk",
|
||||
response_model=BulkInvitedUsersResponse,
|
||||
summary="Bulk Create Invited Users",
|
||||
operation_id="postV2BulkCreateInvitedUsers",
|
||||
)
|
||||
async def bulk_create_invited_users_route(
|
||||
file: UploadFile = File(...),
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> BulkInvitedUsersResponse:
|
||||
logger.info(
|
||||
"Admin user %s bulk invited users from %s",
|
||||
admin_user_id,
|
||||
file.filename or "<unnamed>",
|
||||
)
|
||||
content = await file.read()
|
||||
result = await bulk_create_invited_users_from_file(file.filename, content)
|
||||
return BulkInvitedUsersResponse.from_result(result)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/{invited_user_id}/revoke",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Revoke Invited User",
|
||||
)
|
||||
async def revoke_invited_user_route(
|
||||
invited_user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s revoking invited user %s", admin_user_id, invited_user_id
|
||||
)
|
||||
invited_user = await revoke_invited_user(invited_user_id)
|
||||
logger.info("Admin user %s revoked invited user %s", admin_user_id, invited_user_id)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/{invited_user_id}/retry-tally",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Retry Invited User Tally",
|
||||
)
|
||||
async def retry_invited_user_tally_route(
|
||||
invited_user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s retrying Tally seed for invited user %s",
|
||||
admin_user_id,
|
||||
invited_user_id,
|
||||
)
|
||||
invited_user = await retry_invited_user_tally(invited_user_id)
|
||||
logger.info(
|
||||
"Admin user %s retried Tally seed for invited user %s",
|
||||
admin_user_id,
|
||||
invited_user_id,
|
||||
)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
@@ -1,168 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import prisma.enums
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.data.invited_user import (
|
||||
BulkInvitedUserRowResult,
|
||||
BulkInvitedUsersResult,
|
||||
InvitedUserRecord,
|
||||
)
|
||||
|
||||
from .user_admin_routes import router as user_admin_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(user_admin_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _sample_invited_user() -> InvitedUserRecord:
|
||||
now = datetime.now(timezone.utc)
|
||||
return InvitedUserRecord(
|
||||
id="invite-1",
|
||||
email="invited@example.com",
|
||||
status=prisma.enums.InvitedUserStatus.INVITED,
|
||||
auth_user_id=None,
|
||||
name="Invited User",
|
||||
tally_understanding=None,
|
||||
tally_status=prisma.enums.TallyComputationStatus.PENDING,
|
||||
tally_computed_at=None,
|
||||
tally_error=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
|
||||
return BulkInvitedUsersResult(
|
||||
created_count=1,
|
||||
skipped_count=1,
|
||||
error_count=0,
|
||||
results=[
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=1,
|
||||
email="invited@example.com",
|
||||
name=None,
|
||||
status="CREATED",
|
||||
message="Invite created",
|
||||
invited_user=_sample_invited_user(),
|
||||
),
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=2,
|
||||
email="duplicate@example.com",
|
||||
name=None,
|
||||
status="SKIPPED",
|
||||
message="An invited user with this email already exists",
|
||||
invited_user=None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_get_invited_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.list_invited_users",
|
||||
AsyncMock(return_value=([_sample_invited_user()], 1)),
|
||||
)
|
||||
|
||||
response = client.get("/admin/invited-users")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["invited_users"]) == 1
|
||||
assert data["invited_users"][0]["email"] == "invited@example.com"
|
||||
assert data["invited_users"][0]["status"] == "INVITED"
|
||||
assert data["pagination"]["total_items"] == 1
|
||||
assert data["pagination"]["current_page"] == 1
|
||||
assert data["pagination"]["page_size"] == 50
|
||||
|
||||
|
||||
def test_create_invited_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.create_invited_user",
|
||||
AsyncMock(return_value=_sample_invited_user()),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/invited-users",
|
||||
json={"email": "invited@example.com", "name": "Invited User"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == "invited@example.com"
|
||||
assert data["name"] == "Invited User"
|
||||
|
||||
|
||||
def test_bulk_create_invited_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.bulk_create_invited_users_from_file",
|
||||
AsyncMock(return_value=_sample_bulk_invited_users_result()),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/invited-users/bulk",
|
||||
files={
|
||||
"file": ("invites.txt", b"invited@example.com\nduplicate@example.com\n")
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["created_count"] == 1
|
||||
assert data["skipped_count"] == 1
|
||||
assert data["results"][0]["status"] == "CREATED"
|
||||
assert data["results"][1]["status"] == "SKIPPED"
|
||||
|
||||
|
||||
def test_revoke_invited_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
revoked = _sample_invited_user().model_copy(
|
||||
update={"status": prisma.enums.InvitedUserStatus.REVOKED}
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.revoke_invited_user",
|
||||
AsyncMock(return_value=revoked),
|
||||
)
|
||||
|
||||
response = client.post("/admin/invited-users/invite-1/revoke")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "REVOKED"
|
||||
|
||||
|
||||
def test_retry_invited_user_tally(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
retried = _sample_invited_user().model_copy(
|
||||
update={"tally_status": prisma.enums.TallyComputationStatus.RUNNING}
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.retry_invited_user_tally",
|
||||
AsyncMock(return_value=retried),
|
||||
)
|
||||
|
||||
response = client.post("/admin/invited-users/invite-1/retry-tally")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["tally_status"] == "RUNNING"
|
||||
@@ -60,7 +60,6 @@ from backend.copilot.tools.models import (
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import get_business_understanding
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
@@ -895,47 +894,6 @@ async def session_assign_user(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Suggested Prompts ==========
|
||||
|
||||
|
||||
class SuggestedTheme(BaseModel):
|
||||
"""A themed group of suggested prompts."""
|
||||
|
||||
name: str
|
||||
prompts: list[str]
|
||||
|
||||
|
||||
class SuggestedPromptsResponse(BaseModel):
|
||||
"""Response model for user-specific suggested prompts grouped by theme."""
|
||||
|
||||
themes: list[SuggestedTheme]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/suggested-prompts",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
)
|
||||
async def get_suggested_prompts(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> SuggestedPromptsResponse:
|
||||
"""
|
||||
Get LLM-generated suggested prompts grouped by theme.
|
||||
|
||||
Returns personalized quick-action prompts based on the user's
|
||||
business understanding. Returns empty themes list if no custom
|
||||
prompts are available.
|
||||
"""
|
||||
understanding = await get_business_understanding(user_id)
|
||||
if understanding is None or not understanding.suggested_prompts:
|
||||
return SuggestedPromptsResponse(themes=[])
|
||||
|
||||
themes = [
|
||||
SuggestedTheme(name=name, prompts=prompts)
|
||||
for name, prompts in understanding.suggested_prompts.items()
|
||||
]
|
||||
return SuggestedPromptsResponse(themes=themes)
|
||||
|
||||
|
||||
# ========== Configuration ==========
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Tests for chat API routes: session title update, file attachment validation, usage, rate limiting, and suggested prompts."""
|
||||
"""Tests for chat API routes: session title update, file attachment validation, usage, and rate limiting."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
@@ -400,69 +400,3 @@ def test_usage_rejects_unauthenticated_request() -> None:
|
||||
response = unauthenticated_client.get("/usage")
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_get_business_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
return_value=None,
|
||||
):
|
||||
"""Mock get_business_understanding."""
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=return_value,
|
||||
)
|
||||
|
||||
|
||||
def test_suggested_prompts_returns_themes(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with themed prompts gets them back as themes list."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = {
|
||||
"Learn": ["L1", "L2"],
|
||||
"Create": ["C1"],
|
||||
}
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "themes" in data
|
||||
themes_by_name = {t["name"]: t["prompts"] for t in data["themes"]}
|
||||
assert themes_by_name["Learn"] == ["L1", "L2"]
|
||||
assert themes_by_name["Create"] == ["C1"]
|
||||
|
||||
|
||||
def test_suggested_prompts_no_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with no understanding gets empty themes list."""
|
||||
_mock_get_business_understanding(mocker, return_value=None)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"themes": []}
|
||||
|
||||
|
||||
def test_suggested_prompts_empty_prompts(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with understanding but empty prompts gets empty themes list."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = {}
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"themes": []}
|
||||
|
||||
@@ -24,7 +24,7 @@ from fastapi import (
|
||||
UploadFile,
|
||||
)
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from pydantic import BaseModel
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
@@ -55,11 +55,6 @@ from backend.data.credit import (
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.graph import GraphSettings
|
||||
from backend.data.invited_user import (
|
||||
check_invite_eligibility,
|
||||
get_or_activate_user,
|
||||
is_internal_email,
|
||||
)
|
||||
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.data.onboarding import (
|
||||
@@ -74,8 +69,8 @@ from backend.data.onboarding import (
|
||||
reset_user_onboarding,
|
||||
update_user_onboarding,
|
||||
)
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_by_id,
|
||||
get_user_notification_preference,
|
||||
update_user_email,
|
||||
@@ -134,69 +129,6 @@ v1_router = APIRouter()
|
||||
_tally_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
class CheckInviteRequest(BaseModel):
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class CheckInviteResponse(BaseModel):
|
||||
allowed: bool
|
||||
|
||||
|
||||
_CHECK_INVITE_RATE_LIMIT = 10 # requests
|
||||
_CHECK_INVITE_RATE_WINDOW = 60 # seconds
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/auth/check-invite",
|
||||
summary="Check if an email is allowed to sign up",
|
||||
tags=["auth"],
|
||||
)
|
||||
async def check_invite_route(
|
||||
http_request: Request,
|
||||
request: CheckInviteRequest,
|
||||
) -> CheckInviteResponse:
|
||||
"""Check if an email is allowed to sign up (no auth required).
|
||||
|
||||
Called by the frontend before creating a Supabase auth user to prevent
|
||||
orphaned accounts when the invite gate is enabled.
|
||||
"""
|
||||
client_ip = (
|
||||
http_request.headers.get("x-forwarded-for", "").split(",")[0].strip()
|
||||
or http_request.headers.get("x-real-ip", "")
|
||||
or (http_request.client.host if http_request.client else "unknown")
|
||||
)
|
||||
rate_key = f"rate:check-invite:{client_ip}"
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
# Use a pipeline so that incr + expire are sent atomically.
|
||||
# This prevents the key from persisting indefinitely when expire fails
|
||||
# after a successful incr (which would permanently block the IP once
|
||||
# the count exceeds the limit).
|
||||
# NOTE: pipeline command methods (incr, expire) are NOT awaitable —
|
||||
# they queue the command and return the pipeline. Only execute() is
|
||||
# awaited, which flushes all queued commands in a single round-trip.
|
||||
pipe = redis.pipeline()
|
||||
pipe.incr(rate_key)
|
||||
pipe.expire(rate_key, _CHECK_INVITE_RATE_WINDOW)
|
||||
results = await pipe.execute()
|
||||
count = results[0]
|
||||
if count > _CHECK_INVITE_RATE_LIMIT:
|
||||
raise HTTPException(status_code=429, detail="Too many requests")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
logger.debug("Rate limit check failed for check-invite, failing open")
|
||||
|
||||
if not settings.config.enable_invite_gate:
|
||||
return CheckInviteResponse(allowed=True)
|
||||
|
||||
if is_internal_email(request.email):
|
||||
return CheckInviteResponse(allowed=True)
|
||||
|
||||
allowed = await check_invite_eligibility(request.email)
|
||||
return CheckInviteResponse(allowed=allowed)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/auth/user",
|
||||
summary="Get or create user",
|
||||
@@ -204,10 +136,12 @@ async def check_invite_route(
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
user = await get_or_activate_user(user_data)
|
||||
user = await get_or_create_user(user_data)
|
||||
|
||||
# Fire-and-forget: backfill Tally understanding when invite pre-seeding did
|
||||
# not produce a stored result before first activation.
|
||||
# Fire-and-forget: populate business understanding from Tally form.
|
||||
# We use created_at proximity instead of an is_new flag because
|
||||
# get_or_create_user is cached — a separate is_new return value would be
|
||||
# unreliable on repeated calls within the cache TTL.
|
||||
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
|
||||
if age_seconds < 30:
|
||||
try:
|
||||
@@ -231,8 +165,7 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_user_email_route(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
email: str = Body(...),
|
||||
user_id: Annotated[str, Security(get_user_id)], email: str = Body(...)
|
||||
) -> dict[str, str]:
|
||||
await update_user_email(user_id, email)
|
||||
|
||||
@@ -246,16 +179,10 @@ async def update_user_email_route(
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_user_timezone_route(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
user_data: dict = Security(get_jwt_payload),
|
||||
) -> TimezoneResponse:
|
||||
"""Get user timezone setting."""
|
||||
try:
|
||||
user = await get_user_by_id(user_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_404_NOT_FOUND,
|
||||
detail="User not found. Please complete activation via /auth/user first.",
|
||||
)
|
||||
user = await get_or_create_user(user_data)
|
||||
return TimezoneResponse(timezone=user.timezone)
|
||||
|
||||
|
||||
@@ -266,8 +193,7 @@ async def get_user_timezone_route(
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_user_timezone_route(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
request: UpdateTimezoneRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
|
||||
) -> TimezoneResponse:
|
||||
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
|
||||
user = await update_user_timezone(user_id, str(request.timezone))
|
||||
|
||||
@@ -35,102 +35,6 @@ def setup_app_auth(mock_jwt_user, setup_test_user):
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# check_invite_route tests
|
||||
|
||||
_RATE_LIMIT_PATCH = "backend.api.features.v1.get_redis_async"
|
||||
|
||||
|
||||
def _make_redis_mock(count: int = 1) -> AsyncMock:
|
||||
"""Return a mock Redis client that reports `count` for the rate-limit key.
|
||||
|
||||
The route uses a pipeline where incr/expire are synchronous (they queue
|
||||
commands and return the pipeline) and only execute() is awaited.
|
||||
"""
|
||||
mock_pipe = Mock()
|
||||
mock_pipe.incr = Mock(return_value=mock_pipe)
|
||||
mock_pipe.expire = Mock(return_value=mock_pipe)
|
||||
mock_pipe.execute = AsyncMock(return_value=[count, True])
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = Mock(return_value=mock_pipe)
|
||||
return mock_redis
|
||||
|
||||
|
||||
def test_check_invite_gate_disabled(mocker: pytest_mock.MockFixture) -> None:
|
||||
"""When enable_invite_gate is False every email is allowed."""
|
||||
mocker.patch(_RATE_LIMIT_PATCH, return_value=_make_redis_mock())
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.settings",
|
||||
Mock(config=Mock(enable_invite_gate=False)),
|
||||
)
|
||||
|
||||
response = client.post("/auth/check-invite", json={"email": "anyone@example.com"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"allowed": True}
|
||||
|
||||
|
||||
def test_check_invite_internal_email_bypasses_gate(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""@agpt.co addresses bypass the gate even when it is enabled."""
|
||||
mocker.patch(_RATE_LIMIT_PATCH, return_value=_make_redis_mock())
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.settings",
|
||||
Mock(config=Mock(enable_invite_gate=True)),
|
||||
)
|
||||
|
||||
response = client.post("/auth/check-invite", json={"email": "employee@agpt.co"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"allowed": True}
|
||||
|
||||
|
||||
def test_check_invite_eligible_email(mocker: pytest_mock.MockFixture) -> None:
|
||||
"""An email with INVITED status is allowed when the gate is enabled."""
|
||||
mocker.patch(_RATE_LIMIT_PATCH, return_value=_make_redis_mock())
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.settings",
|
||||
Mock(config=Mock(enable_invite_gate=True)),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.check_invite_eligibility",
|
||||
new=AsyncMock(return_value=True),
|
||||
)
|
||||
|
||||
response = client.post("/auth/check-invite", json={"email": "invited@example.com"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"allowed": True}
|
||||
|
||||
|
||||
def test_check_invite_ineligible_email(mocker: pytest_mock.MockFixture) -> None:
|
||||
"""An email without an active invite is denied when the gate is enabled."""
|
||||
mocker.patch(_RATE_LIMIT_PATCH, return_value=_make_redis_mock())
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.settings",
|
||||
Mock(config=Mock(enable_invite_gate=True)),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.check_invite_eligibility",
|
||||
new=AsyncMock(return_value=False),
|
||||
)
|
||||
|
||||
response = client.post("/auth/check-invite", json={"email": "stranger@example.com"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"allowed": False}
|
||||
|
||||
|
||||
def test_check_invite_rate_limit_exceeded(mocker: pytest_mock.MockFixture) -> None:
|
||||
"""Requests beyond the per-IP rate limit receive HTTP 429."""
|
||||
mocker.patch(_RATE_LIMIT_PATCH, return_value=_make_redis_mock(count=11))
|
||||
|
||||
response = client.post("/auth/check-invite", json={"email": "flood@example.com"})
|
||||
|
||||
assert response.status_code == 429
|
||||
|
||||
|
||||
# Auth endpoints tests
|
||||
def test_get_or_create_user_route(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
@@ -147,7 +51,7 @@ def test_get_or_create_user_route(
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_or_activate_user",
|
||||
"backend.api.features.v1.get_or_create_user",
|
||||
return_value=mock_user,
|
||||
)
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from prisma.errors import PrismaError
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.admin.user_admin_routes
|
||||
import backend.api.features.builder
|
||||
import backend.api.features.builder.routes
|
||||
import backend.api.features.chat.routes as chat_routes
|
||||
@@ -312,11 +311,6 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/executions",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.user_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/users",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.executions.review.routes.router,
|
||||
tags=["v2", "executions", "review"],
|
||||
|
||||
@@ -1,774 +0,0 @@
|
||||
import asyncio
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
from prisma.errors import UniqueViolationError
|
||||
from pydantic import BaseModel, EmailStr, TypeAdapter, ValidationError
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.data.model import User
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.tally import get_business_understanding_input_from_tally, mask_email
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
merge_business_understanding_data,
|
||||
)
|
||||
from backend.data.user import get_user_by_email, get_user_by_id
|
||||
from backend.executor.cluster_lock import AsyncClusterLock
|
||||
from backend.util.exceptions import (
|
||||
NotAuthorizedError,
|
||||
NotFoundError,
|
||||
PreconditionFailed,
|
||||
)
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_settings = Settings()
|
||||
|
||||
_WORKER_ID = f"{socket.gethostname()}:{os.getpid()}"
|
||||
|
||||
_tally_seed_tasks: dict[str, asyncio.Task] = {}
|
||||
_TALLY_STALE_SECONDS = 300
|
||||
_MAX_TALLY_ERROR_LENGTH = 200
|
||||
_email_adapter = TypeAdapter(EmailStr)
|
||||
|
||||
MAX_BULK_INVITE_FILE_BYTES = 1024 * 1024
|
||||
MAX_BULK_INVITE_ROWS = 500
|
||||
|
||||
|
||||
class InvitedUserRecord(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
status: prisma.enums.InvitedUserStatus
|
||||
auth_user_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
tally_understanding: Optional[dict[str, Any]] = None
|
||||
tally_status: prisma.enums.TallyComputationStatus
|
||||
tally_computed_at: Optional[datetime] = None
|
||||
tally_error: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, invited_user: "prisma.models.InvitedUser") -> "InvitedUserRecord":
|
||||
payload = (
|
||||
invited_user.tallyUnderstanding
|
||||
if isinstance(invited_user.tallyUnderstanding, dict)
|
||||
else None
|
||||
)
|
||||
return cls(
|
||||
id=invited_user.id,
|
||||
email=invited_user.email,
|
||||
status=invited_user.status,
|
||||
auth_user_id=invited_user.authUserId,
|
||||
name=invited_user.name,
|
||||
tally_understanding=payload,
|
||||
tally_status=invited_user.tallyStatus,
|
||||
tally_computed_at=invited_user.tallyComputedAt,
|
||||
tally_error=invited_user.tallyError,
|
||||
created_at=invited_user.createdAt,
|
||||
updated_at=invited_user.updatedAt,
|
||||
)
|
||||
|
||||
|
||||
class BulkInvitedUserRowResult(BaseModel):
|
||||
row_number: int
|
||||
email: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
status: Literal["CREATED", "SKIPPED", "ERROR"]
|
||||
message: str
|
||||
invited_user: Optional[InvitedUserRecord] = None
|
||||
|
||||
|
||||
class BulkInvitedUsersResult(BaseModel):
|
||||
created_count: int
|
||||
skipped_count: int
|
||||
error_count: int
|
||||
results: list[BulkInvitedUserRowResult]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ParsedInviteRow:
|
||||
row_number: int
|
||||
email: str
|
||||
name: Optional[str]
|
||||
|
||||
|
||||
def normalize_email(email: str) -> str:
|
||||
return email.strip().lower()
|
||||
|
||||
|
||||
def is_internal_email(email: str) -> bool:
|
||||
"""Return True for @agpt.co addresses, which always bypass the invite gate."""
|
||||
return normalize_email(email).endswith("@agpt.co")
|
||||
|
||||
|
||||
def _normalize_name(name: Optional[str]) -> Optional[str]:
|
||||
if name is None:
|
||||
return None
|
||||
normalized = name.strip()
|
||||
return normalized or None
|
||||
|
||||
|
||||
def _default_profile_name(email: str, preferred_name: Optional[str]) -> str:
|
||||
if preferred_name:
|
||||
return preferred_name
|
||||
local_part = email.split("@", 1)[0].strip()
|
||||
return local_part or "user"
|
||||
|
||||
|
||||
def _sanitize_username_base(email: str) -> str:
|
||||
local_part = email.split("@", 1)[0].lower()
|
||||
sanitized = re.sub(r"[^a-z0-9-]", "", local_part)
|
||||
sanitized = sanitized.strip("-")
|
||||
return sanitized[:40] or "user"
|
||||
|
||||
|
||||
async def _generate_unique_profile_username(email: str, tx) -> str:
|
||||
base = _sanitize_username_base(email)
|
||||
|
||||
for _ in range(2):
|
||||
candidate = f"{base}-{uuid4().hex[:6]}"
|
||||
existing = await prisma.models.Profile.prisma(tx).find_unique(
|
||||
where={"username": candidate}
|
||||
)
|
||||
if existing is None:
|
||||
return candidate
|
||||
|
||||
raise RuntimeError(f"Unable to generate unique username for {email}")
|
||||
|
||||
|
||||
async def _ensure_default_profile(
|
||||
user_id: str,
|
||||
email: str,
|
||||
preferred_name: Optional[str],
|
||||
tx,
|
||||
) -> None:
|
||||
existing_profile = await prisma.models.Profile.prisma(tx).find_unique(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
if existing_profile is not None:
|
||||
return
|
||||
|
||||
username = await _generate_unique_profile_username(email, tx)
|
||||
await prisma.models.Profile.prisma(tx).create(
|
||||
data=prisma.types.ProfileCreateInput(
|
||||
userId=user_id,
|
||||
name=_default_profile_name(email, preferred_name),
|
||||
username=username,
|
||||
description="I'm new here",
|
||||
links=[],
|
||||
avatarUrl="",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _ensure_default_onboarding(user_id: str, tx) -> None:
|
||||
await prisma.models.UserOnboarding.prisma(tx).upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": prisma.types.UserOnboardingCreateInput(userId=user_id),
|
||||
"update": {},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _apply_tally_understanding(
|
||||
user_id: str,
|
||||
invited_user: "prisma.models.InvitedUser",
|
||||
tx,
|
||||
) -> None:
|
||||
if not isinstance(invited_user.tallyUnderstanding, dict):
|
||||
return
|
||||
|
||||
try:
|
||||
input_data = BusinessUnderstandingInput.model_validate(
|
||||
invited_user.tallyUnderstanding
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Malformed tallyUnderstanding for invited user %s; skipping",
|
||||
invited_user.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
payload = merge_business_understanding_data({}, input_data)
|
||||
await prisma.models.CoPilotUnderstanding.prisma(tx).upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "data": SafeJson(payload)},
|
||||
"update": {"data": SafeJson(payload)},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def check_invite_eligibility(email: str) -> bool:
|
||||
"""Check if an email is allowed to sign up based on the invite list.
|
||||
|
||||
Args:
|
||||
email: The email to check (will be normalized internally).
|
||||
|
||||
Returns True if the email has an active (INVITED) invite record.
|
||||
Does NOT check enable_invite_gate — the caller is responsible for that.
|
||||
"""
|
||||
email = normalize_email(email)
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"email": email}
|
||||
)
|
||||
return (
|
||||
invited_user is not None
|
||||
and invited_user.status == prisma.enums.InvitedUserStatus.INVITED
|
||||
)
|
||||
|
||||
|
||||
async def list_invited_users(
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> tuple[list[InvitedUserRecord], int]:
|
||||
total = await prisma.models.InvitedUser.prisma().count()
|
||||
invited_users = await prisma.models.InvitedUser.prisma().find_many(
|
||||
order={"createdAt": "desc"},
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
return [InvitedUserRecord.from_db(iu) for iu in invited_users], total
|
||||
|
||||
|
||||
async def create_invited_user(
|
||||
email: str, name: Optional[str] = None
|
||||
) -> InvitedUserRecord:
|
||||
normalized_email = normalize_email(email)
|
||||
normalized_name = _normalize_name(name)
|
||||
|
||||
existing_user = await prisma.models.User.prisma().find_unique(
|
||||
where={"email": normalized_email}
|
||||
)
|
||||
if existing_user is not None:
|
||||
raise PreconditionFailed("An active user with this email already exists")
|
||||
|
||||
existing_invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"email": normalized_email}
|
||||
)
|
||||
if existing_invited_user is not None:
|
||||
raise PreconditionFailed("An invited user with this email already exists")
|
||||
|
||||
try:
|
||||
invited_user = await prisma.models.InvitedUser.prisma().create(
|
||||
data={
|
||||
"email": normalized_email,
|
||||
"name": normalized_name,
|
||||
"status": prisma.enums.InvitedUserStatus.INVITED,
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
|
||||
}
|
||||
)
|
||||
except UniqueViolationError:
|
||||
raise PreconditionFailed("An invited user with this email already exists")
|
||||
schedule_invited_user_tally_precompute(invited_user.id)
|
||||
return InvitedUserRecord.from_db(invited_user)
|
||||
|
||||
|
||||
async def revoke_invited_user(invited_user_id: str) -> InvitedUserRecord:
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"id": invited_user_id}
|
||||
)
|
||||
if invited_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.CLAIMED:
|
||||
raise PreconditionFailed("Claimed invited users cannot be revoked")
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
|
||||
return InvitedUserRecord.from_db(invited_user)
|
||||
|
||||
revoked_user = await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={"status": prisma.enums.InvitedUserStatus.REVOKED},
|
||||
)
|
||||
if revoked_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
return InvitedUserRecord.from_db(revoked_user)
|
||||
|
||||
|
||||
async def retry_invited_user_tally(invited_user_id: str) -> InvitedUserRecord:
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"id": invited_user_id}
|
||||
)
|
||||
if invited_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
|
||||
raise PreconditionFailed("Revoked invited users cannot retry Tally seeding")
|
||||
|
||||
refreshed_user = await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyUnderstanding": None,
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
|
||||
"tallyComputedAt": None,
|
||||
"tallyError": None,
|
||||
},
|
||||
)
|
||||
if refreshed_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
schedule_invited_user_tally_precompute(invited_user_id)
|
||||
return InvitedUserRecord.from_db(refreshed_user)
|
||||
|
||||
|
||||
def _decode_bulk_invite_file(content: bytes) -> str:
|
||||
if len(content) > MAX_BULK_INVITE_FILE_BYTES:
|
||||
raise ValueError("Invite file exceeds the maximum size of 1 MB")
|
||||
|
||||
try:
|
||||
return content.decode("utf-8-sig")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise ValueError("Invite file must be UTF-8 encoded") from exc
|
||||
|
||||
|
||||
def _parse_bulk_invite_csv(text: str) -> list[_ParsedInviteRow]:
|
||||
indexed_rows: list[tuple[int, list[str]]] = []
|
||||
|
||||
for row_number, row in enumerate(csv.reader(io.StringIO(text)), start=1):
|
||||
normalized_row = [cell.strip() for cell in row]
|
||||
if any(normalized_row):
|
||||
indexed_rows.append((row_number, normalized_row))
|
||||
|
||||
if not indexed_rows:
|
||||
return []
|
||||
|
||||
header = [cell.lower() for cell in indexed_rows[0][1]]
|
||||
has_header = "email" in header
|
||||
email_index = header.index("email") if has_header else 0
|
||||
name_index: Optional[int] = (
|
||||
header.index("name")
|
||||
if has_header and "name" in header
|
||||
else (1 if not has_header else None)
|
||||
)
|
||||
data_rows = indexed_rows[1:] if has_header else indexed_rows
|
||||
|
||||
parsed_rows: list[_ParsedInviteRow] = []
|
||||
for row_number, row in data_rows:
|
||||
if len(parsed_rows) >= MAX_BULK_INVITE_ROWS:
|
||||
break
|
||||
email = row[email_index].strip() if len(row) > email_index else ""
|
||||
name = (
|
||||
row[name_index].strip()
|
||||
if name_index is not None and len(row) > name_index
|
||||
else ""
|
||||
)
|
||||
parsed_rows.append(
|
||||
_ParsedInviteRow(
|
||||
row_number=row_number,
|
||||
email=email,
|
||||
name=name or None,
|
||||
)
|
||||
)
|
||||
|
||||
return parsed_rows
|
||||
|
||||
|
||||
def _parse_bulk_invite_text(text: str) -> list[_ParsedInviteRow]:
|
||||
parsed_rows: list[_ParsedInviteRow] = []
|
||||
|
||||
for row_number, raw_line in enumerate(text.splitlines(), start=1):
|
||||
if len(parsed_rows) >= MAX_BULK_INVITE_ROWS:
|
||||
break
|
||||
line = raw_line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
parsed_rows.append(
|
||||
_ParsedInviteRow(
|
||||
row_number=row_number,
|
||||
email=line,
|
||||
name=None,
|
||||
)
|
||||
)
|
||||
|
||||
return parsed_rows
|
||||
|
||||
|
||||
def _parse_bulk_invite_file(
|
||||
filename: Optional[str],
|
||||
content: bytes,
|
||||
) -> list[_ParsedInviteRow]:
|
||||
text = _decode_bulk_invite_file(content)
|
||||
file_name = filename.lower() if filename else ""
|
||||
parsed_rows = (
|
||||
_parse_bulk_invite_csv(text)
|
||||
if file_name.endswith(".csv")
|
||||
else _parse_bulk_invite_text(text)
|
||||
)
|
||||
|
||||
if not parsed_rows:
|
||||
raise ValueError("Invite file did not contain any emails")
|
||||
|
||||
return parsed_rows
|
||||
|
||||
|
||||
async def bulk_create_invited_users_from_file(
|
||||
filename: Optional[str],
|
||||
content: bytes,
|
||||
) -> BulkInvitedUsersResult:
|
||||
parsed_rows = _parse_bulk_invite_file(filename, content)
|
||||
|
||||
created_count = 0
|
||||
skipped_count = 0
|
||||
error_count = 0
|
||||
results: list[BulkInvitedUserRowResult] = []
|
||||
seen_emails: set[str] = set()
|
||||
|
||||
for row in parsed_rows:
|
||||
row_name = _normalize_name(row.name)
|
||||
|
||||
try:
|
||||
validated_email = _email_adapter.validate_python(row.email)
|
||||
except ValidationError:
|
||||
error_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=row.email or None,
|
||||
name=row_name,
|
||||
status="ERROR",
|
||||
message="Invalid email address",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
normalized_email = normalize_email(str(validated_email))
|
||||
if normalized_email in seen_emails:
|
||||
skipped_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="SKIPPED",
|
||||
message="Duplicate email in upload file",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
seen_emails.add(normalized_email)
|
||||
|
||||
try:
|
||||
invited_user = await create_invited_user(normalized_email, row_name)
|
||||
except PreconditionFailed as exc:
|
||||
skipped_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="SKIPPED",
|
||||
message=str(exc),
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
masked = mask_email(normalized_email)
|
||||
logger.exception(
|
||||
"Failed to create bulk invite for row %s (%s)",
|
||||
row.row_number,
|
||||
masked,
|
||||
)
|
||||
error_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="ERROR",
|
||||
message="Unexpected error creating invite",
|
||||
)
|
||||
)
|
||||
else:
|
||||
created_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="CREATED",
|
||||
message="Invite created",
|
||||
invited_user=invited_user,
|
||||
)
|
||||
)
|
||||
|
||||
return BulkInvitedUsersResult(
|
||||
created_count=created_count,
|
||||
skipped_count=skipped_count,
|
||||
error_count=error_count,
|
||||
results=results,
|
||||
)
|
||||
|
||||
|
||||
async def _compute_invited_user_tally_seed(invited_user_id: str) -> None:
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"id": invited_user_id}
|
||||
)
|
||||
if invited_user is None:
|
||||
return
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
|
||||
return
|
||||
|
||||
try:
|
||||
r = await get_redis_async()
|
||||
except Exception:
|
||||
r = None
|
||||
|
||||
lock: AsyncClusterLock | None = None
|
||||
|
||||
if r is not None:
|
||||
lock = AsyncClusterLock(
|
||||
redis=r,
|
||||
key=f"tally_seed:{invited_user_id}",
|
||||
owner_id=_WORKER_ID,
|
||||
timeout=_TALLY_STALE_SECONDS,
|
||||
)
|
||||
current_owner = await lock.try_acquire()
|
||||
|
||||
if current_owner is None:
|
||||
logger.warn("Redis unvailable for tally lock - skipping tally enrichement")
|
||||
return
|
||||
elif current_owner != _WORKER_ID:
|
||||
logger.debug(
|
||||
"Tally seed for %s already locked by %s, skipping",
|
||||
invited_user_id,
|
||||
current_owner,
|
||||
)
|
||||
return
|
||||
if (
|
||||
invited_user.tallyStatus == prisma.enums.TallyComputationStatus.RUNNING
|
||||
and invited_user.updatedAt is not None
|
||||
):
|
||||
age = (datetime.now(timezone.utc) - invited_user.updatedAt).total_seconds()
|
||||
if age < _TALLY_STALE_SECONDS:
|
||||
logger.debug(
|
||||
"Tally task for %s still RUNNING (age=%ds), skipping",
|
||||
invited_user_id,
|
||||
int(age),
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"Tally task for %s is stale (age=%ds), re-running",
|
||||
invited_user_id,
|
||||
int(age),
|
||||
)
|
||||
|
||||
await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.RUNNING,
|
||||
"tallyError": None,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
input_data = await get_business_understanding_input_from_tally(
|
||||
invited_user.email,
|
||||
require_api_key=True,
|
||||
)
|
||||
payload = (
|
||||
SafeJson(input_data.model_dump(exclude_none=True))
|
||||
if input_data is not None
|
||||
else None
|
||||
)
|
||||
await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyUnderstanding": payload,
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.READY,
|
||||
"tallyComputedAt": datetime.now(timezone.utc),
|
||||
"tallyError": None,
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"Failed to compute Tally understanding for invited user %s",
|
||||
invited_user_id,
|
||||
)
|
||||
sanitized_error = re.sub(
|
||||
r"https?://\S+", "<url>", f"{type(exc).__name__}: {exc}"
|
||||
)[:_MAX_TALLY_ERROR_LENGTH]
|
||||
await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.FAILED,
|
||||
"tallyError": sanitized_error,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def schedule_invited_user_tally_precompute(invited_user_id: str) -> None:
|
||||
existing = _tally_seed_tasks.get(invited_user_id)
|
||||
if existing is not None and not existing.done():
|
||||
logger.debug("Tally task already running for %s, skipping", invited_user_id)
|
||||
return
|
||||
|
||||
task = asyncio.create_task(_compute_invited_user_tally_seed(invited_user_id))
|
||||
_tally_seed_tasks[invited_user_id] = task
|
||||
|
||||
def _on_done(t: asyncio.Task, _id: str = invited_user_id) -> None:
|
||||
if _tally_seed_tasks.get(_id) is t:
|
||||
del _tally_seed_tasks[_id]
|
||||
|
||||
task.add_done_callback(_on_done)
|
||||
|
||||
|
||||
async def _open_signup_create_user(
|
||||
auth_user_id: str,
|
||||
normalized_email: str,
|
||||
metadata_name: Optional[str],
|
||||
) -> User:
|
||||
"""Create a user without requiring an invite (open signup mode)."""
|
||||
preferred_name = _normalize_name(metadata_name)
|
||||
try:
|
||||
async with transaction() as tx:
|
||||
user = await prisma.models.User.prisma(tx).create(
|
||||
data=prisma.types.UserCreateInput(
|
||||
id=auth_user_id,
|
||||
email=normalized_email,
|
||||
name=preferred_name,
|
||||
)
|
||||
)
|
||||
await _ensure_default_profile(
|
||||
auth_user_id, normalized_email, preferred_name, tx
|
||||
)
|
||||
await _ensure_default_onboarding(auth_user_id, tx)
|
||||
except UniqueViolationError:
|
||||
existing = await prisma.models.User.prisma().find_unique(
|
||||
where={"id": auth_user_id}
|
||||
)
|
||||
if existing is not None:
|
||||
return User.from_db(existing)
|
||||
raise
|
||||
|
||||
return User.from_db(user)
|
||||
|
||||
|
||||
# TODO: We need to change this functions logic before going live
|
||||
async def get_or_activate_user(user_data: dict) -> User:
|
||||
auth_user_id = user_data.get("sub")
|
||||
if not auth_user_id:
|
||||
raise NotAuthorizedError("User ID not found in token")
|
||||
|
||||
auth_email = user_data.get("email")
|
||||
if not auth_email:
|
||||
raise NotAuthorizedError("Email not found in token")
|
||||
|
||||
normalized_email = normalize_email(auth_email)
|
||||
user_metadata = user_data.get("user_metadata")
|
||||
metadata_name = (
|
||||
user_metadata.get("name") if isinstance(user_metadata, dict) else None
|
||||
)
|
||||
|
||||
existing_user = None
|
||||
try:
|
||||
existing_user = await get_user_by_id(auth_user_id)
|
||||
except ValueError:
|
||||
existing_user = None
|
||||
except Exception:
|
||||
logger.exception("Error on get user by id during tally enrichment process")
|
||||
raise
|
||||
|
||||
if existing_user is not None:
|
||||
return existing_user
|
||||
|
||||
if not _settings.config.enable_invite_gate or is_internal_email(normalized_email):
|
||||
return await _open_signup_create_user(
|
||||
auth_user_id, normalized_email, metadata_name
|
||||
)
|
||||
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"email": normalized_email}
|
||||
)
|
||||
if invited_user is None:
|
||||
raise NotAuthorizedError("Your email is not allowed to access the platform")
|
||||
|
||||
if invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
|
||||
raise NotAuthorizedError("Your invitation is no longer active")
|
||||
|
||||
try:
|
||||
async with transaction() as tx:
|
||||
current_user = await prisma.models.User.prisma(tx).find_unique(
|
||||
where={"id": auth_user_id}
|
||||
)
|
||||
if current_user is not None:
|
||||
return User.from_db(current_user)
|
||||
|
||||
current_invited_user = await prisma.models.InvitedUser.prisma(
|
||||
tx
|
||||
).find_unique(where={"email": normalized_email})
|
||||
if current_invited_user is None:
|
||||
raise NotAuthorizedError(
|
||||
"Your email is not allowed to access the platform"
|
||||
)
|
||||
|
||||
if current_invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
|
||||
raise NotAuthorizedError("Your invitation is no longer active")
|
||||
|
||||
if current_invited_user.authUserId not in (None, auth_user_id):
|
||||
raise NotAuthorizedError("Your invitation has already been claimed")
|
||||
|
||||
preferred_name = current_invited_user.name or _normalize_name(metadata_name)
|
||||
await prisma.models.User.prisma(tx).create(
|
||||
data=prisma.types.UserCreateInput(
|
||||
id=auth_user_id,
|
||||
email=normalized_email,
|
||||
name=preferred_name,
|
||||
)
|
||||
)
|
||||
|
||||
await prisma.models.InvitedUser.prisma(tx).update(
|
||||
where={"id": current_invited_user.id},
|
||||
data={
|
||||
"status": prisma.enums.InvitedUserStatus.CLAIMED,
|
||||
"authUserId": auth_user_id,
|
||||
},
|
||||
)
|
||||
|
||||
await _ensure_default_profile(
|
||||
auth_user_id,
|
||||
normalized_email,
|
||||
preferred_name,
|
||||
tx,
|
||||
)
|
||||
await _ensure_default_onboarding(auth_user_id, tx)
|
||||
await _apply_tally_understanding(auth_user_id, current_invited_user, tx)
|
||||
except UniqueViolationError:
|
||||
logger.info("Concurrent activation for user %s; re-fetching", auth_user_id)
|
||||
already_created = await prisma.models.User.prisma().find_unique(
|
||||
where={"id": auth_user_id}
|
||||
)
|
||||
if already_created is not None:
|
||||
return User.from_db(already_created)
|
||||
raise RuntimeError(
|
||||
f"UniqueViolationError during activation but user {auth_user_id} not found"
|
||||
)
|
||||
|
||||
get_user_by_id.cache_delete(auth_user_id)
|
||||
get_user_by_email.cache_delete(normalized_email)
|
||||
|
||||
activated_user = await prisma.models.User.prisma().find_unique(
|
||||
where={"id": auth_user_id}
|
||||
)
|
||||
if activated_user is None:
|
||||
raise RuntimeError(
|
||||
f"Activated user {auth_user_id} was not found after creation"
|
||||
)
|
||||
|
||||
return User.from_db(activated_user)
|
||||
@@ -1,409 +0,0 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.util.exceptions import NotAuthorizedError, PreconditionFailed
|
||||
|
||||
from .invited_user import (
|
||||
InvitedUserRecord,
|
||||
bulk_create_invited_users_from_file,
|
||||
check_invite_eligibility,
|
||||
create_invited_user,
|
||||
get_or_activate_user,
|
||||
retry_invited_user_tally,
|
||||
)
|
||||
|
||||
|
||||
def _invited_user_db_record(
|
||||
*,
|
||||
status: prisma.enums.InvitedUserStatus = prisma.enums.InvitedUserStatus.INVITED,
|
||||
tally_understanding: dict | None = None,
|
||||
):
|
||||
now = datetime.now(timezone.utc)
|
||||
return SimpleNamespace(
|
||||
id="invite-1",
|
||||
email="invited@example.com",
|
||||
status=status,
|
||||
authUserId=None,
|
||||
name="Invited User",
|
||||
tallyUnderstanding=tally_understanding,
|
||||
tallyStatus=prisma.enums.TallyComputationStatus.PENDING,
|
||||
tallyComputedAt=None,
|
||||
tallyError=None,
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
)
|
||||
|
||||
|
||||
def _invited_user_record(
|
||||
*,
|
||||
status: prisma.enums.InvitedUserStatus = prisma.enums.InvitedUserStatus.INVITED,
|
||||
tally_understanding: dict | None = None,
|
||||
):
|
||||
return InvitedUserRecord.from_db(
|
||||
cast(
|
||||
prisma.models.InvitedUser,
|
||||
_invited_user_db_record(
|
||||
status=status,
|
||||
tally_understanding=tally_understanding,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _user_db_record():
|
||||
now = datetime.now(timezone.utc)
|
||||
return SimpleNamespace(
|
||||
id="auth-user-1",
|
||||
email="invited@example.com",
|
||||
emailVerified=True,
|
||||
name="Invited User",
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
metadata={},
|
||||
integrations="",
|
||||
stripeCustomerId=None,
|
||||
topUpConfig=None,
|
||||
maxEmailsPerDay=3,
|
||||
notifyOnAgentRun=True,
|
||||
notifyOnZeroBalance=True,
|
||||
notifyOnLowBalance=True,
|
||||
notifyOnBlockExecutionFailed=True,
|
||||
notifyOnContinuousAgentError=True,
|
||||
notifyOnDailySummary=True,
|
||||
notifyOnWeeklySummary=True,
|
||||
notifyOnMonthlySummary=True,
|
||||
notifyOnAgentApproved=True,
|
||||
notifyOnAgentRejected=True,
|
||||
timezone="not-set",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_invited_user_rejects_existing_active_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
user_repo = Mock()
|
||||
user_repo.find_unique = AsyncMock(return_value=_user_db_record())
|
||||
invited_user_repo = Mock()
|
||||
invited_user_repo.find_unique = AsyncMock()
|
||||
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=invited_user_repo,
|
||||
)
|
||||
|
||||
with pytest.raises(PreconditionFailed):
|
||||
await create_invited_user("Invited@example.com")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_invited_user_schedules_tally_seed(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
user_repo = Mock()
|
||||
user_repo.find_unique = AsyncMock(return_value=None)
|
||||
invited_user_repo = Mock()
|
||||
invited_user_repo.find_unique = AsyncMock(return_value=None)
|
||||
invited_user_repo.create = AsyncMock(return_value=_invited_user_db_record())
|
||||
schedule = mocker.patch(
|
||||
"backend.data.invited_user.schedule_invited_user_tally_precompute"
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=invited_user_repo,
|
||||
)
|
||||
|
||||
invited_user = await create_invited_user("Invited@example.com", "Invited User")
|
||||
|
||||
assert invited_user.email == "invited@example.com"
|
||||
invited_user_repo.create.assert_awaited_once()
|
||||
schedule.assert_called_once_with("invite-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_invited_user_tally_resets_state_and_schedules(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
invited_user_repo = Mock()
|
||||
invited_user_repo.find_unique = AsyncMock(return_value=_invited_user_db_record())
|
||||
invited_user_repo.update = AsyncMock(return_value=_invited_user_db_record())
|
||||
schedule = mocker.patch(
|
||||
"backend.data.invited_user.schedule_invited_user_tally_precompute"
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=invited_user_repo,
|
||||
)
|
||||
|
||||
invited_user = await retry_invited_user_tally("invite-1")
|
||||
|
||||
assert invited_user.id == "invite-1"
|
||||
invited_user_repo.update.assert_awaited_once()
|
||||
schedule.assert_called_once_with("invite-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_activate_user_requires_invite(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
invited_user_repo = Mock()
|
||||
invited_user_repo.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
mock_get_user_by_id = AsyncMock(side_effect=ValueError("User not found"))
|
||||
mock_get_user_by_id.cache_delete = Mock()
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.get_user_by_id",
|
||||
mock_get_user_by_id,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user._settings.config.enable_invite_gate",
|
||||
True,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=invited_user_repo,
|
||||
)
|
||||
|
||||
with pytest.raises(NotAuthorizedError):
|
||||
await get_or_activate_user(
|
||||
{"sub": "auth-user-1", "email": "invited@example.com"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_activate_user_creates_user_from_invite(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
tx = object()
|
||||
invited_user = _invited_user_db_record(
|
||||
tally_understanding={"user_name": "Invited User", "industry": "Automation"}
|
||||
)
|
||||
created_user = _user_db_record()
|
||||
|
||||
outside_user_repo = Mock()
|
||||
# Only called once at post-transaction verification (line 741);
|
||||
# get_user_by_id (line 657) uses prisma.user.find_unique, not this mock.
|
||||
outside_user_repo.find_unique = AsyncMock(return_value=created_user)
|
||||
|
||||
inside_user_repo = Mock()
|
||||
inside_user_repo.find_unique = AsyncMock(return_value=None)
|
||||
inside_user_repo.create = AsyncMock(return_value=created_user)
|
||||
|
||||
outside_invited_repo = Mock()
|
||||
outside_invited_repo.find_unique = AsyncMock(return_value=invited_user)
|
||||
|
||||
inside_invited_repo = Mock()
|
||||
inside_invited_repo.find_unique = AsyncMock(return_value=invited_user)
|
||||
inside_invited_repo.update = AsyncMock(return_value=invited_user)
|
||||
|
||||
def user_prisma(client=None):
|
||||
return inside_user_repo if client is tx else outside_user_repo
|
||||
|
||||
def invited_user_prisma(client=None):
|
||||
return inside_invited_repo if client is tx else outside_invited_repo
|
||||
|
||||
@asynccontextmanager
|
||||
async def fake_transaction():
|
||||
yield tx
|
||||
|
||||
# Mock get_user_by_id since it uses prisma.user.find_unique (global client),
|
||||
# not prisma.models.User.prisma().find_unique which we mock above.
|
||||
mock_get_user_by_id = AsyncMock(side_effect=ValueError("User not found"))
|
||||
mock_get_user_by_id.cache_delete = Mock()
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.get_user_by_id",
|
||||
mock_get_user_by_id,
|
||||
)
|
||||
mock_get_user_by_email = AsyncMock()
|
||||
mock_get_user_by_email.cache_delete = Mock()
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.get_user_by_email",
|
||||
mock_get_user_by_email,
|
||||
)
|
||||
ensure_profile = mocker.patch(
|
||||
"backend.data.invited_user._ensure_default_profile",
|
||||
AsyncMock(),
|
||||
)
|
||||
ensure_onboarding = mocker.patch(
|
||||
"backend.data.invited_user._ensure_default_onboarding",
|
||||
AsyncMock(),
|
||||
)
|
||||
apply_tally = mocker.patch(
|
||||
"backend.data.invited_user._apply_tally_understanding",
|
||||
AsyncMock(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user._settings.config.enable_invite_gate",
|
||||
True,
|
||||
)
|
||||
mocker.patch("backend.data.invited_user.transaction", fake_transaction)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.User.prisma", side_effect=user_prisma
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
side_effect=invited_user_prisma,
|
||||
)
|
||||
|
||||
user = await get_or_activate_user(
|
||||
{
|
||||
"sub": "auth-user-1",
|
||||
"email": "Invited@example.com",
|
||||
"user_metadata": {"name": "Invited User"},
|
||||
}
|
||||
)
|
||||
|
||||
assert user.id == "auth-user-1"
|
||||
inside_user_repo.create.assert_awaited_once()
|
||||
inside_invited_repo.update.assert_awaited_once()
|
||||
ensure_profile.assert_awaited_once()
|
||||
ensure_onboarding.assert_awaited_once_with("auth-user-1", tx)
|
||||
apply_tally.assert_awaited_once_with("auth-user-1", invited_user, tx)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_create_invited_users_from_text_file(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
create_invited = mocker.patch(
|
||||
"backend.data.invited_user.create_invited_user",
|
||||
AsyncMock(
|
||||
side_effect=[
|
||||
_invited_user_record(),
|
||||
_invited_user_record(),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
result = await bulk_create_invited_users_from_file(
|
||||
"invites.txt",
|
||||
b"Invited@example.com\nsecond@example.com\n",
|
||||
)
|
||||
|
||||
assert result.created_count == 2
|
||||
assert result.skipped_count == 0
|
||||
assert result.error_count == 0
|
||||
assert [row.status for row in result.results] == ["CREATED", "CREATED"]
|
||||
assert create_invited.await_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_create_invited_users_handles_csv_duplicates_and_invalid_rows(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
create_invited = mocker.patch(
|
||||
"backend.data.invited_user.create_invited_user",
|
||||
AsyncMock(
|
||||
side_effect=[
|
||||
_invited_user_record(),
|
||||
PreconditionFailed("An invited user with this email already exists"),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
result = await bulk_create_invited_users_from_file(
|
||||
"invites.csv",
|
||||
(
|
||||
"email,name\n"
|
||||
"valid@example.com,Valid User\n"
|
||||
"not-an-email,Bad Row\n"
|
||||
"valid@example.com,Duplicate In File\n"
|
||||
"existing@example.com,Existing User\n"
|
||||
).encode("utf-8"),
|
||||
)
|
||||
|
||||
assert result.created_count == 1
|
||||
assert result.skipped_count == 2
|
||||
assert result.error_count == 1
|
||||
assert [row.status for row in result.results] == [
|
||||
"CREATED",
|
||||
"ERROR",
|
||||
"SKIPPED",
|
||||
"SKIPPED",
|
||||
]
|
||||
assert create_invited.await_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_invite_eligibility tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_invite_eligibility_returns_true_for_invited(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
invited = _invited_user_db_record(status=prisma.enums.InvitedUserStatus.INVITED)
|
||||
repo = Mock()
|
||||
repo.find_unique = AsyncMock(return_value=invited)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=repo,
|
||||
)
|
||||
|
||||
result = await check_invite_eligibility("invited@example.com")
|
||||
assert result is True
|
||||
repo.find_unique.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_invite_eligibility_returns_false_for_no_record(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
repo = Mock()
|
||||
repo.find_unique = AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=repo,
|
||||
)
|
||||
|
||||
result = await check_invite_eligibility("unknown@example.com")
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_invite_eligibility_returns_false_for_claimed(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
claimed = _invited_user_db_record(status=prisma.enums.InvitedUserStatus.CLAIMED)
|
||||
repo = Mock()
|
||||
repo.find_unique = AsyncMock(return_value=claimed)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=repo,
|
||||
)
|
||||
|
||||
result = await check_invite_eligibility("claimed@example.com")
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_invite_eligibility_returns_false_for_revoked(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
revoked = _invited_user_db_record(status=prisma.enums.InvitedUserStatus.REVOKED)
|
||||
repo = Mock()
|
||||
repo.find_unique = AsyncMock(return_value=revoked)
|
||||
mocker.patch(
|
||||
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
|
||||
return_value=repo,
|
||||
)
|
||||
|
||||
result = await check_invite_eligibility("revoked@example.com")
|
||||
assert result is False
|
||||
@@ -40,11 +40,8 @@ _MAX_PAGES = 100
|
||||
# LLM extraction timeout (seconds)
|
||||
_LLM_TIMEOUT = 30
|
||||
|
||||
SUGGESTION_THEMES = ["Learn", "Create", "Automate", "Organize"]
|
||||
PROMPTS_PER_THEME = 5
|
||||
|
||||
|
||||
def mask_email(email: str) -> str:
|
||||
def _mask_email(email: str) -> str:
|
||||
"""Mask an email for safe logging: 'alice@example.com' -> 'a***e@example.com'."""
|
||||
try:
|
||||
local, domain = email.rsplit("@", 1)
|
||||
@@ -199,7 +196,8 @@ async def _refresh_cache(form_id: str) -> tuple[dict, list]:
|
||||
|
||||
Returns (email_index, questions).
|
||||
"""
|
||||
client = _make_tally_client(_settings.secrets.tally_api_key)
|
||||
settings = Settings()
|
||||
client = _make_tally_client(settings.secrets.tally_api_key)
|
||||
|
||||
redis = await get_redis_async()
|
||||
last_fetch_key = _LAST_FETCH_KEY.format(form_id=form_id)
|
||||
@@ -334,11 +332,6 @@ Fields:
|
||||
- current_software (list of strings): software/tools currently used
|
||||
- existing_automation (list of strings): existing automations
|
||||
- additional_notes (string): any additional context
|
||||
- suggested_prompts (object with keys "Learn", "Create", "Automate", "Organize"): for each key, \
|
||||
provide a list of 5 short action prompts (each under 20 words) that would help this person. \
|
||||
"Learn" = questions about AutoGPT features; "Create" = content/document generation tasks; \
|
||||
"Automate" = recurring workflow automation ideas; "Organize" = structuring/prioritizing tasks. \
|
||||
Should be specific to their industry, role, and pain points; actionable and conversational in tone.
|
||||
|
||||
Form data:
|
||||
"""
|
||||
@@ -346,21 +339,21 @@ Form data:
|
||||
_EXTRACTION_SUFFIX = "\n\nReturn ONLY valid JSON."
|
||||
|
||||
|
||||
async def extract_business_understanding_from_tally(
|
||||
async def extract_business_understanding(
|
||||
formatted_text: str,
|
||||
) -> BusinessUnderstandingInput:
|
||||
"""
|
||||
Use an LLM to extract structured business understanding from form text.
|
||||
"""Use an LLM to extract structured business understanding from form text.
|
||||
|
||||
Raises on timeout or unparseable response so the caller can handle it.
|
||||
"""
|
||||
api_key = _settings.secrets.open_router_api_key
|
||||
settings = Settings()
|
||||
api_key = settings.secrets.open_router_api_key
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=OPENROUTER_BASE_URL)
|
||||
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
client.chat.completions.create(
|
||||
model=_settings.config.tally_extraction_llm_model,
|
||||
model="openai/gpt-4o-mini",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
@@ -385,60 +378,9 @@ async def extract_business_understanding_from_tally(
|
||||
|
||||
# Filter out null values before constructing
|
||||
cleaned = {k: v for k, v in data.items() if v is not None}
|
||||
|
||||
# Validate suggested_prompts: themed dict, filter >20 words, cap at 5 per theme
|
||||
raw_prompts = cleaned.get("suggested_prompts", {})
|
||||
if isinstance(raw_prompts, dict):
|
||||
themed: dict[str, list[str]] = {}
|
||||
for theme in SUGGESTION_THEMES:
|
||||
theme_prompts = raw_prompts.get(theme, [])
|
||||
if not isinstance(theme_prompts, list):
|
||||
continue
|
||||
valid = [
|
||||
s
|
||||
for p in theme_prompts
|
||||
if isinstance(p, str) and (s := p.strip()) and len(s.split()) <= 20
|
||||
]
|
||||
if valid:
|
||||
themed[theme] = valid[:PROMPTS_PER_THEME]
|
||||
if themed:
|
||||
cleaned["suggested_prompts"] = themed
|
||||
else:
|
||||
cleaned.pop("suggested_prompts", None)
|
||||
else:
|
||||
cleaned.pop("suggested_prompts", None)
|
||||
|
||||
return BusinessUnderstandingInput(**cleaned)
|
||||
|
||||
|
||||
async def get_business_understanding_input_from_tally(
|
||||
email: str,
|
||||
*,
|
||||
require_api_key: bool = False,
|
||||
) -> Optional[BusinessUnderstandingInput]:
|
||||
if not _settings.secrets.tally_api_key:
|
||||
if require_api_key:
|
||||
raise RuntimeError("Tally API key is not configured")
|
||||
logger.debug("Tally: no API key configured, skipping")
|
||||
return None
|
||||
|
||||
masked = mask_email(email)
|
||||
result = await find_submission_by_email(TALLY_FORM_ID, email)
|
||||
if result is None:
|
||||
logger.debug(f"Tally: no submission found for {masked}")
|
||||
return None
|
||||
|
||||
submission, questions = result
|
||||
logger.info(f"Tally: found submission for {masked}, extracting understanding")
|
||||
|
||||
formatted = format_submission_for_llm(submission, questions)
|
||||
if not formatted.strip():
|
||||
logger.warning("Tally: formatted submission was empty, skipping")
|
||||
return None
|
||||
|
||||
return await extract_business_understanding_from_tally(formatted)
|
||||
|
||||
|
||||
async def populate_understanding_from_tally(user_id: str, email: str) -> None:
|
||||
"""Main orchestrator: check Tally for a matching submission and populate understanding.
|
||||
|
||||
@@ -453,9 +395,32 @@ async def populate_understanding_from_tally(user_id: str, email: str) -> None:
|
||||
)
|
||||
return
|
||||
|
||||
understanding_input = await get_business_understanding_input_from_tally(email)
|
||||
if understanding_input is None:
|
||||
# Check required config is present
|
||||
settings = Settings()
|
||||
if not settings.secrets.tally_api_key or not settings.secrets.tally_form_id:
|
||||
logger.debug("Tally: Tally config incomplete, skipping")
|
||||
return
|
||||
if not settings.secrets.open_router_api_key:
|
||||
logger.debug("Tally: no OpenRouter API key configured, skipping")
|
||||
return
|
||||
|
||||
# Look up submission by email
|
||||
masked = _mask_email(email)
|
||||
result = await find_submission_by_email(settings.secrets.tally_form_id, email)
|
||||
if result is None:
|
||||
logger.debug(f"Tally: no submission found for {masked}")
|
||||
return
|
||||
|
||||
submission, questions = result
|
||||
logger.info(f"Tally: found submission for {masked}, extracting understanding")
|
||||
|
||||
# Format and extract
|
||||
formatted = format_submission_for_llm(submission, questions)
|
||||
if not formatted.strip():
|
||||
logger.warning("Tally: formatted submission was empty, skipping")
|
||||
return
|
||||
|
||||
understanding_input = await extract_business_understanding(formatted)
|
||||
|
||||
# Upsert into database
|
||||
await upsert_business_understanding(user_id, understanding_input)
|
||||
|
||||
@@ -12,11 +12,11 @@ from backend.data.tally import (
|
||||
_build_email_index,
|
||||
_format_answer,
|
||||
_make_tally_client,
|
||||
_mask_email,
|
||||
_refresh_cache,
|
||||
extract_business_understanding_from_tally,
|
||||
extract_business_understanding,
|
||||
find_submission_by_email,
|
||||
format_submission_for_llm,
|
||||
mask_email,
|
||||
populate_understanding_from_tally,
|
||||
)
|
||||
|
||||
@@ -248,7 +248,7 @@ async def test_populate_understanding_skips_no_api_key():
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.data.tally._settings", mock_settings),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
@@ -284,7 +284,6 @@ async def test_populate_understanding_full_flow():
|
||||
],
|
||||
}
|
||||
mock_input = MagicMock()
|
||||
mock_input.suggested_prompts = {"Learn": ["P1"], "Create": ["P2"]}
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -292,14 +291,14 @@ async def test_populate_understanding_full_flow():
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.data.tally._settings", mock_settings),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(submission, SAMPLE_QUESTIONS),
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally.extract_business_understanding_from_tally",
|
||||
"backend.data.tally.extract_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_input,
|
||||
) as mock_extract,
|
||||
@@ -332,14 +331,14 @@ async def test_populate_understanding_handles_llm_timeout():
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.data.tally._settings", mock_settings),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(submission, SAMPLE_QUESTIONS),
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally.extract_business_understanding_from_tally",
|
||||
"backend.data.tally.extract_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=asyncio.TimeoutError(),
|
||||
),
|
||||
@@ -357,13 +356,13 @@ async def test_populate_understanding_handles_llm_timeout():
|
||||
|
||||
|
||||
def test_mask_email():
|
||||
assert mask_email("alice@example.com") == "a***e@example.com"
|
||||
assert mask_email("ab@example.com") == "a***@example.com"
|
||||
assert mask_email("a@example.com") == "a***@example.com"
|
||||
assert _mask_email("alice@example.com") == "a***e@example.com"
|
||||
assert _mask_email("ab@example.com") == "a***@example.com"
|
||||
assert _mask_email("a@example.com") == "a***@example.com"
|
||||
|
||||
|
||||
def test_mask_email_invalid():
|
||||
assert mask_email("no-at-sign") == "***"
|
||||
assert _mask_email("no-at-sign") == "***"
|
||||
|
||||
|
||||
# ── Prompt construction (curly-brace safety) ─────────────────────────────────
|
||||
@@ -394,29 +393,19 @@ def test_extraction_prompt_no_format_placeholders():
|
||||
assert single_braces == [], f"Found format placeholders: {single_braces}"
|
||||
|
||||
|
||||
# ── extract_business_understanding_from_tally ────────────────────────────────────────────
|
||||
# ── extract_business_understanding ────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_themed_prompts():
|
||||
"""Happy path: LLM returns themed prompts as dict."""
|
||||
async def test_extract_business_understanding_success():
|
||||
"""Happy path: LLM returns valid JSON that maps to BusinessUnderstandingInput."""
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
{
|
||||
"user_name": "Alice",
|
||||
"business_name": "Acme Corp",
|
||||
"suggested_prompts": {
|
||||
"Learn": ["Learn 1", "Learn 2", "Learn 3", "Learn 4", "Learn 5"],
|
||||
"Create": [
|
||||
"Create 1",
|
||||
"Create 2",
|
||||
"Create 3",
|
||||
"Create 4",
|
||||
"Create 5",
|
||||
],
|
||||
"Automate": ["Auto 1", "Auto 2", "Auto 3", "Auto 4", "Auto 5"],
|
||||
"Organize": ["Org 1", "Org 2", "Org 3", "Org 4", "Org 5"],
|
||||
},
|
||||
"industry": "Technology",
|
||||
"pain_points": ["manual reporting"],
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
@@ -426,49 +415,16 @@ async def test_extract_business_understanding_themed_prompts():
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
|
||||
result = await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
|
||||
result = await extract_business_understanding("Q: Name?\nA: Alice")
|
||||
|
||||
assert result.user_name == "Alice"
|
||||
assert result.suggested_prompts is not None
|
||||
assert len(result.suggested_prompts) == 4
|
||||
assert len(result.suggested_prompts["Learn"]) == 5
|
||||
assert result.business_name == "Acme Corp"
|
||||
assert result.industry == "Technology"
|
||||
assert result.pain_points == ["manual reporting"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_themed_prompts_filters_long_and_unknown_keys():
|
||||
"""Long prompts are filtered, unknown keys are dropped, each theme capped at 5."""
|
||||
long_prompt = " ".join(["word"] * 21)
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
{
|
||||
"user_name": "Alice",
|
||||
"suggested_prompts": {
|
||||
"Learn": [long_prompt, "Valid learn 1", "Valid learn 2"],
|
||||
"UnknownTheme": ["Should be dropped"],
|
||||
"Automate": ["A1", "A2", "A3", "A4", "A5", "A6"],
|
||||
},
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
|
||||
result = await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
|
||||
|
||||
assert result.suggested_prompts is not None
|
||||
# Unknown key dropped
|
||||
assert "UnknownTheme" not in result.suggested_prompts
|
||||
# Long prompt filtered
|
||||
assert result.suggested_prompts["Learn"] == ["Valid learn 1", "Valid learn 2"]
|
||||
# Capped at 5
|
||||
assert result.suggested_prompts["Automate"] == ["A1", "A2", "A3", "A4", "A5"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_from_tally_filters_nulls():
|
||||
async def test_extract_business_understanding_filters_nulls():
|
||||
"""Null values from LLM should be excluded from the result."""
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
@@ -481,7 +437,7 @@ async def test_extract_business_understanding_from_tally_filters_nulls():
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
|
||||
result = await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
|
||||
result = await extract_business_understanding("Q: Name?\nA: Alice")
|
||||
|
||||
assert result.user_name == "Alice"
|
||||
assert result.business_name is None
|
||||
@@ -489,7 +445,7 @@ async def test_extract_business_understanding_from_tally_filters_nulls():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_from_tally_invalid_json():
|
||||
async def test_extract_business_understanding_invalid_json():
|
||||
"""Invalid JSON from LLM should raise JSONDecodeError."""
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = "not valid json {"
|
||||
@@ -503,11 +459,11 @@ async def test_extract_business_understanding_from_tally_invalid_json():
|
||||
patch("backend.data.tally.AsyncOpenAI", return_value=mock_client),
|
||||
pytest.raises(json.JSONDecodeError),
|
||||
):
|
||||
await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
|
||||
await extract_business_understanding("Q: Name?\nA: Alice")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_business_understanding_from_tally_timeout():
|
||||
async def test_extract_business_understanding_timeout():
|
||||
"""LLM timeout should propagate as asyncio.TimeoutError."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.side_effect = asyncio.TimeoutError()
|
||||
@@ -517,7 +473,7 @@ async def test_extract_business_understanding_from_tally_timeout():
|
||||
patch("backend.data.tally._LLM_TIMEOUT", 0.001),
|
||||
pytest.raises(asyncio.TimeoutError),
|
||||
):
|
||||
await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
|
||||
await extract_business_understanding("Q: Name?\nA: Alice")
|
||||
|
||||
|
||||
# ── _refresh_cache ───────────────────────────────────────────────────────────
|
||||
@@ -536,7 +492,7 @@ async def test_refresh_cache_full_fetch():
|
||||
submissions = SAMPLE_SUBMISSIONS
|
||||
|
||||
with (
|
||||
patch("backend.data.tally._settings", mock_settings),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
@@ -584,7 +540,7 @@ async def test_refresh_cache_incremental_fetch():
|
||||
new_submissions = [SAMPLE_SUBMISSIONS[0]] # Just Alice
|
||||
|
||||
with (
|
||||
patch("backend.data.tally._settings", mock_settings),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
|
||||
@@ -31,25 +31,6 @@ def _json_to_list(value: Any) -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
def _json_to_themed_prompts(value: Any) -> dict[str, list[str]]:
|
||||
"""Convert Json field to themed prompts dict.
|
||||
|
||||
Handles both the new ``dict[str, list[str]]`` format and the legacy
|
||||
``list[str]`` format. Legacy rows are placed under a ``"General"`` key so
|
||||
existing personalised prompts remain readable until a backfill regenerates
|
||||
them into the proper themed shape.
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
k: [i for i in v if isinstance(i, str)]
|
||||
for k, v in value.items()
|
||||
if isinstance(k, str) and isinstance(v, list)
|
||||
}
|
||||
if isinstance(value, list) and value:
|
||||
return {"General": [str(p) for p in value if isinstance(p, str)]}
|
||||
return {}
|
||||
|
||||
|
||||
class BusinessUnderstandingInput(pydantic.BaseModel):
|
||||
"""Input model for updating business understanding - all fields optional for incremental updates."""
|
||||
|
||||
@@ -105,11 +86,6 @@ class BusinessUnderstandingInput(pydantic.BaseModel):
|
||||
None, description="Any additional context"
|
||||
)
|
||||
|
||||
# Suggested prompts (UI-only, not included in system prompt)
|
||||
suggested_prompts: Optional[dict[str, list[str]]] = pydantic.Field(
|
||||
None, description="LLM-generated suggested prompts grouped by theme"
|
||||
)
|
||||
|
||||
|
||||
class BusinessUnderstanding(pydantic.BaseModel):
|
||||
"""Full business understanding model returned from database."""
|
||||
@@ -146,9 +122,6 @@ class BusinessUnderstanding(pydantic.BaseModel):
|
||||
# Additional context
|
||||
additional_notes: Optional[str] = None
|
||||
|
||||
# Suggested prompts (UI-only, not included in system prompt)
|
||||
suggested_prompts: dict[str, list[str]] = pydantic.Field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_record: CoPilotUnderstanding) -> "BusinessUnderstanding":
|
||||
"""Convert database record to Pydantic model."""
|
||||
@@ -176,7 +149,6 @@ class BusinessUnderstanding(pydantic.BaseModel):
|
||||
current_software=_json_to_list(business.get("current_software")),
|
||||
existing_automation=_json_to_list(business.get("existing_automation")),
|
||||
additional_notes=business.get("additional_notes"),
|
||||
suggested_prompts=_json_to_themed_prompts(data.get("suggested_prompts")),
|
||||
)
|
||||
|
||||
|
||||
@@ -194,62 +166,6 @@ def _merge_lists(existing: list | None, new: list | None) -> list | None:
|
||||
return merged
|
||||
|
||||
|
||||
def merge_business_understanding_data(
|
||||
existing_data: dict[str, Any],
|
||||
input_data: BusinessUnderstandingInput,
|
||||
) -> dict[str, Any]:
|
||||
merged_data = dict(existing_data)
|
||||
|
||||
merged_business: dict[str, Any] = {}
|
||||
if isinstance(merged_data.get("business"), dict):
|
||||
merged_business = dict(merged_data["business"])
|
||||
|
||||
business_string_fields = [
|
||||
"job_title",
|
||||
"business_name",
|
||||
"industry",
|
||||
"business_size",
|
||||
"user_role",
|
||||
"additional_notes",
|
||||
]
|
||||
business_list_fields = [
|
||||
"key_workflows",
|
||||
"daily_activities",
|
||||
"pain_points",
|
||||
"bottlenecks",
|
||||
"manual_tasks",
|
||||
"automation_goals",
|
||||
"current_software",
|
||||
"existing_automation",
|
||||
]
|
||||
|
||||
if input_data.user_name is not None:
|
||||
merged_data["name"] = input_data.user_name
|
||||
|
||||
for field in business_string_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
merged_business[field] = value
|
||||
|
||||
for field in business_list_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
existing_list = _json_to_list(merged_business.get(field))
|
||||
merged_list = _merge_lists(existing_list, value)
|
||||
merged_business[field] = merged_list
|
||||
|
||||
merged_business["version"] = 1
|
||||
merged_data["business"] = merged_business
|
||||
|
||||
# suggested_prompts lives at the top level (not under `business`) because
|
||||
# it's a UI-only artifact consumed by the frontend, not business understanding
|
||||
# data. The `business` sub-dict feeds the system prompt.
|
||||
if input_data.suggested_prompts is not None:
|
||||
merged_data["suggested_prompts"] = input_data.suggested_prompts
|
||||
|
||||
return merged_data
|
||||
|
||||
|
||||
async def _get_from_cache(user_id: str) -> Optional[BusinessUnderstanding]:
|
||||
"""Get business understanding from Redis cache."""
|
||||
try:
|
||||
@@ -329,18 +245,63 @@ async def upsert_business_understanding(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
|
||||
# Get existing data structure or start fresh
|
||||
existing_data: dict[str, Any] = {}
|
||||
if existing and isinstance(existing.data, dict):
|
||||
existing_data = dict(existing.data)
|
||||
|
||||
merged_data = merge_business_understanding_data(existing_data, input_data)
|
||||
existing_business: dict[str, Any] = {}
|
||||
if isinstance(existing_data.get("business"), dict):
|
||||
existing_business = dict(existing_data["business"])
|
||||
|
||||
# Business fields (stored inside business object)
|
||||
business_string_fields = [
|
||||
"job_title",
|
||||
"business_name",
|
||||
"industry",
|
||||
"business_size",
|
||||
"user_role",
|
||||
"additional_notes",
|
||||
]
|
||||
business_list_fields = [
|
||||
"key_workflows",
|
||||
"daily_activities",
|
||||
"pain_points",
|
||||
"bottlenecks",
|
||||
"manual_tasks",
|
||||
"automation_goals",
|
||||
"current_software",
|
||||
"existing_automation",
|
||||
]
|
||||
|
||||
# Handle top-level name field
|
||||
if input_data.user_name is not None:
|
||||
existing_data["name"] = input_data.user_name
|
||||
|
||||
# Business string fields - overwrite if provided
|
||||
for field in business_string_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
existing_business[field] = value
|
||||
|
||||
# Business list fields - merge with existing
|
||||
for field in business_list_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value is not None:
|
||||
existing_list = _json_to_list(existing_business.get(field))
|
||||
merged = _merge_lists(existing_list, value)
|
||||
existing_business[field] = merged
|
||||
|
||||
# Set version and nest business data
|
||||
existing_business["version"] = 1
|
||||
existing_data["business"] = existing_business
|
||||
|
||||
# Upsert with the merged data
|
||||
record = await CoPilotUnderstanding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "data": SafeJson(merged_data)},
|
||||
"update": {"data": SafeJson(merged_data)},
|
||||
"create": {"userId": user_id, "data": SafeJson(existing_data)},
|
||||
"update": {"data": SafeJson(existing_data)},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -1,148 +0,0 @@
|
||||
"""Tests for business understanding merge and format logic."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstanding,
|
||||
BusinessUnderstandingInput,
|
||||
_json_to_themed_prompts,
|
||||
format_understanding_for_prompt,
|
||||
merge_business_understanding_data,
|
||||
)
|
||||
|
||||
|
||||
def _make_input(**kwargs: Any) -> BusinessUnderstandingInput:
|
||||
"""Create a BusinessUnderstandingInput with only the specified fields."""
|
||||
return BusinessUnderstandingInput.model_validate(kwargs)
|
||||
|
||||
|
||||
# ─── merge_business_understanding_data: themed prompts ─────────────────
|
||||
|
||||
|
||||
def test_merge_themed_prompts_overwrites_existing():
|
||||
"""New themed prompts should fully replace existing ones (not merge)."""
|
||||
existing = {
|
||||
"name": "Alice",
|
||||
"business": {"industry": "Tech", "version": 1},
|
||||
"suggested_prompts": {
|
||||
"Learn": ["Old learn prompt"],
|
||||
"Create": ["Old create prompt"],
|
||||
},
|
||||
}
|
||||
new_prompts = {
|
||||
"Automate": ["Schedule daily reports", "Set up email alerts"],
|
||||
"Organize": ["Sort inbox by priority"],
|
||||
}
|
||||
input_data = _make_input(suggested_prompts=new_prompts)
|
||||
|
||||
result = merge_business_understanding_data(existing, input_data)
|
||||
|
||||
assert result["suggested_prompts"] == new_prompts
|
||||
|
||||
|
||||
def test_merge_themed_prompts_none_preserves_existing():
|
||||
"""When input has suggested_prompts=None, existing themed prompts are preserved."""
|
||||
existing_prompts = {
|
||||
"Learn": ["How to automate?"],
|
||||
"Create": ["Build a chatbot"],
|
||||
}
|
||||
existing = {
|
||||
"name": "Alice",
|
||||
"business": {"industry": "Tech", "version": 1},
|
||||
"suggested_prompts": existing_prompts,
|
||||
}
|
||||
input_data = _make_input(industry="Finance")
|
||||
|
||||
result = merge_business_understanding_data(existing, input_data)
|
||||
|
||||
assert result["suggested_prompts"] == existing_prompts
|
||||
assert result["business"]["industry"] == "Finance"
|
||||
|
||||
|
||||
# ─── from_db: themed prompts deserialization ───────────────────────────
|
||||
|
||||
|
||||
def test_from_db_themed_prompts():
|
||||
"""from_db correctly deserializes a themed dict for suggested_prompts."""
|
||||
themed = {
|
||||
"Learn": ["What can I automate?"],
|
||||
"Create": ["Build a workflow"],
|
||||
}
|
||||
db_record = MagicMock()
|
||||
db_record.id = "test-id"
|
||||
db_record.userId = "user-1"
|
||||
db_record.createdAt = datetime.now(tz=timezone.utc)
|
||||
db_record.updatedAt = datetime.now(tz=timezone.utc)
|
||||
db_record.data = {
|
||||
"name": "Alice",
|
||||
"business": {"industry": "Tech", "version": 1},
|
||||
"suggested_prompts": themed,
|
||||
}
|
||||
|
||||
result = BusinessUnderstanding.from_db(db_record)
|
||||
|
||||
assert result.suggested_prompts == themed
|
||||
|
||||
|
||||
def test_from_db_legacy_list_prompts_preserved_under_general():
|
||||
"""from_db preserves legacy list[str] prompts under a 'General' key."""
|
||||
db_record = MagicMock()
|
||||
db_record.id = "test-id"
|
||||
db_record.userId = "user-1"
|
||||
db_record.createdAt = datetime.now(tz=timezone.utc)
|
||||
db_record.updatedAt = datetime.now(tz=timezone.utc)
|
||||
db_record.data = {
|
||||
"name": "Alice",
|
||||
"business": {"industry": "Tech", "version": 1},
|
||||
"suggested_prompts": ["Old prompt 1", "Old prompt 2"],
|
||||
}
|
||||
|
||||
result = BusinessUnderstanding.from_db(db_record)
|
||||
|
||||
assert result.suggested_prompts == {"General": ["Old prompt 1", "Old prompt 2"]}
|
||||
|
||||
|
||||
# ─── _json_to_themed_prompts helper ───────────────────────────────────
|
||||
|
||||
|
||||
def test_json_to_themed_prompts_with_dict():
|
||||
value = {"Learn": ["a", "b"], "Create": ["c"]}
|
||||
assert _json_to_themed_prompts(value) == {"Learn": ["a", "b"], "Create": ["c"]}
|
||||
|
||||
|
||||
def test_json_to_themed_prompts_with_list_returns_general():
|
||||
assert _json_to_themed_prompts(["a", "b"]) == {"General": ["a", "b"]}
|
||||
|
||||
|
||||
def test_json_to_themed_prompts_with_none_returns_empty():
|
||||
assert _json_to_themed_prompts(None) == {}
|
||||
|
||||
|
||||
# ─── format_understanding_for_prompt: excludes themed prompts ──────────
|
||||
|
||||
|
||||
def test_format_understanding_excludes_themed_prompts():
|
||||
"""Themed suggested_prompts are UI-only and must NOT appear in the system prompt."""
|
||||
understanding = BusinessUnderstanding(
|
||||
id="test-id",
|
||||
user_id="user-1",
|
||||
created_at=datetime.now(tz=timezone.utc),
|
||||
updated_at=datetime.now(tz=timezone.utc),
|
||||
user_name="Alice",
|
||||
industry="Technology",
|
||||
suggested_prompts={
|
||||
"Learn": ["Automate reports"],
|
||||
"Create": ["Set up alerts", "Track KPIs"],
|
||||
},
|
||||
)
|
||||
|
||||
formatted = format_understanding_for_prompt(understanding)
|
||||
|
||||
assert "Alice" in formatted
|
||||
assert "Technology" in formatted
|
||||
assert "suggested_prompts" not in formatted
|
||||
assert "Automate reports" not in formatted
|
||||
assert "Set up alerts" not in formatted
|
||||
assert "Track KPIs" not in formatted
|
||||
@@ -89,10 +89,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
le=500,
|
||||
description="Thread pool size for FastAPI sync operations. All sync endpoints and dependencies automatically use this pool. Higher values support more concurrent sync operations but use more memory.",
|
||||
)
|
||||
tally_extraction_llm_model: str = Field(
|
||||
default="openai/gpt-4o-mini",
|
||||
description="OpenRouter model ID used for extracting business understanding from Tally form data",
|
||||
)
|
||||
ollama_host: str = Field(
|
||||
default="localhost:11434",
|
||||
description="Default Ollama host; exempted from SSRF checks.",
|
||||
@@ -121,10 +117,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default=True,
|
||||
description="If authentication is enabled or not",
|
||||
)
|
||||
enable_invite_gate: bool = Field(
|
||||
default=False,
|
||||
description="If the invite-only signup gate is enforced",
|
||||
)
|
||||
enable_credit: bool = Field(
|
||||
default=False,
|
||||
description="If user credit system is enabled or not",
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
-- Revert the invite system: drop InvitedUser table + enums, restore User+Profile trigger.
|
||||
-- Uses current_schema() so the migration works regardless of the configured schema name.
|
||||
|
||||
-- 1) Drop the InvitedUser table (also drops its indexes and FK constraints)
|
||||
DROP TABLE IF EXISTS "InvitedUser";
|
||||
|
||||
-- 2) Drop the enums introduced by the invite system
|
||||
DROP TYPE IF EXISTS "InvitedUserStatus";
|
||||
DROP TYPE IF EXISTS "TallyComputationStatus";
|
||||
|
||||
-- 3) Restore the User+Profile auto-creation trigger on auth.users.
|
||||
-- Original definition from migration 20250205100104_add_profile_trigger.
|
||||
-- generate_username() was never dropped and is still present.
|
||||
-- Uses EXECUTE + format() to inject current_schema() into SET search_path,
|
||||
-- so the trigger function resolves tables correctly when fired from auth.
|
||||
|
||||
DO $$
|
||||
DECLARE
|
||||
cs text := current_schema();
|
||||
BEGIN
|
||||
EXECUTE format($fn$
|
||||
CREATE OR REPLACE FUNCTION add_user_and_profile_to_platform()
|
||||
RETURNS TRIGGER
|
||||
LANGUAGE plpgsql
|
||||
SECURITY DEFINER
|
||||
SET search_path = %I
|
||||
AS $trigger$
|
||||
BEGIN
|
||||
IF NEW.id IS NULL THEN
|
||||
RAISE EXCEPTION 'Cannot create user/profile: id is null';
|
||||
END IF;
|
||||
|
||||
INSERT INTO "User" (id, email, "updatedAt")
|
||||
VALUES (NEW.id, NEW.email, now());
|
||||
|
||||
INSERT INTO "Profile"
|
||||
("id", "userId", name, username, description, links, "avatarUrl", "updatedAt")
|
||||
VALUES (
|
||||
NEW.id,
|
||||
NEW.id,
|
||||
COALESCE(split_part(NEW.email, '@', 1), 'user'),
|
||||
generate_username(),
|
||||
'I''m new here',
|
||||
'{}',
|
||||
'',
|
||||
now()
|
||||
);
|
||||
|
||||
RETURN NEW;
|
||||
EXCEPTION
|
||||
WHEN OTHERS THEN
|
||||
RAISE NOTICE 'Error in add_user_and_profile_to_platform: %%', SQLERRM;
|
||||
RAISE;
|
||||
END;
|
||||
$trigger$
|
||||
$fn$, cs);
|
||||
END $$;
|
||||
|
||||
-- 4) Backfill: create User + Profile rows for any auth.users rows that were
|
||||
-- created while the trigger was absent (during the invite-system window).
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1 FROM information_schema.tables
|
||||
WHERE table_schema = 'auth' AND table_name = 'users'
|
||||
) THEN
|
||||
INSERT INTO "User" (id, email, "updatedAt")
|
||||
SELECT au.id::text, au.email, now()
|
||||
FROM auth.users au
|
||||
LEFT JOIN "User" pu ON pu.id = au.id::text
|
||||
WHERE pu.id IS NULL
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
|
||||
INSERT INTO "Profile"
|
||||
(id, "userId", name, username, description, links, "avatarUrl", "updatedAt")
|
||||
SELECT
|
||||
gen_random_uuid()::text,
|
||||
au.id::text,
|
||||
COALESCE(NULLIF(split_part(au.email, '@', 1), ''), 'user'),
|
||||
generate_username(),
|
||||
'I''m new here',
|
||||
'{}',
|
||||
'',
|
||||
now()
|
||||
FROM auth.users au
|
||||
LEFT JOIN "Profile" pp ON pp."userId" = au.id::text
|
||||
WHERE pp."userId" IS NULL
|
||||
ON CONFLICT ("userId") DO NOTHING;
|
||||
END IF;
|
||||
END $$;
|
||||
|
||||
-- 5) Restore the trigger for future signups.
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'auth'
|
||||
AND table_name = 'users'
|
||||
) THEN
|
||||
DROP TRIGGER IF EXISTS user_added_to_platform ON auth.users;
|
||||
|
||||
CREATE TRIGGER user_added_to_platform
|
||||
AFTER INSERT ON auth.users
|
||||
FOR EACH ROW EXECUTE FUNCTION add_user_and_profile_to_platform();
|
||||
END IF;
|
||||
END $$;
|
||||
@@ -65,7 +65,6 @@ model User {
|
||||
NotificationBatches UserNotificationBatch[]
|
||||
PendingHumanReviews PendingHumanReview[]
|
||||
Workspace UserWorkspace?
|
||||
ClaimedInvite InvitedUser? @relation("InvitedUserAuthUser")
|
||||
|
||||
// OAuth Provider relations
|
||||
OAuthApplications OAuthApplication[]
|
||||
@@ -74,38 +73,6 @@ model User {
|
||||
OAuthRefreshTokens OAuthRefreshToken[]
|
||||
}
|
||||
|
||||
enum InvitedUserStatus {
|
||||
INVITED
|
||||
CLAIMED
|
||||
REVOKED
|
||||
}
|
||||
|
||||
enum TallyComputationStatus {
|
||||
PENDING
|
||||
RUNNING
|
||||
READY
|
||||
FAILED
|
||||
}
|
||||
|
||||
model InvitedUser {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
email String @unique
|
||||
status InvitedUserStatus @default(INVITED)
|
||||
authUserId String? @unique
|
||||
AuthUser User? @relation("InvitedUserAuthUser", fields: [authUserId], references: [id], onDelete: SetNull)
|
||||
name String?
|
||||
|
||||
tallyUnderstanding Json?
|
||||
tallyStatus TallyComputationStatus @default(PENDING)
|
||||
tallyComputedAt DateTime?
|
||||
tallyError String?
|
||||
|
||||
@@index([createdAt])
|
||||
}
|
||||
|
||||
enum OnboardingStep {
|
||||
// Introductory onboarding (Library)
|
||||
WELCOME
|
||||
@@ -1025,7 +992,7 @@ model StoreListing {
|
||||
ActiveVersion StoreListingVersion? @relation("ActiveVersion", fields: [activeVersionId], references: [id])
|
||||
|
||||
// The agent link here is only so we can do lookup on agentId
|
||||
agentGraphId String @unique
|
||||
agentGraphId String @unique
|
||||
|
||||
owningUserId String
|
||||
OwningUser User @relation(fields: [owningUserId], references: [id])
|
||||
|
||||
@@ -34,7 +34,7 @@ from backend.data.auth.api_key import create_api_key
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.db import prisma
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.invited_user import get_or_activate_user
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.util.clients import get_supabase
|
||||
|
||||
faker = Faker()
|
||||
@@ -151,7 +151,7 @@ class TestDataCreator:
|
||||
}
|
||||
|
||||
# Use the API function to create user in local database
|
||||
user = await get_or_activate_user(user_data)
|
||||
user = await get_or_create_user(user_data)
|
||||
users.append(user.model_dump())
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,14 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { Sidebar } from "@/components/__legacy__/Sidebar";
|
||||
import {
|
||||
UsersIcon,
|
||||
CurrencyDollarSimpleIcon,
|
||||
UserPlusIcon,
|
||||
MagnifyingGlassIcon,
|
||||
FileTextIcon,
|
||||
SlidersHorizontalIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { Users, DollarSign, UserSearch, FileText } from "lucide-react";
|
||||
|
||||
import { IconSliders } from "@/components/__legacy__/ui/icons";
|
||||
|
||||
const sidebarLinkGroups = [
|
||||
{
|
||||
@@ -16,32 +9,27 @@ const sidebarLinkGroups = [
|
||||
{
|
||||
text: "Marketplace Management",
|
||||
href: "/admin/marketplace",
|
||||
icon: <UsersIcon size={24} />,
|
||||
icon: <Users className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "User Spending",
|
||||
href: "/admin/spending",
|
||||
icon: <CurrencyDollarSimpleIcon size={24} />,
|
||||
},
|
||||
{
|
||||
text: "Beta Invites",
|
||||
href: "/admin/users",
|
||||
icon: <UserPlusIcon size={24} />,
|
||||
icon: <DollarSign className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "User Impersonation",
|
||||
href: "/admin/impersonation",
|
||||
icon: <MagnifyingGlassIcon size={24} />,
|
||||
icon: <UserSearch className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Execution Analytics",
|
||||
href: "/admin/execution-analytics",
|
||||
icon: <FileTextIcon size={24} />,
|
||||
icon: <FileText className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Admin User Management",
|
||||
href: "/admin/settings",
|
||||
icon: <SlidersHorizontalIcon size={24} />,
|
||||
icon: <IconSliders className="h-6 w-6" />,
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Card } from "@/components/atoms/Card/Card";
|
||||
import { BulkInviteForm } from "../BulkInviteForm/BulkInviteForm";
|
||||
import { InviteUserForm } from "../InviteUserForm/InviteUserForm";
|
||||
import { InvitedUsersTable } from "../InvitedUsersTable/InvitedUsersTable";
|
||||
import { useAdminUsersPage } from "../../useAdminUsersPage";
|
||||
|
||||
export function AdminUsersPage() {
|
||||
const {
|
||||
email,
|
||||
name,
|
||||
bulkInviteFile,
|
||||
bulkInviteInputKey,
|
||||
lastBulkInviteResult,
|
||||
invitedUsers,
|
||||
isLoadingInvitedUsers,
|
||||
isRefreshingInvitedUsers,
|
||||
isCreatingInvite,
|
||||
isBulkInviting,
|
||||
pendingInviteAction,
|
||||
setEmail,
|
||||
setName,
|
||||
handleBulkInviteFileChange,
|
||||
handleBulkInviteSubmit,
|
||||
handleCreateInvite,
|
||||
handleRetryTally,
|
||||
handleRevoke,
|
||||
} = useAdminUsersPage();
|
||||
|
||||
return (
|
||||
<div className="mx-auto flex max-w-7xl flex-col gap-6 p-6">
|
||||
<div className="flex flex-col gap-2">
|
||||
<h1 className="text-3xl font-bold text-zinc-900">Beta Invites</h1>
|
||||
<p className="max-w-3xl text-sm text-zinc-600">
|
||||
Pre-provision beta users before they sign up. Invites store the
|
||||
platform-side record, run Tally understanding extraction, and activate
|
||||
the real account on the user's first authenticated request.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-6 xl:grid-cols-[24rem,1fr]">
|
||||
<div className="flex flex-col gap-6">
|
||||
<Card className="border border-zinc-200 shadow-sm">
|
||||
<InviteUserForm
|
||||
email={email}
|
||||
name={name}
|
||||
isSubmitting={isCreatingInvite}
|
||||
onEmailChange={setEmail}
|
||||
onNameChange={setName}
|
||||
onSubmit={handleCreateInvite}
|
||||
/>
|
||||
</Card>
|
||||
|
||||
<Card className="border border-zinc-200 shadow-sm">
|
||||
<BulkInviteForm
|
||||
selectedFile={bulkInviteFile}
|
||||
inputKey={bulkInviteInputKey}
|
||||
isSubmitting={isBulkInviting}
|
||||
lastResult={lastBulkInviteResult}
|
||||
onFileChange={handleBulkInviteFileChange}
|
||||
onSubmit={handleBulkInviteSubmit}
|
||||
/>
|
||||
</Card>
|
||||
</div>
|
||||
|
||||
<Card className="border border-zinc-200 shadow-sm">
|
||||
<InvitedUsersTable
|
||||
invitedUsers={invitedUsers}
|
||||
isLoading={isLoadingInvitedUsers}
|
||||
isRefreshing={isRefreshingInvitedUsers}
|
||||
pendingInviteAction={pendingInviteAction}
|
||||
onRetryTally={handleRetryTally}
|
||||
onRevoke={handleRevoke}
|
||||
/>
|
||||
</Card>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,135 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import type { BulkInvitedUsersResponse } from "@/app/api/__generated__/models/bulkInvitedUsersResponse";
|
||||
import { Badge } from "@/components/atoms/Badge/Badge";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { FormEvent } from "react";
|
||||
|
||||
interface Props {
|
||||
selectedFile: File | null;
|
||||
inputKey: number;
|
||||
isSubmitting: boolean;
|
||||
lastResult: BulkInvitedUsersResponse | null;
|
||||
onFileChange: (file: File | null) => void;
|
||||
onSubmit: (event: FormEvent<HTMLFormElement>) => void;
|
||||
}
|
||||
|
||||
function getStatusVariant(status: "CREATED" | "SKIPPED" | "ERROR") {
|
||||
if (status === "CREATED") {
|
||||
return "success";
|
||||
}
|
||||
|
||||
if (status === "ERROR") {
|
||||
return "error";
|
||||
}
|
||||
|
||||
return "info";
|
||||
}
|
||||
|
||||
export function BulkInviteForm({
|
||||
selectedFile,
|
||||
inputKey,
|
||||
isSubmitting,
|
||||
lastResult,
|
||||
onFileChange,
|
||||
onSubmit,
|
||||
}: Props) {
|
||||
return (
|
||||
<form className="flex flex-col gap-4" onSubmit={onSubmit}>
|
||||
<div className="flex flex-col gap-1">
|
||||
<h2 className="text-xl font-semibold text-zinc-900">Bulk invite</h2>
|
||||
<p className="text-sm text-zinc-600">
|
||||
Upload a <span className="font-medium text-zinc-800">.txt</span> file
|
||||
with one email per line, or a{" "}
|
||||
<span className="font-medium text-zinc-800">.csv</span> with
|
||||
<span className="font-medium text-zinc-800"> email</span> and optional
|
||||
<span className="font-medium text-zinc-800"> name</span> columns.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<label
|
||||
htmlFor="bulk-invite-file-input"
|
||||
className="flex cursor-pointer flex-col gap-2 rounded-2xl border border-dashed border-zinc-300 bg-zinc-50 px-4 py-5 text-sm text-zinc-600 transition-colors focus-within:ring-2 focus-within:ring-zinc-500 focus-within:ring-offset-2 hover:border-zinc-400 hover:bg-zinc-100"
|
||||
>
|
||||
<span className="font-medium text-zinc-900">
|
||||
{selectedFile ? selectedFile.name : "Choose invite file"}
|
||||
</span>
|
||||
<span>Maximum 500 rows, UTF-8 encoded.</span>
|
||||
<input
|
||||
id="bulk-invite-file-input"
|
||||
key={inputKey}
|
||||
type="file"
|
||||
accept=".txt,.csv,text/plain,text/csv"
|
||||
disabled={isSubmitting}
|
||||
className="sr-only"
|
||||
onChange={(event) =>
|
||||
onFileChange(event.target.files?.item(0) ?? null)
|
||||
}
|
||||
/>
|
||||
</label>
|
||||
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
loading={isSubmitting}
|
||||
disabled={!selectedFile}
|
||||
className="w-full"
|
||||
>
|
||||
{isSubmitting ? "Uploading invites..." : "Upload invite file"}
|
||||
</Button>
|
||||
|
||||
{lastResult ? (
|
||||
<div className="flex flex-col gap-3 rounded-2xl border border-zinc-200 bg-zinc-50 p-4">
|
||||
<div className="grid grid-cols-3 gap-2 text-center">
|
||||
<div className="rounded-xl bg-white px-3 py-2">
|
||||
<div className="text-lg font-semibold text-zinc-900">
|
||||
{lastResult.created_count}
|
||||
</div>
|
||||
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
|
||||
Created
|
||||
</div>
|
||||
</div>
|
||||
<div className="rounded-xl bg-white px-3 py-2">
|
||||
<div className="text-lg font-semibold text-zinc-900">
|
||||
{lastResult.skipped_count}
|
||||
</div>
|
||||
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
|
||||
Skipped
|
||||
</div>
|
||||
</div>
|
||||
<div className="rounded-xl bg-white px-3 py-2">
|
||||
<div className="text-lg font-semibold text-zinc-900">
|
||||
{lastResult.error_count}
|
||||
</div>
|
||||
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
|
||||
Errors
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="max-h-64 overflow-y-auto rounded-xl border border-zinc-200 bg-white">
|
||||
<div className="flex flex-col divide-y divide-zinc-100">
|
||||
{lastResult.results.map((row) => (
|
||||
<div
|
||||
key={`${row.row_number}-${row.email ?? row.message}`}
|
||||
className="flex items-start gap-3 px-3 py-3"
|
||||
>
|
||||
<Badge variant={getStatusVariant(row.status)} size="small">
|
||||
{row.status}
|
||||
</Badge>
|
||||
<div className="flex min-w-0 flex-1 flex-col gap-1">
|
||||
<span className="text-sm font-medium text-zinc-900">
|
||||
Row {row.row_number}
|
||||
{row.email ? ` · ${row.email}` : ""}
|
||||
</span>
|
||||
<span className="text-xs text-zinc-500">{row.message}</span>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : null}
|
||||
</form>
|
||||
);
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import type { FormEvent } from "react";
|
||||
|
||||
interface Props {
|
||||
email: string;
|
||||
name: string;
|
||||
isSubmitting: boolean;
|
||||
onEmailChange: (value: string) => void;
|
||||
onNameChange: (value: string) => void;
|
||||
onSubmit: (event: FormEvent<HTMLFormElement>) => void;
|
||||
}
|
||||
|
||||
export function InviteUserForm({
|
||||
email,
|
||||
name,
|
||||
isSubmitting,
|
||||
onEmailChange,
|
||||
onNameChange,
|
||||
onSubmit,
|
||||
}: Props) {
|
||||
return (
|
||||
<form className="flex flex-col gap-4" onSubmit={onSubmit}>
|
||||
<div className="flex flex-col gap-1">
|
||||
<h2 className="text-xl font-semibold text-zinc-900">Create invite</h2>
|
||||
<p className="text-sm text-zinc-600">
|
||||
The invite is stored immediately, then Tally pre-seeding starts in the
|
||||
background.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<Input
|
||||
id="invite-email"
|
||||
label="Email"
|
||||
type="email"
|
||||
value={email}
|
||||
placeholder="jane@example.com"
|
||||
autoComplete="email"
|
||||
disabled={isSubmitting}
|
||||
onChange={(event) => onEmailChange(event.target.value)}
|
||||
/>
|
||||
|
||||
<Input
|
||||
id="invite-name"
|
||||
label="Name"
|
||||
type="text"
|
||||
value={name}
|
||||
placeholder="Jane Doe"
|
||||
disabled={isSubmitting}
|
||||
onChange={(event) => onNameChange(event.target.value)}
|
||||
/>
|
||||
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
loading={isSubmitting}
|
||||
disabled={!email.trim()}
|
||||
className="w-full"
|
||||
>
|
||||
{isSubmitting ? "Creating invite..." : "Create invite"}
|
||||
</Button>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
@@ -1,209 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import type { InvitedUserResponse } from "@/app/api/__generated__/models/invitedUserResponse";
|
||||
import { Badge } from "@/components/atoms/Badge/Badge";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/__legacy__/ui/table";
|
||||
|
||||
interface Props {
|
||||
invitedUsers: InvitedUserResponse[];
|
||||
isLoading: boolean;
|
||||
isRefreshing: boolean;
|
||||
pendingInviteAction: string | null;
|
||||
onRetryTally: (invitedUserId: string) => void;
|
||||
onRevoke: (invitedUserId: string) => void;
|
||||
}
|
||||
|
||||
function getInviteBadgeVariant(status: InvitedUserResponse["status"]) {
|
||||
if (status === "CLAIMED") {
|
||||
return "success";
|
||||
}
|
||||
|
||||
if (status === "REVOKED") {
|
||||
return "error";
|
||||
}
|
||||
|
||||
return "info";
|
||||
}
|
||||
|
||||
function getTallyBadgeVariant(status: InvitedUserResponse["tally_status"]) {
|
||||
if (status === "READY") {
|
||||
return "success";
|
||||
}
|
||||
|
||||
if (status === "FAILED") {
|
||||
return "error";
|
||||
}
|
||||
|
||||
return "info";
|
||||
}
|
||||
|
||||
function formatDate(value: Date | undefined) {
|
||||
if (!value) {
|
||||
return "-";
|
||||
}
|
||||
|
||||
return value.toLocaleString();
|
||||
}
|
||||
|
||||
function getTallySummary(invitedUser: InvitedUserResponse) {
|
||||
if (invitedUser.tally_status === "FAILED" && invitedUser.tally_error) {
|
||||
return invitedUser.tally_error;
|
||||
}
|
||||
|
||||
if (invitedUser.tally_status === "READY" && invitedUser.tally_understanding) {
|
||||
return "Stored and ready for activation";
|
||||
}
|
||||
|
||||
if (invitedUser.tally_status === "READY") {
|
||||
return "No matching Tally submission found";
|
||||
}
|
||||
|
||||
if (invitedUser.tally_status === "RUNNING") {
|
||||
return "Extraction in progress";
|
||||
}
|
||||
|
||||
return "Waiting to run";
|
||||
}
|
||||
|
||||
function isActionPending(
|
||||
pendingInviteAction: string | null,
|
||||
action: "retry" | "revoke",
|
||||
invitedUserId: string,
|
||||
) {
|
||||
return pendingInviteAction === `${action}:${invitedUserId}`;
|
||||
}
|
||||
|
||||
export function InvitedUsersTable({
|
||||
invitedUsers,
|
||||
isLoading,
|
||||
isRefreshing,
|
||||
pendingInviteAction,
|
||||
onRetryTally,
|
||||
onRevoke,
|
||||
}: Props) {
|
||||
return (
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex items-center justify-between gap-4">
|
||||
<div className="flex flex-col gap-1">
|
||||
<h2 className="text-xl font-semibold text-zinc-900">Invited users</h2>
|
||||
<p className="text-sm text-zinc-600">
|
||||
Live invite state, claim status, and Tally pre-seeding progress.
|
||||
</p>
|
||||
</div>
|
||||
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
|
||||
{isRefreshing ? "Refreshing" : `${invitedUsers.length} total`}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div className="overflow-hidden rounded-2xl border border-zinc-200">
|
||||
<Table>
|
||||
<TableHeader className="bg-zinc-50">
|
||||
<TableRow>
|
||||
<TableHead>Email</TableHead>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>Invite</TableHead>
|
||||
<TableHead>Tally</TableHead>
|
||||
<TableHead>Claimed User</TableHead>
|
||||
<TableHead>Created</TableHead>
|
||||
<TableHead className="text-right">Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{isLoading ? (
|
||||
<TableRow>
|
||||
<TableCell
|
||||
colSpan={7}
|
||||
className="py-10 text-center text-zinc-500"
|
||||
>
|
||||
Loading invited users...
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
) : invitedUsers.length === 0 ? (
|
||||
<TableRow>
|
||||
<TableCell
|
||||
colSpan={7}
|
||||
className="py-10 text-center text-zinc-500"
|
||||
>
|
||||
No invited users yet
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
) : (
|
||||
invitedUsers.map((invitedUser) => (
|
||||
<TableRow key={invitedUser.id} className="align-top">
|
||||
<TableCell className="font-medium text-zinc-900">
|
||||
{invitedUser.email}
|
||||
</TableCell>
|
||||
<TableCell>{invitedUser.name || "-"}</TableCell>
|
||||
<TableCell>
|
||||
<Badge variant={getInviteBadgeVariant(invitedUser.status)}>
|
||||
{invitedUser.status}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex max-w-xs flex-col gap-2">
|
||||
<Badge
|
||||
variant={getTallyBadgeVariant(invitedUser.tally_status)}
|
||||
>
|
||||
{invitedUser.tally_status}
|
||||
</Badge>
|
||||
<span className="text-xs text-zinc-500">
|
||||
{getTallySummary(invitedUser)}
|
||||
</span>
|
||||
<span className="text-xs text-zinc-400">
|
||||
{formatDate(invitedUser.tally_computed_at ?? undefined)}
|
||||
</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell className="font-mono text-xs text-zinc-500">
|
||||
{invitedUser.auth_user_id || "-"}
|
||||
</TableCell>
|
||||
<TableCell className="text-sm text-zinc-500">
|
||||
{formatDate(invitedUser.created_at)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex justify-end gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
disabled={invitedUser.status === "REVOKED"}
|
||||
loading={isActionPending(
|
||||
pendingInviteAction,
|
||||
"retry",
|
||||
invitedUser.id,
|
||||
)}
|
||||
onClick={() => onRetryTally(invitedUser.id)}
|
||||
>
|
||||
Retry Tally
|
||||
</Button>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="small"
|
||||
disabled={invitedUser.status !== "INVITED"}
|
||||
loading={isActionPending(
|
||||
pendingInviteAction,
|
||||
"revoke",
|
||||
invitedUser.id,
|
||||
)}
|
||||
onClick={() => onRevoke(invitedUser.id)}
|
||||
>
|
||||
Revoke
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))
|
||||
)}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,11 +1,16 @@
|
||||
import { withRoleAccess } from "@/lib/withRoleAccess";
|
||||
import { AdminUsersPage } from "./components/AdminUsersPage/AdminUsersPage";
|
||||
import React from "react";
|
||||
|
||||
function AdminUsers() {
|
||||
return <AdminUsersPage />;
|
||||
return (
|
||||
<div>
|
||||
<h1>Users Dashboard</h1>
|
||||
{/* Add your admin-only content here */}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default async function AdminUsersRoute() {
|
||||
export default async function AdminUsersPage() {
|
||||
"use server";
|
||||
const withAdminAccess = await withRoleAccess(["admin"]);
|
||||
const ProtectedAdminUsers = await withAdminAccess(AdminUsers);
|
||||
|
||||
@@ -1,197 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import type { BulkInvitedUsersResponse } from "@/app/api/__generated__/models/bulkInvitedUsersResponse";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import {
|
||||
getGetV2ListInvitedUsersQueryKey,
|
||||
useGetV2ListInvitedUsers,
|
||||
usePostV2BulkCreateInvitedUsers,
|
||||
usePostV2CreateInvitedUser,
|
||||
usePostV2RetryInvitedUserTally,
|
||||
usePostV2RevokeInvitedUser,
|
||||
} from "@/app/api/__generated__/endpoints/admin/admin";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { type FormEvent, useState } from "react";
|
||||
|
||||
function getErrorMessage(error: unknown) {
|
||||
if (error instanceof Error) {
|
||||
return error.message;
|
||||
}
|
||||
|
||||
return "Something went wrong";
|
||||
}
|
||||
|
||||
export function useAdminUsersPage() {
|
||||
const queryClient = useQueryClient();
|
||||
const { toast } = useToast();
|
||||
const [email, setEmail] = useState("");
|
||||
const [name, setName] = useState("");
|
||||
const [bulkInviteFile, setBulkInviteFile] = useState<File | null>(null);
|
||||
const [bulkInviteInputKey, setBulkInviteInputKey] = useState(0);
|
||||
const [lastBulkInviteResult, setLastBulkInviteResult] =
|
||||
useState<BulkInvitedUsersResponse | null>(null);
|
||||
const [pendingInviteAction, setPendingInviteAction] = useState<string | null>(
|
||||
null,
|
||||
);
|
||||
|
||||
const invitedUsersQuery = useGetV2ListInvitedUsers(undefined, {
|
||||
query: {
|
||||
select: okData,
|
||||
refetchInterval: 30_000,
|
||||
},
|
||||
});
|
||||
|
||||
const createInvitedUserMutation = usePostV2CreateInvitedUser({
|
||||
mutation: {
|
||||
onSuccess: async () => {
|
||||
setEmail("");
|
||||
setName("");
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListInvitedUsersQueryKey(),
|
||||
});
|
||||
toast({
|
||||
title: "Invited user created",
|
||||
variant: "default",
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
toast({
|
||||
title: getErrorMessage(error),
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const bulkCreateInvitedUsersMutation = usePostV2BulkCreateInvitedUsers({
|
||||
mutation: {
|
||||
onSuccess: async (response) => {
|
||||
const result = okData(response) ?? null;
|
||||
setBulkInviteFile(null);
|
||||
setBulkInviteInputKey((currentValue) => currentValue + 1);
|
||||
setLastBulkInviteResult(result);
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListInvitedUsersQueryKey(),
|
||||
});
|
||||
toast({
|
||||
title: result
|
||||
? `${result.created_count} invites created`
|
||||
: "Bulk invite upload complete",
|
||||
variant: "default",
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
toast({
|
||||
title: getErrorMessage(error),
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const retryInvitedUserTallyMutation = usePostV2RetryInvitedUserTally({
|
||||
mutation: {
|
||||
onSuccess: async () => {
|
||||
setPendingInviteAction(null);
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListInvitedUsersQueryKey(),
|
||||
});
|
||||
toast({
|
||||
title: "Tally pre-seeding restarted",
|
||||
variant: "default",
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
setPendingInviteAction(null);
|
||||
toast({
|
||||
title: getErrorMessage(error),
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const revokeInvitedUserMutation = usePostV2RevokeInvitedUser({
|
||||
mutation: {
|
||||
onSuccess: async () => {
|
||||
setPendingInviteAction(null);
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListInvitedUsersQueryKey(),
|
||||
});
|
||||
toast({
|
||||
title: "Invite revoked",
|
||||
variant: "default",
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
setPendingInviteAction(null);
|
||||
toast({
|
||||
title: getErrorMessage(error),
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
function handleCreateInvite(event: FormEvent<HTMLFormElement>) {
|
||||
event.preventDefault();
|
||||
|
||||
createInvitedUserMutation.mutate({
|
||||
data: {
|
||||
email,
|
||||
name: name.trim() || null,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function handleRetryTally(invitedUserId: string) {
|
||||
setPendingInviteAction(`retry:${invitedUserId}`);
|
||||
retryInvitedUserTallyMutation.mutate({ invitedUserId });
|
||||
}
|
||||
|
||||
function handleBulkInviteFileChange(file: File | null) {
|
||||
setBulkInviteFile(file);
|
||||
}
|
||||
|
||||
function handleBulkInviteSubmit(event: FormEvent<HTMLFormElement>) {
|
||||
event.preventDefault();
|
||||
|
||||
if (!bulkInviteFile) {
|
||||
return;
|
||||
}
|
||||
|
||||
bulkCreateInvitedUsersMutation.mutate({
|
||||
data: {
|
||||
file: bulkInviteFile,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function handleRevoke(invitedUserId: string) {
|
||||
setPendingInviteAction(`revoke:${invitedUserId}`);
|
||||
revokeInvitedUserMutation.mutate({ invitedUserId });
|
||||
}
|
||||
|
||||
return {
|
||||
email,
|
||||
name,
|
||||
bulkInviteFile,
|
||||
bulkInviteInputKey,
|
||||
lastBulkInviteResult,
|
||||
invitedUsers: invitedUsersQuery.data?.invited_users ?? [],
|
||||
invitedUsersError: invitedUsersQuery.error,
|
||||
isLoadingInvitedUsers: invitedUsersQuery.isLoading,
|
||||
isRefreshingInvitedUsers: invitedUsersQuery.isFetching,
|
||||
isCreatingInvite: createInvitedUserMutation.isPending,
|
||||
isBulkInviting: bulkCreateInvitedUsersMutation.isPending,
|
||||
pendingInviteAction,
|
||||
setEmail,
|
||||
setName,
|
||||
handleBulkInviteFileChange,
|
||||
handleBulkInviteSubmit,
|
||||
handleCreateInvite,
|
||||
handleRetryTally,
|
||||
handleRevoke,
|
||||
};
|
||||
}
|
||||
@@ -1,18 +1,17 @@
|
||||
"use client";
|
||||
|
||||
import { useGetV2GetSuggestedPrompts } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
|
||||
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { SpinnerGapIcon } from "@phosphor-icons/react";
|
||||
import { motion } from "framer-motion";
|
||||
import { useEffect, useState } from "react";
|
||||
import {
|
||||
getGreetingName,
|
||||
getInputPlaceholder,
|
||||
getSuggestionThemes,
|
||||
getQuickActions,
|
||||
} from "./helpers";
|
||||
import { SuggestionThemes } from "./components/SuggestionThemes/SuggestionThemes";
|
||||
|
||||
interface Props {
|
||||
inputLayoutId: string;
|
||||
@@ -34,35 +33,25 @@ export function EmptySession({
|
||||
}: Props) {
|
||||
const { user } = useSupabase();
|
||||
const greetingName = getGreetingName(user);
|
||||
|
||||
const { data: suggestedPromptsResponse, isLoading: isLoadingPrompts } =
|
||||
useGetV2GetSuggestedPrompts({
|
||||
query: { staleTime: Infinity },
|
||||
});
|
||||
const themes = getSuggestionThemes(
|
||||
suggestedPromptsResponse?.status === 200
|
||||
? suggestedPromptsResponse.data.themes
|
||||
: undefined,
|
||||
);
|
||||
|
||||
const quickActions = getQuickActions();
|
||||
const [loadingAction, setLoadingAction] = useState<string | null>(null);
|
||||
const [inputPlaceholder, setInputPlaceholder] = useState(
|
||||
getInputPlaceholder(),
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
function update() {
|
||||
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
|
||||
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
|
||||
}, [window.innerWidth]);
|
||||
|
||||
async function handleQuickActionClick(action: string) {
|
||||
if (isCreatingSession || loadingAction !== null) return;
|
||||
setLoadingAction(action);
|
||||
try {
|
||||
await onSend(action);
|
||||
} finally {
|
||||
setLoadingAction(null);
|
||||
}
|
||||
const mq500 = window.matchMedia("(min-width: 500px)");
|
||||
const mq1081 = window.matchMedia("(min-width: 1081px)");
|
||||
update();
|
||||
mq500.addEventListener("change", update);
|
||||
mq1081.addEventListener("change", update);
|
||||
return () => {
|
||||
mq500.removeEventListener("change", update);
|
||||
mq1081.removeEventListener("change", update);
|
||||
};
|
||||
}, []);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-0 py-5 md:px-6 md:py-10">
|
||||
@@ -100,19 +89,30 @@ export function EmptySession({
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{isLoadingPrompts ? (
|
||||
<div className="flex flex-wrap items-center justify-center gap-3">
|
||||
{Array.from({ length: 4 }, (_, i) => (
|
||||
<Skeleton key={i} className="h-10 w-28 shrink-0 rounded-full" />
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
<SuggestionThemes
|
||||
themes={themes}
|
||||
onSend={onSend}
|
||||
disabled={isCreatingSession}
|
||||
/>
|
||||
)}
|
||||
<div className="flex flex-wrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
||||
{quickActions.map((action) => (
|
||||
<Button
|
||||
key={action}
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={() => void handleQuickActionClick(action)}
|
||||
disabled={isCreatingSession || loadingAction !== null}
|
||||
aria-busy={loadingAction === action}
|
||||
leftIcon={
|
||||
loadingAction === action ? (
|
||||
<SpinnerGapIcon
|
||||
className="h-4 w-4 animate-spin"
|
||||
weight="bold"
|
||||
/>
|
||||
) : null
|
||||
}
|
||||
className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
|
||||
>
|
||||
{action}
|
||||
</Button>
|
||||
))}
|
||||
</div>
|
||||
</motion.div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/molecules/Popover/Popover";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
BookOpenIcon,
|
||||
PaintBrushIcon,
|
||||
LightningIcon,
|
||||
ListChecksIcon,
|
||||
SpinnerGapIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { useState } from "react";
|
||||
import type { SuggestionTheme } from "../../helpers";
|
||||
|
||||
const THEME_ICONS: Record<string, typeof BookOpenIcon> = {
|
||||
Learn: BookOpenIcon,
|
||||
Create: PaintBrushIcon,
|
||||
Automate: LightningIcon,
|
||||
Organize: ListChecksIcon,
|
||||
};
|
||||
|
||||
interface Props {
|
||||
themes: SuggestionTheme[];
|
||||
onSend: (prompt: string) => void | Promise<void>;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
export function SuggestionThemes({ themes, onSend, disabled }: Props) {
|
||||
const [openTheme, setOpenTheme] = useState<string | null>(null);
|
||||
const [loadingPrompt, setLoadingPrompt] = useState<string | null>(null);
|
||||
|
||||
async function handlePromptClick(theme: string, prompt: string) {
|
||||
if (disabled || loadingPrompt) return;
|
||||
setLoadingPrompt(`${theme}:${prompt}`);
|
||||
try {
|
||||
await onSend(prompt);
|
||||
} finally {
|
||||
setLoadingPrompt(null);
|
||||
setOpenTheme(null);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-wrap items-center justify-center gap-3">
|
||||
{themes.map((theme) => {
|
||||
const Icon = THEME_ICONS[theme.name];
|
||||
return (
|
||||
<Popover
|
||||
key={theme.name}
|
||||
open={openTheme === theme.name}
|
||||
onOpenChange={(open) => setOpenTheme(open ? theme.name : null)}
|
||||
>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
disabled={disabled || loadingPrompt !== null}
|
||||
className="shrink-0 gap-2 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
|
||||
>
|
||||
{Icon && <Icon size={16} weight="regular" />}
|
||||
{theme.name}
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent align="center" className="w-80 p-2">
|
||||
<ul className="grid gap-0.5">
|
||||
{theme.prompts.map((prompt) => (
|
||||
<li key={prompt}>
|
||||
<button
|
||||
type="button"
|
||||
disabled={loadingPrompt !== null}
|
||||
onClick={() => void handlePromptClick(theme.name, prompt)}
|
||||
className="w-full rounded-md px-3 py-2 text-left text-sm text-zinc-700 transition-colors hover:bg-zinc-100 disabled:opacity-50"
|
||||
>
|
||||
{loadingPrompt === `${theme.name}:${prompt}` ? (
|
||||
<span className="flex items-center gap-2">
|
||||
<SpinnerGapIcon
|
||||
className="h-4 w-4 animate-spin"
|
||||
weight="bold"
|
||||
/>
|
||||
{prompt}
|
||||
</span>
|
||||
) : (
|
||||
prompt
|
||||
)}
|
||||
</button>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -12,87 +12,12 @@ export function getInputPlaceholder(width?: number) {
|
||||
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
|
||||
}
|
||||
|
||||
export interface SuggestionTheme {
|
||||
name: string;
|
||||
prompts: string[];
|
||||
}
|
||||
|
||||
export const DEFAULT_THEMES: SuggestionTheme[] = [
|
||||
{
|
||||
name: "Learn",
|
||||
prompts: [
|
||||
"What can AutoGPT do for me?",
|
||||
"Show me how agents work",
|
||||
"What integrations are available?",
|
||||
"How do I schedule an agent?",
|
||||
"What are the most popular agents?",
|
||||
],
|
||||
},
|
||||
{
|
||||
name: "Create",
|
||||
prompts: [
|
||||
"Draft a weekly status report",
|
||||
"Generate social media posts for my business",
|
||||
"Create a competitive analysis summary",
|
||||
"Write onboarding emails for new hires",
|
||||
"Build a content calendar for next month",
|
||||
],
|
||||
},
|
||||
{
|
||||
name: "Automate",
|
||||
prompts: [
|
||||
"Monitor relevant websites for changes",
|
||||
"Send me a daily news digest on my industry",
|
||||
"Auto-reply to common customer questions",
|
||||
"Track price changes on products I sell",
|
||||
"Summarize my emails every morning",
|
||||
],
|
||||
},
|
||||
{
|
||||
name: "Organize",
|
||||
prompts: [
|
||||
"Sort my bookmarks into categories",
|
||||
"Create a project timeline from my notes",
|
||||
"Prioritize my task list by urgency",
|
||||
"Build a decision matrix for vendor selection",
|
||||
"Organize my meeting notes into action items",
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
export function getSuggestionThemes(
|
||||
apiThemes?: SuggestionTheme[],
|
||||
): SuggestionTheme[] {
|
||||
if (!apiThemes?.length) {
|
||||
return DEFAULT_THEMES;
|
||||
}
|
||||
|
||||
const promptsByTheme = new Map(
|
||||
apiThemes.map((theme) => [theme.name, theme.prompts] as const),
|
||||
);
|
||||
|
||||
// Legacy users have prompts under "General" — distribute them across themes
|
||||
const generalPrompts = (promptsByTheme.get("General") ?? []).filter(
|
||||
(p) => p.trim().length > 0,
|
||||
);
|
||||
|
||||
return DEFAULT_THEMES.map((theme, idx) => {
|
||||
const personalized = (promptsByTheme.get(theme.name) ?? []).filter(
|
||||
(p) => p.trim().length > 0,
|
||||
);
|
||||
|
||||
// Spread legacy "General" prompts round-robin across themes
|
||||
const legacySlice = generalPrompts.filter(
|
||||
(_, i) => i % DEFAULT_THEMES.length === idx,
|
||||
);
|
||||
|
||||
return {
|
||||
name: theme.name,
|
||||
prompts: Array.from(
|
||||
new Set([...personalized, ...legacySlice, ...theme.prompts]),
|
||||
).slice(0, theme.prompts.length),
|
||||
};
|
||||
});
|
||||
export function getQuickActions() {
|
||||
return [
|
||||
"I don't know where to start, just ask me stuff",
|
||||
"I do the same thing every week and it's killing me",
|
||||
"Help me find where I'm wasting my time",
|
||||
];
|
||||
}
|
||||
|
||||
export function getGreetingName(user?: User | null) {
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
"use server";
|
||||
|
||||
import {
|
||||
postV1CheckIfAnEmailIsAllowedToSignUp,
|
||||
postV1GetOrCreateUser,
|
||||
} from "@/app/api/__generated__/endpoints/auth/auth";
|
||||
import { postV1GetOrCreateUser } from "@/app/api/__generated__/endpoints/auth/auth";
|
||||
import { getOnboardingStatus, resolveResponse } from "@/app/api/helpers";
|
||||
import { ApiError } from "@/lib/autogpt-server-api/helpers";
|
||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||
import { signupFormSchema } from "@/types/auth";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
@@ -32,31 +28,6 @@ export async function signup(
|
||||
};
|
||||
}
|
||||
|
||||
// Pre-check invite eligibility before creating a Supabase auth user.
|
||||
// This prevents orphaned auth accounts when the invite gate is enabled.
|
||||
try {
|
||||
const checkResult = await resolveResponse(
|
||||
postV1CheckIfAnEmailIsAllowedToSignUp({ email: parsed.data.email }),
|
||||
);
|
||||
if (!checkResult.allowed) {
|
||||
return { success: false, error: "not_allowed" };
|
||||
}
|
||||
// If the check fails (non-OK or backend unreachable), fall through to
|
||||
// signup — the backend-level check in get_or_activate_user() catches it.
|
||||
} catch (precheckError) {
|
||||
if (precheckError instanceof ApiError) {
|
||||
Sentry.captureMessage(
|
||||
`Invite pre-check returned HTTP ${precheckError.status}`,
|
||||
{ level: "warning", tags: { flow: "signup_precheck" } },
|
||||
);
|
||||
} else {
|
||||
Sentry.captureException(precheckError, {
|
||||
tags: { flow: "signup_precheck" },
|
||||
});
|
||||
}
|
||||
// Graceful fallback: don't block signup if the pre-check itself fails.
|
||||
}
|
||||
|
||||
const supabase = await getServerSupabase();
|
||||
if (!supabase) {
|
||||
return {
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
/**
|
||||
* Generated by orval v7.13.0 🍺
|
||||
* Do not edit manually.
|
||||
* AutoGPT Agent Server
|
||||
* This server is used to execute agents that are created by the AutoGPT system.
|
||||
* OpenAPI spec version: 0.1
|
||||
*/
|
||||
import type { SuggestedTheme } from "./suggestedTheme";
|
||||
|
||||
/**
|
||||
* Response model for user-specific suggested prompts grouped by theme.
|
||||
*/
|
||||
export interface SuggestedPromptsResponse {
|
||||
themes: SuggestedTheme[];
|
||||
}
|
||||
@@ -299,40 +299,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/auth/check-invite": {
|
||||
"post": {
|
||||
"tags": ["v1", "auth"],
|
||||
"summary": "Check if an email is allowed to sign up",
|
||||
"description": "Check if an email is allowed to sign up (no auth required).\n\nCalled by the frontend before creating a Supabase auth user to prevent\norphaned accounts when the invite gate is enabled.",
|
||||
"operationId": "postV1Check if an email is allowed to sign up",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/CheckInviteRequest" }
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/CheckInviteResponse" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/auth/user": {
|
||||
"post": {
|
||||
"tags": ["v1", "auth"],
|
||||
@@ -1392,30 +1358,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/suggested-prompts": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Get Suggested Prompts",
|
||||
"description": "Get LLM-generated suggested prompts grouped by theme.\n\nReturns personalized quick-action prompts based on the user's\nbusiness understanding. Returns empty themes list if no custom\nprompts are available.",
|
||||
"operationId": "getV2GetSuggestedPrompts",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/SuggestedPromptsResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
}
|
||||
},
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/chat/usage": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
@@ -6726,214 +6668,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/users/admin/invited-users": {
|
||||
"get": {
|
||||
"tags": ["v2", "admin", "users", "admin"],
|
||||
"summary": "List Invited Users",
|
||||
"operationId": "getV2List invited users",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "page",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"default": 1,
|
||||
"title": "Page"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "page_size",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "integer",
|
||||
"maximum": 200,
|
||||
"minimum": 1,
|
||||
"default": 50,
|
||||
"title": "Page Size"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/InvitedUsersResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"post": {
|
||||
"tags": ["v2", "admin", "users", "admin"],
|
||||
"summary": "Create Invited User",
|
||||
"operationId": "postV2Create invited user",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/CreateInvitedUserRequest"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/InvitedUserResponse" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/users/admin/invited-users/bulk": {
|
||||
"post": {
|
||||
"tags": ["v2", "admin", "users", "admin"],
|
||||
"summary": "Bulk Create Invited Users",
|
||||
"operationId": "postV2BulkCreateInvitedUsers",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"multipart/form-data": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Body_postV2BulkCreateInvitedUsers"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/BulkInvitedUsersResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/users/admin/invited-users/{invited_user_id}/retry-tally": {
|
||||
"post": {
|
||||
"tags": ["v2", "admin", "users", "admin"],
|
||||
"summary": "Retry Invited User Tally",
|
||||
"operationId": "postV2Retry invited user tally",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "invited_user_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Invited User Id" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/InvitedUserResponse" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/users/admin/invited-users/{invited_user_id}/revoke": {
|
||||
"post": {
|
||||
"tags": ["v2", "admin", "users", "admin"],
|
||||
"summary": "Revoke Invited User",
|
||||
"operationId": "postV2Revoke invited user",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "invited_user_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Invited User Id" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/InvitedUserResponse" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/workspace/files/upload": {
|
||||
"post": {
|
||||
"tags": ["workspace"],
|
||||
@@ -8320,14 +8054,6 @@
|
||||
"required": ["store_listing_version_id"],
|
||||
"title": "Body_postV2Add marketplace agent"
|
||||
},
|
||||
"Body_postV2BulkCreateInvitedUsers": {
|
||||
"properties": {
|
||||
"file": { "type": "string", "format": "binary", "title": "File" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["file"],
|
||||
"title": "Body_postV2BulkCreateInvitedUsers"
|
||||
},
|
||||
"Body_postV2Execute_a_preset": {
|
||||
"properties": {
|
||||
"inputs": {
|
||||
@@ -8362,56 +8088,6 @@
|
||||
"required": ["file"],
|
||||
"title": "Body_postWorkspaceUpload file to workspace"
|
||||
},
|
||||
"BulkInvitedUserRowResponse": {
|
||||
"properties": {
|
||||
"row_number": { "type": "integer", "title": "Row Number" },
|
||||
"email": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Email"
|
||||
},
|
||||
"name": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Name"
|
||||
},
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["CREATED", "SKIPPED", "ERROR"],
|
||||
"title": "Status"
|
||||
},
|
||||
"message": { "type": "string", "title": "Message" },
|
||||
"invited_user": {
|
||||
"anyOf": [
|
||||
{ "$ref": "#/components/schemas/InvitedUserResponse" },
|
||||
{ "type": "null" }
|
||||
]
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["row_number", "status", "message"],
|
||||
"title": "BulkInvitedUserRowResponse"
|
||||
},
|
||||
"BulkInvitedUsersResponse": {
|
||||
"properties": {
|
||||
"created_count": { "type": "integer", "title": "Created Count" },
|
||||
"skipped_count": { "type": "integer", "title": "Skipped Count" },
|
||||
"error_count": { "type": "integer", "title": "Error Count" },
|
||||
"results": {
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/BulkInvitedUserRowResponse"
|
||||
},
|
||||
"type": "array",
|
||||
"title": "Results"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"created_count",
|
||||
"skipped_count",
|
||||
"error_count",
|
||||
"results"
|
||||
],
|
||||
"title": "BulkInvitedUsersResponse"
|
||||
},
|
||||
"BulkMoveAgentsRequest": {
|
||||
"properties": {
|
||||
"agent_ids": {
|
||||
@@ -8475,20 +8151,6 @@
|
||||
"required": ["query", "conversation_history", "message_id"],
|
||||
"title": "ChatRequest"
|
||||
},
|
||||
"CheckInviteRequest": {
|
||||
"properties": {
|
||||
"email": { "type": "string", "format": "email", "title": "Email" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["email"],
|
||||
"title": "CheckInviteRequest"
|
||||
},
|
||||
"CheckInviteResponse": {
|
||||
"properties": { "allowed": { "type": "boolean", "title": "Allowed" } },
|
||||
"type": "object",
|
||||
"required": ["allowed"],
|
||||
"title": "CheckInviteResponse"
|
||||
},
|
||||
"ClarificationNeededResponse": {
|
||||
"properties": {
|
||||
"type": {
|
||||
@@ -8612,18 +8274,6 @@
|
||||
"required": ["graph"],
|
||||
"title": "CreateGraph"
|
||||
},
|
||||
"CreateInvitedUserRequest": {
|
||||
"properties": {
|
||||
"email": { "type": "string", "format": "email", "title": "Email" },
|
||||
"name": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Name"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["email"],
|
||||
"title": "CreateInvitedUserRequest"
|
||||
},
|
||||
"CreateSessionResponse": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
@@ -10088,80 +9738,6 @@
|
||||
"title": "InputValidationErrorResponse",
|
||||
"description": "Response when run_agent receives unknown input fields."
|
||||
},
|
||||
"InvitedUserResponse": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
"email": { "type": "string", "title": "Email" },
|
||||
"status": { "$ref": "#/components/schemas/InvitedUserStatus" },
|
||||
"auth_user_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Auth User Id"
|
||||
},
|
||||
"name": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Name"
|
||||
},
|
||||
"tally_understanding": {
|
||||
"anyOf": [
|
||||
{ "additionalProperties": true, "type": "object" },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Tally Understanding"
|
||||
},
|
||||
"tally_status": {
|
||||
"$ref": "#/components/schemas/TallyComputationStatus"
|
||||
},
|
||||
"tally_computed_at": {
|
||||
"anyOf": [
|
||||
{ "type": "string", "format": "date-time" },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Tally Computed At"
|
||||
},
|
||||
"tally_error": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Tally Error"
|
||||
},
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"title": "Created At"
|
||||
},
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"title": "Updated At"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"id",
|
||||
"email",
|
||||
"status",
|
||||
"tally_status",
|
||||
"created_at",
|
||||
"updated_at"
|
||||
],
|
||||
"title": "InvitedUserResponse"
|
||||
},
|
||||
"InvitedUserStatus": {
|
||||
"type": "string",
|
||||
"enum": ["INVITED", "CLAIMED", "REVOKED"],
|
||||
"title": "InvitedUserStatus"
|
||||
},
|
||||
"InvitedUsersResponse": {
|
||||
"properties": {
|
||||
"invited_users": {
|
||||
"items": { "$ref": "#/components/schemas/InvitedUserResponse" },
|
||||
"type": "array",
|
||||
"title": "Invited Users"
|
||||
},
|
||||
"pagination": { "$ref": "#/components/schemas/Pagination" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["invited_users", "pagination"],
|
||||
"title": "InvitedUsersResponse"
|
||||
},
|
||||
"LibraryAgent": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
@@ -13124,33 +12700,6 @@
|
||||
"title": "SuggestedGoalResponse",
|
||||
"description": "Response when the goal needs refinement with a suggested alternative."
|
||||
},
|
||||
"SuggestedPromptsResponse": {
|
||||
"properties": {
|
||||
"themes": {
|
||||
"items": { "$ref": "#/components/schemas/SuggestedTheme" },
|
||||
"type": "array",
|
||||
"title": "Themes"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["themes"],
|
||||
"title": "SuggestedPromptsResponse",
|
||||
"description": "Response model for user-specific suggested prompts grouped by theme."
|
||||
},
|
||||
"SuggestedTheme": {
|
||||
"properties": {
|
||||
"name": { "type": "string", "title": "Name" },
|
||||
"prompts": {
|
||||
"items": { "type": "string" },
|
||||
"type": "array",
|
||||
"title": "Prompts"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["name", "prompts"],
|
||||
"title": "SuggestedTheme",
|
||||
"description": "A themed group of suggested prompts."
|
||||
},
|
||||
"SuggestionsResponse": {
|
||||
"properties": {
|
||||
"recent_searches": {
|
||||
@@ -13176,11 +12725,6 @@
|
||||
"required": ["recent_searches", "providers", "top_blocks"],
|
||||
"title": "SuggestionsResponse"
|
||||
},
|
||||
"TallyComputationStatus": {
|
||||
"type": "string",
|
||||
"enum": ["PENDING", "RUNNING", "READY", "FAILED"],
|
||||
"title": "TallyComputationStatus"
|
||||
},
|
||||
"TimezoneResponse": {
|
||||
"properties": {
|
||||
"timezone": {
|
||||
|
||||
Reference in New Issue
Block a user