mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-16 17:55:55 -05:00
Compare commits
3 Commits
dev
...
swiftyos/s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9c2a0229d5 | ||
|
|
2a05a39de7 | ||
|
|
92d2da442d |
@@ -190,5 +190,8 @@ ZEROBOUNCE_API_KEY=
|
||||
POSTHOG_API_KEY=
|
||||
POSTHOG_HOST=https://eu.i.posthog.com
|
||||
|
||||
# Tally Form Integration (pre-populate business understanding on signup)
|
||||
TALLY_API_KEY=
|
||||
|
||||
# Other Services
|
||||
AUTOMOD_API_KEY=
|
||||
|
||||
@@ -1016,7 +1016,7 @@ async def health_check() -> dict:
|
||||
|
||||
# Ensure health check user exists (required for FK constraint)
|
||||
health_check_user_id = "health-check-user"
|
||||
await get_or_create_user(
|
||||
await get_or_create_user( # returns (User, is_new); we only need the side-effect
|
||||
{
|
||||
"sub": health_check_user_id,
|
||||
"email": "health-check@system.local",
|
||||
|
||||
@@ -45,7 +45,7 @@ async def setup_test_data():
|
||||
"sub": f"test-user-{uuid.uuid4()}",
|
||||
"email": f"test-{uuid.uuid4()}@example.com",
|
||||
}
|
||||
user = await get_or_create_user(user_data)
|
||||
user, _ = await get_or_create_user(user_data)
|
||||
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
@@ -168,7 +168,7 @@ async def setup_llm_test_data():
|
||||
"sub": f"test-user-{uuid.uuid4()}",
|
||||
"email": f"test-{uuid.uuid4()}@example.com",
|
||||
}
|
||||
user = await get_or_create_user(user_data)
|
||||
user, _ = await get_or_create_user(user_data)
|
||||
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
@@ -328,7 +328,7 @@ async def setup_firecrawl_test_data():
|
||||
"sub": f"test-user-{uuid.uuid4()}",
|
||||
"email": f"test-{uuid.uuid4()}@example.com",
|
||||
}
|
||||
user = await get_or_create_user(user_data)
|
||||
user, _ = await get_or_create_user(user_data)
|
||||
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
|
||||
@@ -126,6 +126,9 @@ v1_router = APIRouter()
|
||||
########################################################
|
||||
|
||||
|
||||
_tally_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/auth/user",
|
||||
summary="Get or create user",
|
||||
@@ -133,7 +136,21 @@ v1_router = APIRouter()
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
user = await get_or_create_user(user_data)
|
||||
user, is_new = await get_or_create_user(user_data)
|
||||
|
||||
# Fire-and-forget: populate business understanding from Tally form
|
||||
if is_new:
|
||||
try:
|
||||
from backend.data.tally import populate_understanding_from_tally
|
||||
|
||||
task = asyncio.create_task(
|
||||
populate_understanding_from_tally(user.id, user.email)
|
||||
)
|
||||
_tally_background_tasks.add(task)
|
||||
task.add_done_callback(_tally_background_tasks.discard)
|
||||
except Exception:
|
||||
pass # Never block user creation
|
||||
|
||||
return user.model_dump()
|
||||
|
||||
|
||||
@@ -161,7 +178,7 @@ async def get_user_timezone_route(
|
||||
user_data: dict = Security(get_jwt_payload),
|
||||
) -> TimezoneResponse:
|
||||
"""Get user timezone setting."""
|
||||
user = await get_or_create_user(user_data)
|
||||
user, _ = await get_or_create_user(user_data)
|
||||
return TimezoneResponse(timezone=user.timezone)
|
||||
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ def test_get_or_create_user_route(
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_or_create_user",
|
||||
return_value=mock_user,
|
||||
return_value=(mock_user, False),
|
||||
)
|
||||
|
||||
response = client.post("/auth/user")
|
||||
|
||||
418
autogpt_platform/backend/backend/data/tally.py
Normal file
418
autogpt_platform/backend/backend/data/tally.py
Normal file
@@ -0,0 +1,418 @@
|
||||
"""Tally form integration: cache submissions, match by email, extract business understanding."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
get_business_understanding,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
from backend.util.request import Requests
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TALLY_API_BASE = "https://api.tally.so"
|
||||
TALLY_FORM_ID = "npGe0q"
|
||||
|
||||
# Redis key templates
|
||||
_EMAIL_INDEX_KEY = "tally:form:{form_id}:email_index"
|
||||
_QUESTIONS_KEY = "tally:form:{form_id}:questions"
|
||||
_LAST_FETCH_KEY = "tally:form:{form_id}:last_fetch"
|
||||
|
||||
# TTLs
|
||||
_INDEX_TTL = 3600 # 1 hour
|
||||
_LAST_FETCH_TTL = 7200 # 2 hours
|
||||
|
||||
# Pagination
|
||||
_PAGE_LIMIT = 500
|
||||
_MAX_PAGES = 100
|
||||
|
||||
# LLM extraction timeout (seconds)
|
||||
_LLM_TIMEOUT = 30
|
||||
|
||||
|
||||
def _mask_email(email: str) -> str:
|
||||
"""Mask an email for safe logging: 'alice@example.com' -> 'a***e@example.com'."""
|
||||
try:
|
||||
local, domain = email.rsplit("@", 1)
|
||||
if len(local) <= 2:
|
||||
masked_local = local[0] + "***"
|
||||
else:
|
||||
masked_local = local[0] + "***" + local[-1]
|
||||
return f"{masked_local}@{domain}"
|
||||
except (ValueError, IndexError):
|
||||
return "***"
|
||||
|
||||
|
||||
async def _fetch_tally_page(
|
||||
form_id: str,
|
||||
page: int,
|
||||
limit: int = _PAGE_LIMIT,
|
||||
start_date: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Fetch a single page of submissions from the Tally API."""
|
||||
settings = Settings()
|
||||
api_key = settings.secrets.tally_api_key
|
||||
|
||||
url = f"{TALLY_API_BASE}/forms/{form_id}/submissions?page={page}&limit={limit}"
|
||||
if start_date:
|
||||
url += f"&startDate={start_date}"
|
||||
|
||||
client = Requests(
|
||||
trusted_origins=[TALLY_API_BASE],
|
||||
raise_for_status=True,
|
||||
extra_headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
response = await client.get(url)
|
||||
return response.json()
|
||||
|
||||
|
||||
async def _fetch_all_submissions(
|
||||
form_id: str,
|
||||
start_date: Optional[str] = None,
|
||||
max_pages: int = _MAX_PAGES,
|
||||
) -> tuple[list[dict], list[dict]]:
|
||||
"""Paginate through all Tally submissions. Returns (questions, submissions)."""
|
||||
questions: list[dict] = []
|
||||
all_submissions: list[dict] = []
|
||||
page = 1
|
||||
|
||||
while True:
|
||||
data = await _fetch_tally_page(form_id, page, start_date=start_date)
|
||||
|
||||
if page == 1:
|
||||
questions = data.get("questions", [])
|
||||
|
||||
submissions = data.get("submissions", [])
|
||||
all_submissions.extend(submissions)
|
||||
|
||||
total_pages = data.get("totalNumberOfPages", 1)
|
||||
if page >= total_pages:
|
||||
break
|
||||
if page >= max_pages:
|
||||
logger.warning(
|
||||
f"Tally: hit max page cap ({max_pages}) for form {form_id}, "
|
||||
f"API reports {total_pages} total pages"
|
||||
)
|
||||
break
|
||||
page += 1
|
||||
|
||||
return questions, all_submissions
|
||||
|
||||
|
||||
def _build_email_index(
|
||||
submissions: list[dict], questions: list[dict]
|
||||
) -> dict[str, dict]:
|
||||
"""Build an {email -> submission_data} index from submissions.
|
||||
|
||||
Scans question titles for email/contact fields to find the email answer.
|
||||
"""
|
||||
# Find question IDs that are likely email fields
|
||||
email_question_ids: list[str] = []
|
||||
for q in questions:
|
||||
label = (q.get("label") or q.get("title") or q.get("name") or "").lower()
|
||||
q_type = (q.get("type") or "").lower()
|
||||
if q_type in ("input_email", "email"):
|
||||
email_question_ids.append(q["id"])
|
||||
elif any(kw in label for kw in ("email", "e-mail", "contact")):
|
||||
email_question_ids.append(q["id"])
|
||||
|
||||
index: dict[str, dict] = {}
|
||||
for sub in submissions:
|
||||
email = _extract_email_from_submission(sub, email_question_ids)
|
||||
if email:
|
||||
index[email.lower()] = {
|
||||
"responses": sub.get("responses", sub.get("fields", [])),
|
||||
"submitted_at": sub.get("submittedAt", sub.get("createdAt", "")),
|
||||
"questions": sub.get("questions", []),
|
||||
}
|
||||
return index
|
||||
|
||||
|
||||
def _extract_email_from_submission(
|
||||
submission: dict, email_question_ids: list[str]
|
||||
) -> Optional[str]:
|
||||
"""Extract email address from a submission's responses."""
|
||||
# Try respondent email first (Tally often includes this)
|
||||
respondent_email = submission.get("respondentEmail")
|
||||
if respondent_email:
|
||||
return respondent_email
|
||||
|
||||
# Search through responses/fields for matching question IDs
|
||||
responses = submission.get("responses", submission.get("fields", []))
|
||||
if isinstance(responses, list):
|
||||
for resp in responses:
|
||||
q_id = resp.get("questionId") or resp.get("key") or resp.get("id")
|
||||
if q_id in email_question_ids:
|
||||
value = resp.get("value") or resp.get("answer")
|
||||
if isinstance(value, str) and "@" in value:
|
||||
return value
|
||||
elif isinstance(responses, dict):
|
||||
for q_id in email_question_ids:
|
||||
value = responses.get(q_id)
|
||||
if isinstance(value, str) and "@" in value:
|
||||
return value
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _get_cached_index(
|
||||
form_id: str,
|
||||
) -> tuple[Optional[dict], Optional[list]]:
|
||||
"""Check Redis for cached email index and questions. Returns (index, questions) or (None, None)."""
|
||||
redis = await get_redis_async()
|
||||
index_key = _EMAIL_INDEX_KEY.format(form_id=form_id)
|
||||
questions_key = _QUESTIONS_KEY.format(form_id=form_id)
|
||||
|
||||
raw_index = await redis.get(index_key)
|
||||
raw_questions = await redis.get(questions_key)
|
||||
|
||||
if raw_index and raw_questions:
|
||||
return json.loads(raw_index), json.loads(raw_questions)
|
||||
return None, None
|
||||
|
||||
|
||||
async def _refresh_cache(form_id: str) -> tuple[dict, list]:
|
||||
"""Refresh the Tally submission cache. Uses incremental fetch when possible.
|
||||
|
||||
Returns (email_index, questions).
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
last_fetch_key = _LAST_FETCH_KEY.format(form_id=form_id)
|
||||
index_key = _EMAIL_INDEX_KEY.format(form_id=form_id)
|
||||
questions_key = _QUESTIONS_KEY.format(form_id=form_id)
|
||||
|
||||
last_fetch = await redis.get(last_fetch_key)
|
||||
|
||||
if last_fetch:
|
||||
# Try to load existing index for incremental merge
|
||||
raw_existing = await redis.get(index_key)
|
||||
|
||||
if raw_existing is None:
|
||||
# Index expired but last_fetch still present — fall back to full fetch
|
||||
logger.info("Tally: last_fetch present but index missing, doing full fetch")
|
||||
questions, submissions = await _fetch_all_submissions(form_id)
|
||||
email_index = _build_email_index(submissions, questions)
|
||||
else:
|
||||
# Incremental fetch: only get new submissions since last fetch
|
||||
logger.info(f"Tally incremental fetch since {last_fetch}")
|
||||
questions, new_submissions = await _fetch_all_submissions(
|
||||
form_id, start_date=last_fetch
|
||||
)
|
||||
|
||||
existing_index: dict[str, dict] = json.loads(raw_existing)
|
||||
|
||||
if not questions:
|
||||
raw_q = await redis.get(questions_key)
|
||||
if raw_q:
|
||||
questions = json.loads(raw_q)
|
||||
|
||||
new_index = _build_email_index(new_submissions, questions)
|
||||
existing_index.update(new_index)
|
||||
email_index = existing_index
|
||||
else:
|
||||
# Full initial fetch
|
||||
logger.info("Tally full initial fetch")
|
||||
questions, submissions = await _fetch_all_submissions(form_id)
|
||||
email_index = _build_email_index(submissions, questions)
|
||||
|
||||
# Store in Redis
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
await redis.setex(index_key, _INDEX_TTL, json.dumps(email_index))
|
||||
await redis.setex(questions_key, _INDEX_TTL, json.dumps(questions))
|
||||
await redis.setex(last_fetch_key, _LAST_FETCH_TTL, now)
|
||||
|
||||
logger.info(f"Tally cache refreshed: {len(email_index)} emails indexed")
|
||||
return email_index, questions
|
||||
|
||||
|
||||
async def find_submission_by_email(
|
||||
form_id: str, email: str
|
||||
) -> Optional[tuple[dict, list]]:
|
||||
"""Look up a Tally submission by email. Uses cache when available.
|
||||
|
||||
Returns (submission_data, questions) or None.
|
||||
"""
|
||||
email_lower = email.lower()
|
||||
|
||||
# Try cache first
|
||||
email_index, questions = await _get_cached_index(form_id)
|
||||
if email_index is not None and questions is not None:
|
||||
sub = email_index.get(email_lower)
|
||||
if sub is not None:
|
||||
return sub, questions
|
||||
return None
|
||||
|
||||
# Cache miss - refresh
|
||||
email_index, questions = await _refresh_cache(form_id)
|
||||
sub = email_index.get(email_lower)
|
||||
if sub is not None:
|
||||
return sub, questions
|
||||
return None
|
||||
|
||||
|
||||
def format_submission_for_llm(submission: dict, questions: list[dict]) -> str:
|
||||
"""Format a submission as readable Q&A text for LLM consumption."""
|
||||
# Build question ID -> title lookup
|
||||
q_titles: dict[str, str] = {}
|
||||
for q in questions:
|
||||
q_id = q.get("id", "")
|
||||
title = q.get("label") or q.get("title") or q.get("name") or f"Question {q_id}"
|
||||
q_titles[q_id] = title
|
||||
|
||||
lines: list[str] = []
|
||||
responses = submission.get("responses", [])
|
||||
|
||||
if isinstance(responses, list):
|
||||
for resp in responses:
|
||||
q_id = resp.get("questionId") or resp.get("key") or resp.get("id") or ""
|
||||
title = q_titles.get(q_id, f"Question {q_id}")
|
||||
value = resp.get("value") or resp.get("answer") or ""
|
||||
lines.append(f"Q: {title}\nA: {_format_answer(value)}")
|
||||
elif isinstance(responses, dict):
|
||||
for q_id, value in responses.items():
|
||||
title = q_titles.get(q_id, f"Question {q_id}")
|
||||
lines.append(f"Q: {title}\nA: {_format_answer(value)}")
|
||||
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
def _format_answer(value: object) -> str:
|
||||
"""Format an answer value for display."""
|
||||
if value is None:
|
||||
return "(no answer)"
|
||||
if isinstance(value, list):
|
||||
return ", ".join(str(v) for v in value)
|
||||
if isinstance(value, dict):
|
||||
parts = [f"{k}: {v}" for k, v in value.items() if v]
|
||||
return "; ".join(parts) if parts else "(no answer)"
|
||||
return str(value)
|
||||
|
||||
|
||||
_EXTRACTION_PROMPT = """\
|
||||
You are a business analyst. Given the following form submission data, extract structured business understanding information.
|
||||
|
||||
Return a JSON object with ONLY the fields that can be confidently extracted. Use null for fields that cannot be determined.
|
||||
|
||||
Fields:
|
||||
- user_name (string): the person's name
|
||||
- job_title (string): their job title
|
||||
- business_name (string): company/business name
|
||||
- industry (string): industry or sector
|
||||
- business_size (string): company size e.g. "1-10", "11-50", "51-200"
|
||||
- user_role (string): their role context e.g. "decision maker", "implementer"
|
||||
- key_workflows (list of strings): key business workflows
|
||||
- daily_activities (list of strings): daily activities performed
|
||||
- pain_points (list of strings): current pain points
|
||||
- bottlenecks (list of strings): process bottlenecks
|
||||
- manual_tasks (list of strings): manual/repetitive tasks
|
||||
- automation_goals (list of strings): desired automation goals
|
||||
- current_software (list of strings): software/tools currently used
|
||||
- existing_automation (list of strings): existing automations
|
||||
- additional_notes (string): any additional context
|
||||
|
||||
Form data:
|
||||
{submission_text}
|
||||
|
||||
Return ONLY valid JSON."""
|
||||
|
||||
|
||||
async def extract_business_understanding(
|
||||
formatted_text: str,
|
||||
) -> BusinessUnderstandingInput:
|
||||
"""Use an LLM to extract structured business understanding from form text.
|
||||
|
||||
Raises on timeout or unparseable response so the caller can handle it.
|
||||
"""
|
||||
settings = Settings()
|
||||
api_key = settings.secrets.open_router_api_key
|
||||
client = AsyncOpenAI(api_key=api_key, base_url="https://openrouter.ai/api/v1")
|
||||
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
client.chat.completions.create(
|
||||
model="openai/gpt-4o-mini",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": _EXTRACTION_PROMPT.format(
|
||||
submission_text=formatted_text
|
||||
),
|
||||
}
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.0,
|
||||
),
|
||||
timeout=_LLM_TIMEOUT,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Tally: LLM extraction timed out")
|
||||
raise
|
||||
|
||||
raw = response.choices[0].message.content or "{}"
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Tally: LLM returned invalid JSON, skipping extraction")
|
||||
raise
|
||||
|
||||
# Filter out null values before constructing
|
||||
cleaned = {k: v for k, v in data.items() if v is not None}
|
||||
return BusinessUnderstandingInput(**cleaned)
|
||||
|
||||
|
||||
async def populate_understanding_from_tally(user_id: str, email: str) -> None:
|
||||
"""Main orchestrator: check Tally for a matching submission and populate understanding.
|
||||
|
||||
Fire-and-forget safe — all exceptions are caught and logged.
|
||||
"""
|
||||
try:
|
||||
# Check if understanding already exists (idempotency)
|
||||
existing = await get_business_understanding(user_id)
|
||||
if existing is not None:
|
||||
logger.debug(
|
||||
f"Tally: user {user_id} already has business understanding, skipping"
|
||||
)
|
||||
return
|
||||
|
||||
# Check API key is configured
|
||||
settings = Settings()
|
||||
if not settings.secrets.tally_api_key:
|
||||
logger.debug("Tally: no API key configured, skipping")
|
||||
return
|
||||
|
||||
# Look up submission by email
|
||||
masked = _mask_email(email)
|
||||
result = await find_submission_by_email(TALLY_FORM_ID, email)
|
||||
if result is None:
|
||||
logger.debug(f"Tally: no submission found for {masked}")
|
||||
return
|
||||
|
||||
submission, questions = result
|
||||
logger.info(f"Tally: found submission for {masked}, extracting understanding")
|
||||
|
||||
# Format and extract
|
||||
formatted = format_submission_for_llm(submission, questions)
|
||||
if not formatted.strip():
|
||||
logger.warning("Tally: formatted submission was empty, skipping")
|
||||
return
|
||||
|
||||
understanding_input = await extract_business_understanding(formatted)
|
||||
|
||||
# Upsert into database
|
||||
await upsert_business_understanding(user_id, understanding_input)
|
||||
logger.info(f"Tally: successfully populated understanding for user {user_id}")
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Tally: error populating understanding for user {user_id}")
|
||||
358
autogpt_platform/backend/backend/data/tally_test.py
Normal file
358
autogpt_platform/backend/backend/data/tally_test.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""Tests for backend.data.tally module."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.tally import (
|
||||
_build_email_index,
|
||||
_format_answer,
|
||||
_mask_email,
|
||||
find_submission_by_email,
|
||||
format_submission_for_llm,
|
||||
populate_understanding_from_tally,
|
||||
)
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||
|
||||
SAMPLE_QUESTIONS = [
|
||||
{"id": "q1", "label": "What is your name?", "type": "INPUT_TEXT"},
|
||||
{"id": "q2", "label": "Email address", "type": "INPUT_EMAIL"},
|
||||
{"id": "q3", "label": "Company name", "type": "INPUT_TEXT"},
|
||||
{"id": "q4", "label": "Industry", "type": "INPUT_TEXT"},
|
||||
]
|
||||
|
||||
SAMPLE_SUBMISSIONS = [
|
||||
{
|
||||
"respondentEmail": None,
|
||||
"responses": [
|
||||
{"questionId": "q1", "value": "Alice Smith"},
|
||||
{"questionId": "q2", "value": "alice@example.com"},
|
||||
{"questionId": "q3", "value": "Acme Corp"},
|
||||
{"questionId": "q4", "value": "Technology"},
|
||||
],
|
||||
"submittedAt": "2025-01-15T10:00:00Z",
|
||||
},
|
||||
{
|
||||
"respondentEmail": "bob@example.com",
|
||||
"responses": [
|
||||
{"questionId": "q1", "value": "Bob Jones"},
|
||||
{"questionId": "q2", "value": "bob@example.com"},
|
||||
{"questionId": "q3", "value": "Bob's Burgers"},
|
||||
{"questionId": "q4", "value": "Food"},
|
||||
],
|
||||
"submittedAt": "2025-01-16T10:00:00Z",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# ── _build_email_index ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_build_email_index():
|
||||
index = _build_email_index(SAMPLE_SUBMISSIONS, SAMPLE_QUESTIONS)
|
||||
assert "alice@example.com" in index
|
||||
assert "bob@example.com" in index
|
||||
assert len(index) == 2
|
||||
|
||||
|
||||
def test_build_email_index_case_insensitive():
|
||||
submissions = [
|
||||
{
|
||||
"respondentEmail": None,
|
||||
"responses": [
|
||||
{"questionId": "q2", "value": "Alice@Example.COM"},
|
||||
],
|
||||
"submittedAt": "2025-01-15T10:00:00Z",
|
||||
},
|
||||
]
|
||||
index = _build_email_index(submissions, SAMPLE_QUESTIONS)
|
||||
assert "alice@example.com" in index
|
||||
assert "Alice@Example.COM" not in index
|
||||
|
||||
|
||||
def test_build_email_index_empty():
|
||||
index = _build_email_index([], SAMPLE_QUESTIONS)
|
||||
assert index == {}
|
||||
|
||||
|
||||
def test_build_email_index_no_email_field():
|
||||
questions = [{"id": "q1", "label": "Name", "type": "INPUT_TEXT"}]
|
||||
submissions = [
|
||||
{
|
||||
"responses": [{"questionId": "q1", "value": "Alice"}],
|
||||
"submittedAt": "2025-01-15T10:00:00Z",
|
||||
}
|
||||
]
|
||||
index = _build_email_index(submissions, questions)
|
||||
assert index == {}
|
||||
|
||||
|
||||
def test_build_email_index_respondent_email():
|
||||
"""respondentEmail takes precedence over field scanning."""
|
||||
submissions = [
|
||||
{
|
||||
"respondentEmail": "direct@example.com",
|
||||
"responses": [
|
||||
{"questionId": "q2", "value": "field@example.com"},
|
||||
],
|
||||
"submittedAt": "2025-01-15T10:00:00Z",
|
||||
}
|
||||
]
|
||||
index = _build_email_index(submissions, SAMPLE_QUESTIONS)
|
||||
assert "direct@example.com" in index
|
||||
assert "field@example.com" not in index
|
||||
|
||||
|
||||
# ── format_submission_for_llm ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_format_submission_for_llm():
|
||||
submission = {
|
||||
"responses": [
|
||||
{"questionId": "q1", "value": "Alice Smith"},
|
||||
{"questionId": "q3", "value": "Acme Corp"},
|
||||
],
|
||||
}
|
||||
result = format_submission_for_llm(submission, SAMPLE_QUESTIONS)
|
||||
assert "Q: What is your name?" in result
|
||||
assert "A: Alice Smith" in result
|
||||
assert "Q: Company name" in result
|
||||
assert "A: Acme Corp" in result
|
||||
|
||||
|
||||
def test_format_submission_for_llm_dict_responses():
|
||||
submission = {
|
||||
"responses": {
|
||||
"q1": "Alice Smith",
|
||||
"q3": "Acme Corp",
|
||||
},
|
||||
}
|
||||
result = format_submission_for_llm(submission, SAMPLE_QUESTIONS)
|
||||
assert "A: Alice Smith" in result
|
||||
assert "A: Acme Corp" in result
|
||||
|
||||
|
||||
def test_format_answer_types():
|
||||
assert _format_answer(None) == "(no answer)"
|
||||
assert _format_answer("hello") == "hello"
|
||||
assert _format_answer(["a", "b"]) == "a, b"
|
||||
assert _format_answer({"key": "val"}) == "key: val"
|
||||
assert _format_answer(42) == "42"
|
||||
|
||||
|
||||
# ── find_submission_by_email ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_submission_by_email_cache_hit():
|
||||
cached_index = {
|
||||
"alice@example.com": {"responses": [], "submitted_at": "2025-01-15"},
|
||||
}
|
||||
cached_questions = SAMPLE_QUESTIONS
|
||||
|
||||
with patch(
|
||||
"backend.data.tally._get_cached_index",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(cached_index, cached_questions),
|
||||
) as mock_cache:
|
||||
result = await find_submission_by_email("form123", "alice@example.com")
|
||||
|
||||
mock_cache.assert_awaited_once_with("form123")
|
||||
assert result is not None
|
||||
sub, questions = result
|
||||
assert sub["submitted_at"] == "2025-01-15"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_submission_by_email_cache_miss():
|
||||
refreshed_index = {
|
||||
"alice@example.com": {"responses": [], "submitted_at": "2025-01-15"},
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.tally._get_cached_index",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(None, None),
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally._refresh_cache",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(refreshed_index, SAMPLE_QUESTIONS),
|
||||
) as mock_refresh,
|
||||
):
|
||||
result = await find_submission_by_email("form123", "alice@example.com")
|
||||
|
||||
mock_refresh.assert_awaited_once_with("form123")
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_submission_by_email_no_match():
|
||||
cached_index = {
|
||||
"alice@example.com": {"responses": [], "submitted_at": "2025-01-15"},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"backend.data.tally._get_cached_index",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(cached_index, SAMPLE_QUESTIONS),
|
||||
):
|
||||
result = await find_submission_by_email("form123", "unknown@example.com")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# ── populate_understanding_from_tally ─────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_understanding_skips_existing():
|
||||
"""If user already has understanding, skip entirely."""
|
||||
mock_understanding = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.tally.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_understanding,
|
||||
) as mock_get,
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_find,
|
||||
):
|
||||
await populate_understanding_from_tally("user-1", "test@example.com")
|
||||
|
||||
mock_get.assert_awaited_once_with("user-1")
|
||||
mock_find.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_understanding_skips_no_api_key():
|
||||
"""If no Tally API key, skip gracefully."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.tally_api_key = ""
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.tally.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_find,
|
||||
):
|
||||
await populate_understanding_from_tally("user-1", "test@example.com")
|
||||
|
||||
mock_find.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_understanding_handles_errors():
|
||||
"""Must never raise, even on unexpected errors."""
|
||||
with patch(
|
||||
"backend.data.tally.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("DB down"),
|
||||
):
|
||||
# Should not raise
|
||||
await populate_understanding_from_tally("user-1", "test@example.com")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_understanding_full_flow():
|
||||
"""Happy path: no existing understanding, finds submission, extracts, upserts."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.tally_api_key = "test-key"
|
||||
|
||||
submission = {
|
||||
"responses": [
|
||||
{"questionId": "q1", "value": "Alice"},
|
||||
{"questionId": "q3", "value": "Acme"},
|
||||
],
|
||||
}
|
||||
mock_input = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.tally.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(submission, SAMPLE_QUESTIONS),
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally.extract_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_input,
|
||||
) as mock_extract,
|
||||
patch(
|
||||
"backend.data.tally.upsert_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_upsert,
|
||||
):
|
||||
await populate_understanding_from_tally("user-1", "alice@example.com")
|
||||
|
||||
mock_extract.assert_awaited_once()
|
||||
mock_upsert.assert_awaited_once_with("user-1", mock_input)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_understanding_handles_llm_timeout():
|
||||
"""LLM timeout is caught and doesn't raise."""
|
||||
import asyncio
|
||||
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.tally_api_key = "test-key"
|
||||
|
||||
submission = {
|
||||
"responses": [{"questionId": "q1", "value": "Alice"}],
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.tally.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(submission, SAMPLE_QUESTIONS),
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally.extract_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=asyncio.TimeoutError(),
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally.upsert_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_upsert,
|
||||
):
|
||||
await populate_understanding_from_tally("user-1", "alice@example.com")
|
||||
|
||||
mock_upsert.assert_not_awaited()
|
||||
|
||||
|
||||
# ── _mask_email ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_mask_email():
|
||||
assert _mask_email("alice@example.com") == "a***e@example.com"
|
||||
assert _mask_email("ab@example.com") == "a***@example.com"
|
||||
assert _mask_email("a@example.com") == "a***@example.com"
|
||||
|
||||
|
||||
def test_mask_email_invalid():
|
||||
assert _mask_email("no-at-sign") == "***"
|
||||
@@ -29,7 +29,12 @@ cache_user_lookup = cached(maxsize=1000, ttl_seconds=300)
|
||||
|
||||
|
||||
@cache_user_lookup
|
||||
async def get_or_create_user(user_data: dict) -> User:
|
||||
async def get_or_create_user(user_data: dict) -> tuple[User, bool]:
|
||||
"""Get existing user or create a new one.
|
||||
|
||||
Returns:
|
||||
A tuple of (User, is_new) where is_new is True if the user was just created.
|
||||
"""
|
||||
try:
|
||||
user_id = user_data.get("sub")
|
||||
if not user_id:
|
||||
@@ -39,6 +44,7 @@ async def get_or_create_user(user_data: dict) -> User:
|
||||
if not user_email:
|
||||
raise HTTPException(status_code=401, detail="Email not found in token")
|
||||
|
||||
is_new = False
|
||||
user = await prisma.user.find_unique(where={"id": user_id})
|
||||
if not user:
|
||||
user = await prisma.user.create(
|
||||
@@ -48,8 +54,9 @@ async def get_or_create_user(user_data: dict) -> User:
|
||||
name=user_data.get("user_metadata", {}).get("name"),
|
||||
)
|
||||
)
|
||||
is_new = True
|
||||
|
||||
return User.from_db(user)
|
||||
return User.from_db(user), is_new
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to get or create user {user_data}: {e}") from e
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ async def create_test_user() -> User:
|
||||
"email": "testuser@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
user = await get_or_create_user(test_user_data)
|
||||
user, _ = await get_or_create_user(test_user_data)
|
||||
return user
|
||||
|
||||
|
||||
|
||||
@@ -146,7 +146,7 @@ async def create_test_user() -> User:
|
||||
"email": "testuser@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
user = await get_or_create_user(test_user_data)
|
||||
user, _ = await get_or_create_user(test_user_data)
|
||||
return user
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ async def create_test_user(alt_user: bool = False) -> User:
|
||||
"email": "testuser@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
user = await get_or_create_user(test_user_data)
|
||||
user, _ = await get_or_create_user(test_user_data)
|
||||
return user
|
||||
|
||||
|
||||
|
||||
@@ -684,6 +684,11 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
|
||||
screenshotone_api_key: str = Field(default="", description="ScreenshotOne API Key")
|
||||
|
||||
tally_api_key: str = Field(
|
||||
default="",
|
||||
description="Tally API key for form submission lookup on signup",
|
||||
)
|
||||
|
||||
apollo_api_key: str = Field(default="", description="Apollo API Key")
|
||||
smartlead_api_key: str = Field(default="", description="SmartLead API Key")
|
||||
zerobounce_api_key: str = Field(default="", description="ZeroBounce API Key")
|
||||
|
||||
@@ -151,7 +151,7 @@ class TestDataCreator:
|
||||
}
|
||||
|
||||
# Use the API function to create user in local database
|
||||
user = await get_or_create_user(user_data)
|
||||
user, _ = await get_or_create_user(user_data)
|
||||
users.append(user.model_dump())
|
||||
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user