mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-16 17:55:55 -05:00
fix(backend): harden Tally integration and get_or_create_user return type
- get_or_create_user now returns (User, is_new) tuple; Tally background task only fires for newly created users - All callers updated to unpack the new return shape - extract_business_understanding: add asyncio.wait_for timeout (30s), catch TimeoutError and JSONDecodeError - _refresh_cache: fall back to full fetch when last_fetch exists but cached index is missing - _fetch_all_submissions: add max_pages safety cap (100) to prevent infinite pagination loops - populate_understanding_from_tally: mask emails in all log statements via _mask_email helper Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -1016,7 +1016,7 @@ async def health_check() -> dict:
|
||||
|
||||
# Ensure health check user exists (required for FK constraint)
|
||||
health_check_user_id = "health-check-user"
|
||||
await get_or_create_user(
|
||||
await get_or_create_user( # returns (User, is_new); we only need the side-effect
|
||||
{
|
||||
"sub": health_check_user_id,
|
||||
"email": "health-check@system.local",
|
||||
|
||||
@@ -45,7 +45,7 @@ async def setup_test_data():
|
||||
"sub": f"test-user-{uuid.uuid4()}",
|
||||
"email": f"test-{uuid.uuid4()}@example.com",
|
||||
}
|
||||
user = await get_or_create_user(user_data)
|
||||
user, _ = await get_or_create_user(user_data)
|
||||
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
@@ -168,7 +168,7 @@ async def setup_llm_test_data():
|
||||
"sub": f"test-user-{uuid.uuid4()}",
|
||||
"email": f"test-{uuid.uuid4()}@example.com",
|
||||
}
|
||||
user = await get_or_create_user(user_data)
|
||||
user, _ = await get_or_create_user(user_data)
|
||||
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
@@ -328,7 +328,7 @@ async def setup_firecrawl_test_data():
|
||||
"sub": f"test-user-{uuid.uuid4()}",
|
||||
"email": f"test-{uuid.uuid4()}@example.com",
|
||||
}
|
||||
user = await get_or_create_user(user_data)
|
||||
user, _ = await get_or_create_user(user_data)
|
||||
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
|
||||
@@ -136,19 +136,20 @@ _tally_background_tasks: set[asyncio.Task] = set()
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
user = await get_or_create_user(user_data)
|
||||
user, is_new = await get_or_create_user(user_data)
|
||||
|
||||
# Fire-and-forget: populate business understanding from Tally form
|
||||
try:
|
||||
from backend.data.tally import populate_understanding_from_tally
|
||||
if is_new:
|
||||
try:
|
||||
from backend.data.tally import populate_understanding_from_tally
|
||||
|
||||
task = asyncio.create_task(
|
||||
populate_understanding_from_tally(user.id, user.email)
|
||||
)
|
||||
_tally_background_tasks.add(task)
|
||||
task.add_done_callback(_tally_background_tasks.discard)
|
||||
except Exception:
|
||||
pass # Never block user creation
|
||||
task = asyncio.create_task(
|
||||
populate_understanding_from_tally(user.id, user.email)
|
||||
)
|
||||
_tally_background_tasks.add(task)
|
||||
task.add_done_callback(_tally_background_tasks.discard)
|
||||
except Exception:
|
||||
pass # Never block user creation
|
||||
|
||||
return user.model_dump()
|
||||
|
||||
@@ -177,7 +178,7 @@ async def get_user_timezone_route(
|
||||
user_data: dict = Security(get_jwt_payload),
|
||||
) -> TimezoneResponse:
|
||||
"""Get user timezone setting."""
|
||||
user = await get_or_create_user(user_data)
|
||||
user, _ = await get_or_create_user(user_data)
|
||||
return TimezoneResponse(timezone=user.timezone)
|
||||
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ def test_get_or_create_user_route(
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_or_create_user",
|
||||
return_value=mock_user,
|
||||
return_value=(mock_user, False),
|
||||
)
|
||||
|
||||
response = client.post("/auth/user")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tally form integration: cache submissions, match by email, extract business understanding."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
@@ -32,6 +33,23 @@ _LAST_FETCH_TTL = 7200 # 2 hours
|
||||
|
||||
# Pagination
|
||||
_PAGE_LIMIT = 500
|
||||
_MAX_PAGES = 100
|
||||
|
||||
# LLM extraction timeout (seconds)
|
||||
_LLM_TIMEOUT = 30
|
||||
|
||||
|
||||
def _mask_email(email: str) -> str:
|
||||
"""Mask an email for safe logging: 'alice@example.com' -> 'a***e@example.com'."""
|
||||
try:
|
||||
local, domain = email.rsplit("@", 1)
|
||||
if len(local) <= 2:
|
||||
masked_local = local[0] + "***"
|
||||
else:
|
||||
masked_local = local[0] + "***" + local[-1]
|
||||
return f"{masked_local}@{domain}"
|
||||
except (ValueError, IndexError):
|
||||
return "***"
|
||||
|
||||
|
||||
async def _fetch_tally_page(
|
||||
@@ -63,6 +81,7 @@ async def _fetch_tally_page(
|
||||
async def _fetch_all_submissions(
|
||||
form_id: str,
|
||||
start_date: Optional[str] = None,
|
||||
max_pages: int = _MAX_PAGES,
|
||||
) -> tuple[list[dict], list[dict]]:
|
||||
"""Paginate through all Tally submissions. Returns (questions, submissions)."""
|
||||
questions: list[dict] = []
|
||||
@@ -81,6 +100,12 @@ async def _fetch_all_submissions(
|
||||
total_pages = data.get("totalNumberOfPages", 1)
|
||||
if page >= total_pages:
|
||||
break
|
||||
if page >= max_pages:
|
||||
logger.warning(
|
||||
f"Tally: hit max page cap ({max_pages}) for form {form_id}, "
|
||||
f"API reports {total_pages} total pages"
|
||||
)
|
||||
break
|
||||
page += 1
|
||||
|
||||
return questions, all_submissions
|
||||
@@ -171,26 +196,31 @@ async def _refresh_cache(form_id: str) -> tuple[dict, list]:
|
||||
last_fetch = await redis.get(last_fetch_key)
|
||||
|
||||
if last_fetch:
|
||||
# Incremental fetch: only get new submissions since last fetch
|
||||
logger.info(f"Tally incremental fetch since {last_fetch}")
|
||||
questions, new_submissions = await _fetch_all_submissions(
|
||||
form_id, start_date=last_fetch
|
||||
)
|
||||
|
||||
# Try to load existing index to merge
|
||||
# Try to load existing index for incremental merge
|
||||
raw_existing = await redis.get(index_key)
|
||||
existing_index: dict[str, dict] = {}
|
||||
if raw_existing:
|
||||
existing_index = json.loads(raw_existing)
|
||||
|
||||
if not questions:
|
||||
raw_q = await redis.get(questions_key)
|
||||
if raw_q:
|
||||
questions = json.loads(raw_q)
|
||||
if raw_existing is None:
|
||||
# Index expired but last_fetch still present — fall back to full fetch
|
||||
logger.info("Tally: last_fetch present but index missing, doing full fetch")
|
||||
questions, submissions = await _fetch_all_submissions(form_id)
|
||||
email_index = _build_email_index(submissions, questions)
|
||||
else:
|
||||
# Incremental fetch: only get new submissions since last fetch
|
||||
logger.info(f"Tally incremental fetch since {last_fetch}")
|
||||
questions, new_submissions = await _fetch_all_submissions(
|
||||
form_id, start_date=last_fetch
|
||||
)
|
||||
|
||||
new_index = _build_email_index(new_submissions, questions)
|
||||
existing_index.update(new_index)
|
||||
email_index = existing_index
|
||||
existing_index: dict[str, dict] = json.loads(raw_existing)
|
||||
|
||||
if not questions:
|
||||
raw_q = await redis.get(questions_key)
|
||||
if raw_q:
|
||||
questions = json.loads(raw_q)
|
||||
|
||||
new_index = _build_email_index(new_submissions, questions)
|
||||
existing_index.update(new_index)
|
||||
email_index = existing_index
|
||||
else:
|
||||
# Full initial fetch
|
||||
logger.info("Tally full initial fetch")
|
||||
@@ -301,25 +331,41 @@ Return ONLY valid JSON."""
|
||||
async def extract_business_understanding(
|
||||
formatted_text: str,
|
||||
) -> BusinessUnderstandingInput:
|
||||
"""Use an LLM to extract structured business understanding from form text."""
|
||||
"""Use an LLM to extract structured business understanding from form text.
|
||||
|
||||
Raises on timeout or unparseable response so the caller can handle it.
|
||||
"""
|
||||
settings = Settings()
|
||||
api_key = settings.secrets.open_router_api_key
|
||||
client = AsyncOpenAI(api_key=api_key, base_url="https://openrouter.ai/api/v1")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="openai/gpt-4o-mini",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": _EXTRACTION_PROMPT.format(submission_text=formatted_text),
|
||||
}
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.0,
|
||||
)
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
client.chat.completions.create(
|
||||
model="openai/gpt-4o-mini",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": _EXTRACTION_PROMPT.format(
|
||||
submission_text=formatted_text
|
||||
),
|
||||
}
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.0,
|
||||
),
|
||||
timeout=_LLM_TIMEOUT,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Tally: LLM extraction timed out")
|
||||
raise
|
||||
|
||||
raw = response.choices[0].message.content or "{}"
|
||||
data = json.loads(raw)
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Tally: LLM returned invalid JSON, skipping extraction")
|
||||
raise
|
||||
|
||||
# Filter out null values before constructing
|
||||
cleaned = {k: v for k, v in data.items() if v is not None}
|
||||
@@ -347,13 +393,16 @@ async def populate_understanding_from_tally(user_id: str, email: str) -> None:
|
||||
return
|
||||
|
||||
# Look up submission by email
|
||||
masked = _mask_email(email)
|
||||
result = await find_submission_by_email(TALLY_FORM_ID, email)
|
||||
if result is None:
|
||||
logger.debug(f"Tally: no submission found for {email}")
|
||||
logger.debug(f"Tally: no submission found for {masked}")
|
||||
return
|
||||
|
||||
submission, questions = result
|
||||
logger.info(f"Tally: found submission for {email}, extracting understanding")
|
||||
logger.info(
|
||||
f"Tally: found submission for {masked}, extracting understanding"
|
||||
)
|
||||
|
||||
# Format and extract
|
||||
formatted = format_submission_for_llm(submission, questions)
|
||||
@@ -368,4 +417,6 @@ async def populate_understanding_from_tally(user_id: str, email: str) -> None:
|
||||
logger.info(f"Tally: successfully populated understanding for user {user_id}")
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Tally: error populating understanding for user {user_id}")
|
||||
logger.exception(
|
||||
f"Tally: error populating understanding for user {user_id}"
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ import pytest
|
||||
from backend.data.tally import (
|
||||
_build_email_index,
|
||||
_format_answer,
|
||||
_mask_email,
|
||||
find_submission_by_email,
|
||||
format_submission_for_llm,
|
||||
populate_understanding_from_tally,
|
||||
@@ -303,3 +304,55 @@ async def test_populate_understanding_full_flow():
|
||||
|
||||
mock_extract.assert_awaited_once()
|
||||
mock_upsert.assert_awaited_once_with("user-1", mock_input)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_understanding_handles_llm_timeout():
|
||||
"""LLM timeout is caught and doesn't raise."""
|
||||
import asyncio
|
||||
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.tally_api_key = "test-key"
|
||||
|
||||
submission = {
|
||||
"responses": [{"questionId": "q1", "value": "Alice"}],
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.tally.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(submission, SAMPLE_QUESTIONS),
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally.extract_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=asyncio.TimeoutError(),
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally.upsert_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_upsert,
|
||||
):
|
||||
await populate_understanding_from_tally("user-1", "alice@example.com")
|
||||
|
||||
mock_upsert.assert_not_awaited()
|
||||
|
||||
|
||||
# ── _mask_email ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_mask_email():
|
||||
assert _mask_email("alice@example.com") == "a***e@example.com"
|
||||
assert _mask_email("ab@example.com") == "a***@example.com"
|
||||
assert _mask_email("a@example.com") == "a***@example.com"
|
||||
|
||||
|
||||
def test_mask_email_invalid():
|
||||
assert _mask_email("no-at-sign") == "***"
|
||||
|
||||
@@ -29,7 +29,12 @@ cache_user_lookup = cached(maxsize=1000, ttl_seconds=300)
|
||||
|
||||
|
||||
@cache_user_lookup
|
||||
async def get_or_create_user(user_data: dict) -> User:
|
||||
async def get_or_create_user(user_data: dict) -> tuple[User, bool]:
|
||||
"""Get existing user or create a new one.
|
||||
|
||||
Returns:
|
||||
A tuple of (User, is_new) where is_new is True if the user was just created.
|
||||
"""
|
||||
try:
|
||||
user_id = user_data.get("sub")
|
||||
if not user_id:
|
||||
@@ -39,6 +44,7 @@ async def get_or_create_user(user_data: dict) -> User:
|
||||
if not user_email:
|
||||
raise HTTPException(status_code=401, detail="Email not found in token")
|
||||
|
||||
is_new = False
|
||||
user = await prisma.user.find_unique(where={"id": user_id})
|
||||
if not user:
|
||||
user = await prisma.user.create(
|
||||
@@ -48,8 +54,9 @@ async def get_or_create_user(user_data: dict) -> User:
|
||||
name=user_data.get("user_metadata", {}).get("name"),
|
||||
)
|
||||
)
|
||||
is_new = True
|
||||
|
||||
return User.from_db(user)
|
||||
return User.from_db(user), is_new
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to get or create user {user_data}: {e}") from e
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ async def create_test_user() -> User:
|
||||
"email": "testuser@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
user = await get_or_create_user(test_user_data)
|
||||
user, _ = await get_or_create_user(test_user_data)
|
||||
return user
|
||||
|
||||
|
||||
|
||||
@@ -146,7 +146,7 @@ async def create_test_user() -> User:
|
||||
"email": "testuser@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
user = await get_or_create_user(test_user_data)
|
||||
user, _ = await get_or_create_user(test_user_data)
|
||||
return user
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ async def create_test_user(alt_user: bool = False) -> User:
|
||||
"email": "testuser@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
user = await get_or_create_user(test_user_data)
|
||||
user, _ = await get_or_create_user(test_user_data)
|
||||
return user
|
||||
|
||||
|
||||
|
||||
@@ -151,7 +151,7 @@ class TestDataCreator:
|
||||
}
|
||||
|
||||
# Use the API function to create user in local database
|
||||
user = await get_or_create_user(user_data)
|
||||
user, _ = await get_or_create_user(user_data)
|
||||
users.append(user.model_dump())
|
||||
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user