mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
602 lines
19 KiB
Python
602 lines
19 KiB
Python
import asyncio
|
|
import csv
|
|
import io
|
|
import logging
|
|
import re
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Literal, Optional
|
|
from uuid import uuid4
|
|
|
|
import prisma.enums
|
|
import prisma.models
|
|
import prisma.types
|
|
from pydantic import BaseModel, EmailStr, TypeAdapter, ValidationError
|
|
|
|
from backend.data.db import transaction
|
|
from backend.data.model import User
|
|
from backend.data.tally import get_business_understanding_input_from_tally
|
|
from backend.data.understanding import (
|
|
BusinessUnderstandingInput,
|
|
merge_business_understanding_data,
|
|
)
|
|
from backend.util.exceptions import (
|
|
NotAuthorizedError,
|
|
NotFoundError,
|
|
PreconditionFailed,
|
|
)
|
|
from backend.util.json import SafeJson
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_tally_seed_tasks: set[asyncio.Task] = set()
|
|
_email_adapter = TypeAdapter(EmailStr)
|
|
|
|
MAX_BULK_INVITE_FILE_BYTES = 1024 * 1024
|
|
MAX_BULK_INVITE_ROWS = 500
|
|
|
|
|
|
class InvitedUserRecord(BaseModel):
|
|
id: str
|
|
email: str
|
|
status: prisma.enums.InvitedUserStatus
|
|
auth_user_id: Optional[str] = None
|
|
name: Optional[str] = None
|
|
tally_understanding: Optional[dict[str, Any]] = None
|
|
tally_status: prisma.enums.TallyComputationStatus
|
|
tally_computed_at: Optional[datetime] = None
|
|
tally_error: Optional[str] = None
|
|
created_at: datetime
|
|
updated_at: datetime
|
|
|
|
@classmethod
|
|
def from_db(cls, invited_user: "prisma.models.InvitedUser") -> "InvitedUserRecord":
|
|
payload = (
|
|
invited_user.tallyUnderstanding
|
|
if isinstance(invited_user.tallyUnderstanding, dict)
|
|
else None
|
|
)
|
|
return cls(
|
|
id=invited_user.id,
|
|
email=invited_user.email,
|
|
status=invited_user.status,
|
|
auth_user_id=invited_user.authUserId,
|
|
name=invited_user.name,
|
|
tally_understanding=payload,
|
|
tally_status=invited_user.tallyStatus,
|
|
tally_computed_at=invited_user.tallyComputedAt,
|
|
tally_error=invited_user.tallyError,
|
|
created_at=invited_user.createdAt,
|
|
updated_at=invited_user.updatedAt,
|
|
)
|
|
|
|
|
|
class BulkInvitedUserRowResult(BaseModel):
|
|
row_number: int
|
|
email: Optional[str] = None
|
|
name: Optional[str] = None
|
|
status: Literal["CREATED", "SKIPPED", "ERROR"]
|
|
message: str
|
|
invited_user: Optional[InvitedUserRecord] = None
|
|
|
|
|
|
class BulkInvitedUsersResult(BaseModel):
|
|
created_count: int
|
|
skipped_count: int
|
|
error_count: int
|
|
results: list[BulkInvitedUserRowResult]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class _ParsedInviteRow:
|
|
row_number: int
|
|
email: str
|
|
name: Optional[str]
|
|
|
|
|
|
def normalize_email(email: str) -> str:
|
|
return email.strip().lower()
|
|
|
|
|
|
def _normalize_name(name: Optional[str]) -> Optional[str]:
|
|
if name is None:
|
|
return None
|
|
normalized = name.strip()
|
|
return normalized or None
|
|
|
|
|
|
def _default_profile_name(email: str, preferred_name: Optional[str]) -> str:
|
|
if preferred_name:
|
|
return preferred_name
|
|
local_part = email.split("@", 1)[0].strip()
|
|
return local_part or "user"
|
|
|
|
|
|
def _sanitize_username_base(email: str) -> str:
|
|
local_part = email.split("@", 1)[0].lower()
|
|
sanitized = re.sub(r"[^a-z0-9-]", "", local_part)
|
|
sanitized = sanitized.strip("-")
|
|
return sanitized[:40] or "user"
|
|
|
|
|
|
async def _generate_unique_profile_username(email: str, tx) -> str:
|
|
base = _sanitize_username_base(email)
|
|
|
|
for attempt in range(10):
|
|
candidate = base if attempt == 0 else f"{base}-{uuid4().hex[:6]}"
|
|
existing = await prisma.models.Profile.prisma(tx).find_unique(
|
|
where={"username": candidate}
|
|
)
|
|
if existing is None:
|
|
return candidate
|
|
|
|
raise RuntimeError(f"Unable to generate unique username for {email}")
|
|
|
|
|
|
async def _ensure_default_profile(
|
|
user_id: str,
|
|
email: str,
|
|
preferred_name: Optional[str],
|
|
tx,
|
|
) -> None:
|
|
existing_profile = await prisma.models.Profile.prisma(tx).find_unique(
|
|
where={"userId": user_id}
|
|
)
|
|
if existing_profile is not None:
|
|
return
|
|
|
|
username = await _generate_unique_profile_username(email, tx)
|
|
await prisma.models.Profile.prisma(tx).create(
|
|
data=prisma.types.ProfileCreateInput(
|
|
userId=user_id,
|
|
name=_default_profile_name(email, preferred_name),
|
|
username=username,
|
|
description="I'm new here",
|
|
links=[],
|
|
avatarUrl="",
|
|
)
|
|
)
|
|
|
|
|
|
async def _ensure_default_onboarding(user_id: str, tx) -> None:
|
|
await prisma.models.UserOnboarding.prisma(tx).upsert(
|
|
where={"userId": user_id},
|
|
data={
|
|
"create": prisma.types.UserOnboardingCreateInput(userId=user_id),
|
|
"update": {},
|
|
},
|
|
)
|
|
|
|
|
|
async def _apply_tally_understanding(
|
|
user_id: str,
|
|
invited_user: "prisma.models.InvitedUser",
|
|
tx,
|
|
) -> None:
|
|
if not isinstance(invited_user.tallyUnderstanding, dict):
|
|
return
|
|
|
|
input_data = BusinessUnderstandingInput.model_validate(
|
|
invited_user.tallyUnderstanding
|
|
)
|
|
payload = merge_business_understanding_data({}, input_data)
|
|
await prisma.models.CoPilotUnderstanding.prisma(tx).upsert(
|
|
where={"userId": user_id},
|
|
data={
|
|
"create": {"userId": user_id, "data": SafeJson(payload)},
|
|
"update": {"data": SafeJson(payload)},
|
|
},
|
|
)
|
|
|
|
|
|
async def list_invited_users() -> list[InvitedUserRecord]:
|
|
invited_users = await prisma.models.InvitedUser.prisma().find_many(
|
|
order={"createdAt": "desc"}
|
|
)
|
|
return [InvitedUserRecord.from_db(invited_user) for invited_user in invited_users]
|
|
|
|
|
|
async def create_invited_user(
|
|
email: str, name: Optional[str] = None
|
|
) -> InvitedUserRecord:
|
|
normalized_email = normalize_email(email)
|
|
normalized_name = _normalize_name(name)
|
|
|
|
existing_user = await prisma.models.User.prisma().find_unique(
|
|
where={"email": normalized_email}
|
|
)
|
|
if existing_user is not None:
|
|
raise PreconditionFailed("An active user with this email already exists")
|
|
|
|
existing_invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
|
where={"email": normalized_email}
|
|
)
|
|
if existing_invited_user is not None:
|
|
raise PreconditionFailed("An invited user with this email already exists")
|
|
|
|
invited_user = await prisma.models.InvitedUser.prisma().create(
|
|
data={
|
|
"email": normalized_email,
|
|
"name": normalized_name,
|
|
"status": prisma.enums.InvitedUserStatus.INVITED,
|
|
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
|
|
}
|
|
)
|
|
schedule_invited_user_tally_precompute(invited_user.id)
|
|
return InvitedUserRecord.from_db(invited_user)
|
|
|
|
|
|
async def revoke_invited_user(invited_user_id: str) -> InvitedUserRecord:
|
|
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
|
where={"id": invited_user_id}
|
|
)
|
|
if invited_user is None:
|
|
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
|
|
|
if invited_user.status == prisma.enums.InvitedUserStatus.CLAIMED:
|
|
raise PreconditionFailed("Claimed invited users cannot be revoked")
|
|
|
|
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
|
|
return InvitedUserRecord.from_db(invited_user)
|
|
|
|
revoked_user = await prisma.models.InvitedUser.prisma().update(
|
|
where={"id": invited_user_id},
|
|
data={"status": prisma.enums.InvitedUserStatus.REVOKED},
|
|
)
|
|
if revoked_user is None:
|
|
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
|
return InvitedUserRecord.from_db(revoked_user)
|
|
|
|
|
|
async def retry_invited_user_tally(invited_user_id: str) -> InvitedUserRecord:
|
|
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
|
where={"id": invited_user_id}
|
|
)
|
|
if invited_user is None:
|
|
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
|
|
|
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
|
|
raise PreconditionFailed("Revoked invited users cannot retry Tally seeding")
|
|
|
|
refreshed_user = await prisma.models.InvitedUser.prisma().update(
|
|
where={"id": invited_user_id},
|
|
data={
|
|
"tallyUnderstanding": None,
|
|
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
|
|
"tallyComputedAt": None,
|
|
"tallyError": None,
|
|
},
|
|
)
|
|
if refreshed_user is None:
|
|
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
|
schedule_invited_user_tally_precompute(invited_user_id)
|
|
return InvitedUserRecord.from_db(refreshed_user)
|
|
|
|
|
|
def _decode_bulk_invite_file(content: bytes) -> str:
|
|
if len(content) > MAX_BULK_INVITE_FILE_BYTES:
|
|
raise ValueError("Invite file exceeds the maximum size of 1 MB")
|
|
|
|
try:
|
|
return content.decode("utf-8-sig")
|
|
except UnicodeDecodeError as exc:
|
|
raise ValueError("Invite file must be UTF-8 encoded") from exc
|
|
|
|
|
|
def _parse_bulk_invite_csv(text: str) -> list[_ParsedInviteRow]:
|
|
indexed_rows: list[tuple[int, list[str]]] = []
|
|
|
|
for row_number, row in enumerate(csv.reader(io.StringIO(text)), start=1):
|
|
normalized_row = [cell.strip() for cell in row]
|
|
if any(normalized_row):
|
|
indexed_rows.append((row_number, normalized_row))
|
|
|
|
if not indexed_rows:
|
|
return []
|
|
|
|
header = [cell.lower() for cell in indexed_rows[0][1]]
|
|
has_header = "email" in header
|
|
email_index = header.index("email") if has_header else 0
|
|
name_index = header.index("name") if has_header and "name" in header else 1
|
|
data_rows = indexed_rows[1:] if has_header else indexed_rows
|
|
|
|
parsed_rows: list[_ParsedInviteRow] = []
|
|
for row_number, row in data_rows:
|
|
email = row[email_index].strip() if len(row) > email_index else ""
|
|
name = row[name_index].strip() if len(row) > name_index else ""
|
|
parsed_rows.append(
|
|
_ParsedInviteRow(
|
|
row_number=row_number,
|
|
email=email,
|
|
name=name or None,
|
|
)
|
|
)
|
|
|
|
return parsed_rows
|
|
|
|
|
|
def _parse_bulk_invite_text(text: str) -> list[_ParsedInviteRow]:
|
|
parsed_rows: list[_ParsedInviteRow] = []
|
|
|
|
for row_number, raw_line in enumerate(text.splitlines(), start=1):
|
|
line = raw_line.strip()
|
|
if not line or line.startswith("#"):
|
|
continue
|
|
|
|
parsed_rows.append(
|
|
_ParsedInviteRow(
|
|
row_number=row_number,
|
|
email=line,
|
|
name=None,
|
|
)
|
|
)
|
|
|
|
return parsed_rows
|
|
|
|
|
|
def _parse_bulk_invite_file(
|
|
filename: Optional[str],
|
|
content: bytes,
|
|
) -> list[_ParsedInviteRow]:
|
|
text = _decode_bulk_invite_file(content)
|
|
file_name = filename.lower() if filename else ""
|
|
parsed_rows = (
|
|
_parse_bulk_invite_csv(text)
|
|
if file_name.endswith(".csv")
|
|
else _parse_bulk_invite_text(text)
|
|
)
|
|
|
|
if not parsed_rows:
|
|
raise ValueError("Invite file did not contain any emails")
|
|
|
|
if len(parsed_rows) > MAX_BULK_INVITE_ROWS:
|
|
raise ValueError(
|
|
f"Invite file contains too many rows. Maximum supported rows: {MAX_BULK_INVITE_ROWS}"
|
|
)
|
|
|
|
return parsed_rows
|
|
|
|
|
|
async def bulk_create_invited_users_from_file(
|
|
filename: Optional[str],
|
|
content: bytes,
|
|
) -> BulkInvitedUsersResult:
|
|
parsed_rows = _parse_bulk_invite_file(filename, content)
|
|
|
|
created_count = 0
|
|
skipped_count = 0
|
|
error_count = 0
|
|
results: list[BulkInvitedUserRowResult] = []
|
|
seen_emails: set[str] = set()
|
|
|
|
for row in parsed_rows:
|
|
row_name = _normalize_name(row.name)
|
|
|
|
try:
|
|
validated_email = _email_adapter.validate_python(row.email)
|
|
except ValidationError:
|
|
error_count += 1
|
|
results.append(
|
|
BulkInvitedUserRowResult(
|
|
row_number=row.row_number,
|
|
email=row.email or None,
|
|
name=row_name,
|
|
status="ERROR",
|
|
message="Invalid email address",
|
|
)
|
|
)
|
|
continue
|
|
|
|
normalized_email = normalize_email(str(validated_email))
|
|
if normalized_email in seen_emails:
|
|
skipped_count += 1
|
|
results.append(
|
|
BulkInvitedUserRowResult(
|
|
row_number=row.row_number,
|
|
email=normalized_email,
|
|
name=row_name,
|
|
status="SKIPPED",
|
|
message="Duplicate email in upload file",
|
|
)
|
|
)
|
|
continue
|
|
|
|
seen_emails.add(normalized_email)
|
|
|
|
try:
|
|
invited_user = await create_invited_user(normalized_email, row_name)
|
|
except PreconditionFailed as exc:
|
|
skipped_count += 1
|
|
results.append(
|
|
BulkInvitedUserRowResult(
|
|
row_number=row.row_number,
|
|
email=normalized_email,
|
|
name=row_name,
|
|
status="SKIPPED",
|
|
message=str(exc),
|
|
)
|
|
)
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to create bulk invite for %s from row %s",
|
|
normalized_email,
|
|
row.row_number,
|
|
)
|
|
error_count += 1
|
|
results.append(
|
|
BulkInvitedUserRowResult(
|
|
row_number=row.row_number,
|
|
email=normalized_email,
|
|
name=row_name,
|
|
status="ERROR",
|
|
message="Unexpected error creating invite",
|
|
)
|
|
)
|
|
else:
|
|
created_count += 1
|
|
results.append(
|
|
BulkInvitedUserRowResult(
|
|
row_number=row.row_number,
|
|
email=normalized_email,
|
|
name=row_name,
|
|
status="CREATED",
|
|
message="Invite created",
|
|
invited_user=invited_user,
|
|
)
|
|
)
|
|
|
|
return BulkInvitedUsersResult(
|
|
created_count=created_count,
|
|
skipped_count=skipped_count,
|
|
error_count=error_count,
|
|
results=results,
|
|
)
|
|
|
|
|
|
async def _compute_invited_user_tally_seed(invited_user_id: str) -> None:
|
|
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
|
where={"id": invited_user_id}
|
|
)
|
|
if invited_user is None:
|
|
return
|
|
|
|
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
|
|
return
|
|
|
|
await prisma.models.InvitedUser.prisma().update(
|
|
where={"id": invited_user_id},
|
|
data={
|
|
"tallyStatus": prisma.enums.TallyComputationStatus.RUNNING,
|
|
"tallyError": None,
|
|
},
|
|
)
|
|
|
|
try:
|
|
input_data = await get_business_understanding_input_from_tally(
|
|
invited_user.email,
|
|
require_api_key=True,
|
|
)
|
|
payload = (
|
|
SafeJson(input_data.model_dump(exclude_none=True))
|
|
if input_data is not None
|
|
else None
|
|
)
|
|
await prisma.models.InvitedUser.prisma().update(
|
|
where={"id": invited_user_id},
|
|
data={
|
|
"tallyUnderstanding": payload,
|
|
"tallyStatus": prisma.enums.TallyComputationStatus.READY,
|
|
"tallyComputedAt": datetime.now(timezone.utc),
|
|
"tallyError": None,
|
|
},
|
|
)
|
|
except Exception as exc:
|
|
logger.exception(
|
|
"Failed to compute Tally understanding for invited user %s",
|
|
invited_user_id,
|
|
)
|
|
await prisma.models.InvitedUser.prisma().update(
|
|
where={"id": invited_user_id},
|
|
data={
|
|
"tallyStatus": prisma.enums.TallyComputationStatus.FAILED,
|
|
"tallyError": str(exc),
|
|
},
|
|
)
|
|
|
|
|
|
def schedule_invited_user_tally_precompute(invited_user_id: str) -> None:
|
|
task = asyncio.create_task(_compute_invited_user_tally_seed(invited_user_id))
|
|
_tally_seed_tasks.add(task)
|
|
task.add_done_callback(_tally_seed_tasks.discard)
|
|
|
|
|
|
async def get_or_activate_user(user_data: dict) -> User:
|
|
auth_user_id = user_data.get("sub")
|
|
if not auth_user_id:
|
|
raise NotAuthorizedError("User ID not found in token")
|
|
|
|
auth_email = user_data.get("email")
|
|
if not auth_email:
|
|
raise NotAuthorizedError("Email not found in token")
|
|
|
|
normalized_email = normalize_email(auth_email)
|
|
user_metadata = user_data.get("user_metadata")
|
|
metadata_name = (
|
|
user_metadata.get("name") if isinstance(user_metadata, dict) else None
|
|
)
|
|
|
|
existing_user = await prisma.models.User.prisma().find_unique(
|
|
where={"id": auth_user_id}
|
|
)
|
|
if existing_user is not None:
|
|
return User.from_db(existing_user)
|
|
|
|
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
|
where={"email": normalized_email}
|
|
)
|
|
if invited_user is None:
|
|
raise NotAuthorizedError("Your email is not allowed to access the platform")
|
|
|
|
if invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
|
|
raise NotAuthorizedError("Your invitation is no longer active")
|
|
|
|
async with transaction() as tx:
|
|
current_user = await prisma.models.User.prisma(tx).find_unique(
|
|
where={"id": auth_user_id}
|
|
)
|
|
if current_user is not None:
|
|
return User.from_db(current_user)
|
|
|
|
current_invited_user = await prisma.models.InvitedUser.prisma(tx).find_unique(
|
|
where={"email": normalized_email}
|
|
)
|
|
if current_invited_user is None:
|
|
raise NotAuthorizedError("Your email is not allowed to access the platform")
|
|
|
|
if current_invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
|
|
raise NotAuthorizedError("Your invitation is no longer active")
|
|
|
|
if current_invited_user.authUserId not in (None, auth_user_id):
|
|
raise NotAuthorizedError("Your invitation has already been claimed")
|
|
|
|
preferred_name = current_invited_user.name or _normalize_name(metadata_name)
|
|
await prisma.models.User.prisma(tx).create(
|
|
data=prisma.types.UserCreateInput(
|
|
id=auth_user_id,
|
|
email=normalized_email,
|
|
name=preferred_name,
|
|
)
|
|
)
|
|
|
|
await prisma.models.InvitedUser.prisma(tx).update(
|
|
where={"id": current_invited_user.id},
|
|
data={
|
|
"status": prisma.enums.InvitedUserStatus.CLAIMED,
|
|
"authUserId": auth_user_id,
|
|
},
|
|
)
|
|
|
|
await _ensure_default_profile(
|
|
auth_user_id,
|
|
normalized_email,
|
|
preferred_name,
|
|
tx,
|
|
)
|
|
await _ensure_default_onboarding(auth_user_id, tx)
|
|
await _apply_tally_understanding(auth_user_id, current_invited_user, tx)
|
|
|
|
from backend.data.user import get_user_by_email, get_user_by_id
|
|
|
|
get_user_by_id.cache_delete(auth_user_id)
|
|
get_user_by_email.cache_delete(normalized_email)
|
|
|
|
activated_user = await prisma.models.User.prisma().find_unique(
|
|
where={"id": auth_user_id}
|
|
)
|
|
if activated_user is None:
|
|
raise RuntimeError(
|
|
f"Activated user {auth_user_id} was not found after creation"
|
|
)
|
|
|
|
return User.from_db(activated_user)
|