fix(backend): harden Tally integration and get_or_create_user return type

- get_or_create_user now returns (User, is_new) tuple; Tally background
  task only fires for newly created users
- All callers updated to unpack the new return shape
- extract_business_understanding: add asyncio.wait_for timeout (30s),
  catch TimeoutError and JSONDecodeError
- _refresh_cache: fall back to full fetch when last_fetch exists but
  cached index is missing
- _fetch_all_submissions: add max_pages safety cap (100) to prevent
  infinite pagination loops
- populate_understanding_from_tally: mask emails in all log statements
  via _mask_email helper

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Swifty
2026-02-16 15:00:36 +01:00
parent 92d2da442d
commit 2a05a39de7
11 changed files with 167 additions and 55 deletions

View File

@@ -1016,7 +1016,7 @@ async def health_check() -> dict:
# Ensure health check user exists (required for FK constraint)
health_check_user_id = "health-check-user"
await get_or_create_user(
await get_or_create_user( # returns (User, is_new); we only need the side-effect
{
"sub": health_check_user_id,
"email": "health-check@system.local",

View File

@@ -45,7 +45,7 @@ async def setup_test_data():
"sub": f"test-user-{uuid.uuid4()}",
"email": f"test-{uuid.uuid4()}@example.com",
}
user = await get_or_create_user(user_data)
user, _ = await get_or_create_user(user_data)
# 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0]
@@ -168,7 +168,7 @@ async def setup_llm_test_data():
"sub": f"test-user-{uuid.uuid4()}",
"email": f"test-{uuid.uuid4()}@example.com",
}
user = await get_or_create_user(user_data)
user, _ = await get_or_create_user(user_data)
# 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0]
@@ -328,7 +328,7 @@ async def setup_firecrawl_test_data():
"sub": f"test-user-{uuid.uuid4()}",
"email": f"test-{uuid.uuid4()}@example.com",
}
user = await get_or_create_user(user_data)
user, _ = await get_or_create_user(user_data)
# 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0]

View File

@@ -136,19 +136,20 @@ _tally_background_tasks: set[asyncio.Task] = set()
dependencies=[Security(requires_user)],
)
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
user = await get_or_create_user(user_data)
user, is_new = await get_or_create_user(user_data)
# Fire-and-forget: populate business understanding from Tally form
try:
from backend.data.tally import populate_understanding_from_tally
if is_new:
try:
from backend.data.tally import populate_understanding_from_tally
task = asyncio.create_task(
populate_understanding_from_tally(user.id, user.email)
)
_tally_background_tasks.add(task)
task.add_done_callback(_tally_background_tasks.discard)
except Exception:
pass # Never block user creation
task = asyncio.create_task(
populate_understanding_from_tally(user.id, user.email)
)
_tally_background_tasks.add(task)
task.add_done_callback(_tally_background_tasks.discard)
except Exception:
pass # Never block user creation
return user.model_dump()
@@ -177,7 +178,7 @@ async def get_user_timezone_route(
user_data: dict = Security(get_jwt_payload),
) -> TimezoneResponse:
"""Get user timezone setting."""
user = await get_or_create_user(user_data)
user, _ = await get_or_create_user(user_data)
return TimezoneResponse(timezone=user.timezone)

View File

@@ -51,7 +51,7 @@ def test_get_or_create_user_route(
mocker.patch(
"backend.api.features.v1.get_or_create_user",
return_value=mock_user,
return_value=(mock_user, False),
)
response = client.post("/auth/user")

View File

@@ -1,5 +1,6 @@
"""Tally form integration: cache submissions, match by email, extract business understanding."""
import asyncio
import json
import logging
from datetime import datetime, timezone
@@ -32,6 +33,23 @@ _LAST_FETCH_TTL = 7200 # 2 hours
# Pagination
_PAGE_LIMIT = 500
_MAX_PAGES = 100
# LLM extraction timeout (seconds)
_LLM_TIMEOUT = 30
def _mask_email(email: str) -> str:
"""Mask an email for safe logging: 'alice@example.com' -> 'a***e@example.com'."""
try:
local, domain = email.rsplit("@", 1)
if len(local) <= 2:
masked_local = local[0] + "***"
else:
masked_local = local[0] + "***" + local[-1]
return f"{masked_local}@{domain}"
except (ValueError, IndexError):
return "***"
async def _fetch_tally_page(
@@ -63,6 +81,7 @@ async def _fetch_tally_page(
async def _fetch_all_submissions(
form_id: str,
start_date: Optional[str] = None,
max_pages: int = _MAX_PAGES,
) -> tuple[list[dict], list[dict]]:
"""Paginate through all Tally submissions. Returns (questions, submissions)."""
questions: list[dict] = []
@@ -81,6 +100,12 @@ async def _fetch_all_submissions(
total_pages = data.get("totalNumberOfPages", 1)
if page >= total_pages:
break
if page >= max_pages:
logger.warning(
f"Tally: hit max page cap ({max_pages}) for form {form_id}, "
f"API reports {total_pages} total pages"
)
break
page += 1
return questions, all_submissions
@@ -171,26 +196,31 @@ async def _refresh_cache(form_id: str) -> tuple[dict, list]:
last_fetch = await redis.get(last_fetch_key)
if last_fetch:
# Incremental fetch: only get new submissions since last fetch
logger.info(f"Tally incremental fetch since {last_fetch}")
questions, new_submissions = await _fetch_all_submissions(
form_id, start_date=last_fetch
)
# Try to load existing index to merge
# Try to load existing index for incremental merge
raw_existing = await redis.get(index_key)
existing_index: dict[str, dict] = {}
if raw_existing:
existing_index = json.loads(raw_existing)
if not questions:
raw_q = await redis.get(questions_key)
if raw_q:
questions = json.loads(raw_q)
if raw_existing is None:
# Index expired but last_fetch still present — fall back to full fetch
logger.info("Tally: last_fetch present but index missing, doing full fetch")
questions, submissions = await _fetch_all_submissions(form_id)
email_index = _build_email_index(submissions, questions)
else:
# Incremental fetch: only get new submissions since last fetch
logger.info(f"Tally incremental fetch since {last_fetch}")
questions, new_submissions = await _fetch_all_submissions(
form_id, start_date=last_fetch
)
new_index = _build_email_index(new_submissions, questions)
existing_index.update(new_index)
email_index = existing_index
existing_index: dict[str, dict] = json.loads(raw_existing)
if not questions:
raw_q = await redis.get(questions_key)
if raw_q:
questions = json.loads(raw_q)
new_index = _build_email_index(new_submissions, questions)
existing_index.update(new_index)
email_index = existing_index
else:
# Full initial fetch
logger.info("Tally full initial fetch")
@@ -301,25 +331,41 @@ Return ONLY valid JSON."""
async def extract_business_understanding(
formatted_text: str,
) -> BusinessUnderstandingInput:
"""Use an LLM to extract structured business understanding from form text."""
"""Use an LLM to extract structured business understanding from form text.
Raises on timeout or unparseable response so the caller can handle it.
"""
settings = Settings()
api_key = settings.secrets.open_router_api_key
client = AsyncOpenAI(api_key=api_key, base_url="https://openrouter.ai/api/v1")
response = await client.chat.completions.create(
model="openai/gpt-4o-mini",
messages=[
{
"role": "user",
"content": _EXTRACTION_PROMPT.format(submission_text=formatted_text),
}
],
response_format={"type": "json_object"},
temperature=0.0,
)
try:
response = await asyncio.wait_for(
client.chat.completions.create(
model="openai/gpt-4o-mini",
messages=[
{
"role": "user",
"content": _EXTRACTION_PROMPT.format(
submission_text=formatted_text
),
}
],
response_format={"type": "json_object"},
temperature=0.0,
),
timeout=_LLM_TIMEOUT,
)
except asyncio.TimeoutError:
logger.warning("Tally: LLM extraction timed out")
raise
raw = response.choices[0].message.content or "{}"
data = json.loads(raw)
try:
data = json.loads(raw)
except json.JSONDecodeError:
logger.warning("Tally: LLM returned invalid JSON, skipping extraction")
raise
# Filter out null values before constructing
cleaned = {k: v for k, v in data.items() if v is not None}
@@ -347,13 +393,16 @@ async def populate_understanding_from_tally(user_id: str, email: str) -> None:
return
# Look up submission by email
masked = _mask_email(email)
result = await find_submission_by_email(TALLY_FORM_ID, email)
if result is None:
logger.debug(f"Tally: no submission found for {email}")
logger.debug(f"Tally: no submission found for {masked}")
return
submission, questions = result
logger.info(f"Tally: found submission for {email}, extracting understanding")
logger.info(
f"Tally: found submission for {masked}, extracting understanding"
)
# Format and extract
formatted = format_submission_for_llm(submission, questions)
@@ -368,4 +417,6 @@ async def populate_understanding_from_tally(user_id: str, email: str) -> None:
logger.info(f"Tally: successfully populated understanding for user {user_id}")
except Exception:
logger.exception(f"Tally: error populating understanding for user {user_id}")
logger.exception(
f"Tally: error populating understanding for user {user_id}"
)

View File

@@ -7,6 +7,7 @@ import pytest
from backend.data.tally import (
_build_email_index,
_format_answer,
_mask_email,
find_submission_by_email,
format_submission_for_llm,
populate_understanding_from_tally,
@@ -303,3 +304,55 @@ async def test_populate_understanding_full_flow():
mock_extract.assert_awaited_once()
mock_upsert.assert_awaited_once_with("user-1", mock_input)
@pytest.mark.asyncio
async def test_populate_understanding_handles_llm_timeout():
"""LLM timeout is caught and doesn't raise."""
import asyncio
mock_settings = MagicMock()
mock_settings.secrets.tally_api_key = "test-key"
submission = {
"responses": [{"questionId": "q1", "value": "Alice"}],
}
with (
patch(
"backend.data.tally.get_business_understanding",
new_callable=AsyncMock,
return_value=None,
),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.find_submission_by_email",
new_callable=AsyncMock,
return_value=(submission, SAMPLE_QUESTIONS),
),
patch(
"backend.data.tally.extract_business_understanding",
new_callable=AsyncMock,
side_effect=asyncio.TimeoutError(),
),
patch(
"backend.data.tally.upsert_business_understanding",
new_callable=AsyncMock,
) as mock_upsert,
):
await populate_understanding_from_tally("user-1", "alice@example.com")
mock_upsert.assert_not_awaited()
# ── _mask_email ───────────────────────────────────────────────────────────────
def test_mask_email():
assert _mask_email("alice@example.com") == "a***e@example.com"
assert _mask_email("ab@example.com") == "a***@example.com"
assert _mask_email("a@example.com") == "a***@example.com"
def test_mask_email_invalid():
assert _mask_email("no-at-sign") == "***"

View File

@@ -29,7 +29,12 @@ cache_user_lookup = cached(maxsize=1000, ttl_seconds=300)
@cache_user_lookup
async def get_or_create_user(user_data: dict) -> User:
async def get_or_create_user(user_data: dict) -> tuple[User, bool]:
"""Get existing user or create a new one.
Returns:
A tuple of (User, is_new) where is_new is True if the user was just created.
"""
try:
user_id = user_data.get("sub")
if not user_id:
@@ -39,6 +44,7 @@ async def get_or_create_user(user_data: dict) -> User:
if not user_email:
raise HTTPException(status_code=401, detail="Email not found in token")
is_new = False
user = await prisma.user.find_unique(where={"id": user_id})
if not user:
user = await prisma.user.create(
@@ -48,8 +54,9 @@ async def get_or_create_user(user_data: dict) -> User:
name=user_data.get("user_metadata", {}).get("name"),
)
)
is_new = True
return User.from_db(user)
return User.from_db(user), is_new
except Exception as e:
raise DatabaseError(f"Failed to get or create user {user_data}: {e}") from e

View File

@@ -31,7 +31,7 @@ async def create_test_user() -> User:
"email": "testuser@example.com",
"name": "Test User",
}
user = await get_or_create_user(test_user_data)
user, _ = await get_or_create_user(test_user_data)
return user

View File

@@ -146,7 +146,7 @@ async def create_test_user() -> User:
"email": "testuser@example.com",
"name": "Test User",
}
user = await get_or_create_user(test_user_data)
user, _ = await get_or_create_user(test_user_data)
return user

View File

@@ -21,7 +21,7 @@ async def create_test_user(alt_user: bool = False) -> User:
"email": "testuser@example.com",
"name": "Test User",
}
user = await get_or_create_user(test_user_data)
user, _ = await get_or_create_user(test_user_data)
return user

View File

@@ -151,7 +151,7 @@ class TestDataCreator:
}
# Use the API function to create user in local database
user = await get_or_create_user(user_data)
user, _ = await get_or_create_user(user_data)
users.append(user.model_dump())
except Exception as e: