mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
2 Commits
dev
...
swiftyos/c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7d6338c1fc | ||
|
|
ea69536165 |
@@ -9,6 +9,21 @@ import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_id() -> str:
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_id() -> str:
|
||||
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_test_user(test_user_id: str) -> str:
|
||||
return test_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def configured_snapshot(snapshot: Snapshot) -> Snapshot:
|
||||
"""Pre-configured snapshot fixture with standard settings."""
|
||||
|
||||
@@ -29,6 +29,7 @@ from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
from backend.api.model import (
|
||||
BusinessUnderstandingPromptsResponse,
|
||||
CreateAPIKeyRequest,
|
||||
CreateAPIKeyResponse,
|
||||
CreateGraph,
|
||||
@@ -54,6 +55,7 @@ from backend.data.credit import (
|
||||
get_user_credit_model,
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.db_accessors import understanding_db
|
||||
from backend.data.graph import GraphSettings
|
||||
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
@@ -158,6 +160,22 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
return user.model_dump()
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/auth/user/understanding/prompts",
|
||||
summary="Get business understanding prompts",
|
||||
tags=["auth"],
|
||||
dependencies=[Security(requires_user)],
|
||||
response_model=BusinessUnderstandingPromptsResponse,
|
||||
)
|
||||
async def get_business_understanding_prompts_route(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> BusinessUnderstandingPromptsResponse:
|
||||
understanding = await understanding_db().get_business_understanding(user_id)
|
||||
return BusinessUnderstandingPromptsResponse(
|
||||
prompts=understanding.prompts if understanding else []
|
||||
)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/auth/user/email",
|
||||
summary="Update user email",
|
||||
|
||||
@@ -89,6 +89,48 @@ def test_update_user_email_route(
|
||||
)
|
||||
|
||||
|
||||
def test_get_business_understanding_prompts_route(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_understanding_db = Mock()
|
||||
mock_understanding_db.get_business_understanding = AsyncMock(
|
||||
return_value=Mock(prompts=["Prompt one", "Prompt two", "Prompt three"])
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.understanding_db",
|
||||
return_value=mock_understanding_db,
|
||||
)
|
||||
|
||||
response = client.get("/auth/user/understanding/prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": ["Prompt one", "Prompt two", "Prompt three"]}
|
||||
mock_understanding_db.get_business_understanding.assert_awaited_once_with(
|
||||
test_user_id
|
||||
)
|
||||
|
||||
|
||||
def test_get_business_understanding_prompts_route_returns_empty_list(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_understanding_db = Mock()
|
||||
mock_understanding_db.get_business_understanding = AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.understanding_db",
|
||||
return_value=mock_understanding_db,
|
||||
)
|
||||
|
||||
response = client.get("/auth/user/understanding/prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": []}
|
||||
mock_understanding_db.get_business_understanding.assert_awaited_once_with(
|
||||
test_user_id
|
||||
)
|
||||
|
||||
|
||||
# Blocks endpoints tests
|
||||
def test_get_graph_blocks(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
|
||||
@@ -85,6 +85,10 @@ class UpdateTimezoneRequest(pydantic.BaseModel):
|
||||
timezone: TimeZoneName
|
||||
|
||||
|
||||
class BusinessUnderstandingPromptsResponse(pydantic.BaseModel):
|
||||
prompts: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class NotificationPayload(pydantic.BaseModel):
|
||||
type: str
|
||||
event: str
|
||||
|
||||
@@ -91,6 +91,7 @@ from backend.data.notifications import (
|
||||
from backend.data.onboarding import increment_onboarding_runs
|
||||
from backend.data.understanding import (
|
||||
get_business_understanding,
|
||||
update_business_understanding_prompts,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
from backend.data.user import (
|
||||
@@ -311,6 +312,7 @@ class DatabaseManager(AppService):
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = _(get_business_understanding)
|
||||
update_business_understanding_prompts = _(update_business_understanding_prompts)
|
||||
upsert_business_understanding = _(upsert_business_understanding)
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
@@ -379,6 +381,11 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
backfill_missing_embeddings = _(d.backfill_missing_embeddings)
|
||||
cleanup_orphaned_embeddings = _(d.cleanup_orphaned_embeddings)
|
||||
|
||||
# Understanding
|
||||
get_business_understanding = _(d.get_business_understanding)
|
||||
update_business_understanding_prompts = _(d.update_business_understanding_prompts)
|
||||
upsert_business_understanding = _(d.upsert_business_understanding)
|
||||
|
||||
|
||||
class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
@@ -493,6 +500,7 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = d.get_business_understanding
|
||||
update_business_understanding_prompts = d.update_business_understanding_prompts
|
||||
upsert_business_understanding = d.upsert_business_understanding
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
|
||||
@@ -10,10 +10,13 @@ from openai import AsyncOpenAI
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstanding,
|
||||
BusinessUnderstandingInput,
|
||||
get_business_understanding,
|
||||
update_business_understanding_prompts,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
from backend.data.understanding_prompts import generate_understanding_prompts
|
||||
from backend.util.request import Requests
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -418,9 +421,31 @@ async def populate_understanding_from_tally(user_id: str, email: str) -> None:
|
||||
|
||||
understanding_input = await extract_business_understanding(formatted)
|
||||
|
||||
# Upsert into database
|
||||
await upsert_business_understanding(user_id, understanding_input)
|
||||
understanding = await upsert_business_understanding(
|
||||
user_id, understanding_input
|
||||
)
|
||||
await _generate_and_store_prompts(user_id, understanding)
|
||||
logger.info(f"Tally: successfully populated understanding for user {user_id}")
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Tally: error populating understanding for user {user_id}")
|
||||
|
||||
|
||||
async def _generate_and_store_prompts(
|
||||
user_id: str, understanding: BusinessUnderstanding
|
||||
) -> None:
|
||||
try:
|
||||
prompts = await generate_understanding_prompts(understanding)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Tally: skipping quick prompt generation for {user_id}: {e}")
|
||||
return
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Tally: failed to generate quick prompts for understanding {user_id}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
await update_business_understanding_prompts(user_id, prompts)
|
||||
except Exception:
|
||||
logger.exception(f"Tally: failed to store quick prompts for user {user_id}")
|
||||
|
||||
@@ -284,6 +284,7 @@ async def test_populate_understanding_full_flow():
|
||||
],
|
||||
}
|
||||
mock_input = MagicMock()
|
||||
mock_understanding = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -305,12 +306,35 @@ async def test_populate_understanding_full_flow():
|
||||
patch(
|
||||
"backend.data.tally.upsert_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_understanding,
|
||||
) as mock_upsert,
|
||||
patch(
|
||||
"backend.data.tally.generate_understanding_prompts",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[
|
||||
"Help me automate customer support",
|
||||
"Find repetitive support work",
|
||||
"Show me faster support workflows",
|
||||
],
|
||||
) as mock_generate_prompts,
|
||||
patch(
|
||||
"backend.data.tally.update_business_understanding_prompts",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update_prompts,
|
||||
):
|
||||
await populate_understanding_from_tally("user-1", "alice@example.com")
|
||||
|
||||
mock_extract.assert_awaited_once()
|
||||
mock_upsert.assert_awaited_once_with("user-1", mock_input)
|
||||
mock_generate_prompts.assert_awaited_once_with(mock_understanding)
|
||||
mock_update_prompts.assert_awaited_once_with(
|
||||
"user-1",
|
||||
[
|
||||
"Help me automate customer support",
|
||||
"Find repetitive support work",
|
||||
"Show me faster support workflows",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -352,6 +376,55 @@ async def test_populate_understanding_handles_llm_timeout():
|
||||
mock_upsert.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_understanding_keeps_understanding_when_prompt_generation_fails():
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.tally_api_key = "test-key"
|
||||
|
||||
submission = {
|
||||
"responses": [{"questionId": "q1", "value": "Alice"}],
|
||||
}
|
||||
mock_input = MagicMock()
|
||||
mock_understanding = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.tally.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.data.tally.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.data.tally.find_submission_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(submission, SAMPLE_QUESTIONS),
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally.extract_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_input,
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally.upsert_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_understanding,
|
||||
) as mock_upsert,
|
||||
patch(
|
||||
"backend.data.tally.generate_understanding_prompts",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("bad prompts"),
|
||||
),
|
||||
patch(
|
||||
"backend.data.tally.update_business_understanding_prompts",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update_prompts,
|
||||
):
|
||||
await populate_understanding_from_tally("user-1", "alice@example.com")
|
||||
|
||||
mock_upsert.assert_awaited_once_with("user-1", mock_input)
|
||||
mock_update_prompts.assert_not_awaited()
|
||||
|
||||
|
||||
# ── _mask_email ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
@@ -118,6 +118,7 @@ class BusinessUnderstanding(pydantic.BaseModel):
|
||||
# Current tools
|
||||
current_software: list[str] = pydantic.Field(default_factory=list)
|
||||
existing_automation: list[str] = pydantic.Field(default_factory=list)
|
||||
prompts: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
# Additional context
|
||||
additional_notes: Optional[str] = None
|
||||
@@ -148,6 +149,7 @@ class BusinessUnderstanding(pydantic.BaseModel):
|
||||
automation_goals=_json_to_list(business.get("automation_goals")),
|
||||
current_software=_json_to_list(business.get("current_software")),
|
||||
existing_automation=_json_to_list(business.get("existing_automation")),
|
||||
prompts=_json_to_list(business.get("prompts")),
|
||||
additional_notes=business.get("additional_notes"),
|
||||
)
|
||||
|
||||
@@ -313,6 +315,40 @@ async def upsert_business_understanding(
|
||||
return understanding
|
||||
|
||||
|
||||
async def update_business_understanding_prompts(
|
||||
user_id: str, prompts: list[str]
|
||||
) -> Optional[BusinessUnderstanding]:
|
||||
"""Update derived quick prompts for an existing business understanding."""
|
||||
existing = await CoPilotUnderstanding.prisma().find_unique(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
if existing is None:
|
||||
return None
|
||||
|
||||
existing_data: dict[str, Any] = {}
|
||||
if isinstance(existing.data, dict):
|
||||
existing_data = dict(existing.data)
|
||||
|
||||
existing_business: dict[str, Any] = {}
|
||||
if isinstance(existing_data.get("business"), dict):
|
||||
existing_business = dict(existing_data["business"])
|
||||
|
||||
existing_business["prompts"] = prompts
|
||||
existing_business["version"] = 1
|
||||
existing_data["business"] = existing_business
|
||||
|
||||
record = await CoPilotUnderstanding.prisma().update(
|
||||
where={"userId": user_id},
|
||||
data={"data": SafeJson(existing_data)},
|
||||
)
|
||||
if record is None:
|
||||
return None
|
||||
|
||||
understanding = BusinessUnderstanding.from_db(record)
|
||||
await _set_cache(user_id, understanding)
|
||||
return understanding
|
||||
|
||||
|
||||
async def clear_business_understanding(user_id: str) -> bool:
|
||||
"""Clear/delete business understanding for a user from both DB and cache."""
|
||||
# Delete from cache first
|
||||
|
||||
115
autogpt_platform/backend/backend/data/understanding_prompts.py
Normal file
115
autogpt_platform/backend/backend/data/understanding_prompts.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Helpers for generating quick prompts from saved business understanding."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstanding,
|
||||
format_understanding_for_prompt,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_LLM_TIMEOUT = 30
|
||||
|
||||
_PROMPTS_PROMPT = """\
|
||||
You generate three short starter prompts for a user to click in a chat UI.
|
||||
|
||||
Return a JSON object with this exact shape:
|
||||
{"prompts":["...","...","..."]}
|
||||
|
||||
Requirements:
|
||||
- Exactly 3 prompts
|
||||
- Each prompt must be written in first person, as if the user is speaking
|
||||
- Each prompt must be shorter than 20 words
|
||||
- Keep them specific to the user's business context
|
||||
- Do not number the prompts
|
||||
- Do not add labels or explanations
|
||||
|
||||
Business context:
|
||||
"""
|
||||
|
||||
_PROMPTS_SUFFIX = "\n\nReturn ONLY valid JSON."
|
||||
|
||||
|
||||
def has_prompt_generation_context(understanding: BusinessUnderstanding) -> bool:
|
||||
return bool(format_understanding_for_prompt(understanding).strip())
|
||||
|
||||
|
||||
def _normalize_prompt(prompt: str) -> str:
|
||||
return " ".join(prompt.split())
|
||||
|
||||
|
||||
def _validate_prompts(value: object) -> list[str]:
|
||||
if not isinstance(value, list) or len(value) != 3:
|
||||
raise ValueError("Prompt response must contain exactly three prompts")
|
||||
|
||||
prompts: list[str] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
for item in value:
|
||||
if not isinstance(item, str):
|
||||
raise ValueError("Each prompt must be a string")
|
||||
|
||||
prompt = _normalize_prompt(item)
|
||||
if not prompt:
|
||||
raise ValueError("Prompts cannot be empty")
|
||||
if len(prompt.split()) >= 20:
|
||||
raise ValueError("Prompts must be fewer than 20 words")
|
||||
if prompt in seen:
|
||||
raise ValueError("Prompts must be unique")
|
||||
|
||||
seen.add(prompt)
|
||||
prompts.append(prompt)
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
async def generate_understanding_prompts(
|
||||
understanding: BusinessUnderstanding,
|
||||
) -> list[str]:
|
||||
"""Generate validated quick prompts from a saved understanding snapshot."""
|
||||
context = format_understanding_for_prompt(understanding)
|
||||
if not context.strip():
|
||||
raise ValueError("Understanding does not contain usable context")
|
||||
|
||||
settings = Settings()
|
||||
client = AsyncOpenAI(
|
||||
api_key=settings.secrets.open_router_api_key,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
client.chat.completions.create(
|
||||
model="openai/gpt-4o-mini",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{_PROMPTS_PROMPT}{context}{_PROMPTS_SUFFIX}",
|
||||
}
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.2,
|
||||
),
|
||||
timeout=_LLM_TIMEOUT,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Understanding prompts: generation timed out")
|
||||
raise
|
||||
|
||||
raw = response.choices[0].message.content or "{}"
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Understanding prompts: invalid JSON response")
|
||||
raise
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("Prompt response must be a JSON object")
|
||||
|
||||
return _validate_prompts(data.get("prompts"))
|
||||
@@ -0,0 +1,145 @@
|
||||
"""Tests for backend.data.understanding_prompts."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.understanding import BusinessUnderstanding
|
||||
from backend.data.understanding_prompts import generate_understanding_prompts
|
||||
|
||||
|
||||
def make_understanding(**overrides) -> BusinessUnderstanding:
|
||||
data = {
|
||||
"id": "understanding-1",
|
||||
"user_id": "user-1",
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
"industry": "Customer support",
|
||||
"pain_points": ["manual ticket triage"],
|
||||
"automation_goals": ["speed up support responses"],
|
||||
}
|
||||
data.update(overrides)
|
||||
return BusinessUnderstanding(**data)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_understanding_prompts_success():
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
{
|
||||
"prompts": [
|
||||
"Help me automate customer support triage",
|
||||
"Show me how to speed up support replies",
|
||||
"Find repetitive work in our support process",
|
||||
]
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch(
|
||||
"backend.data.understanding_prompts.AsyncOpenAI", return_value=mock_client
|
||||
):
|
||||
prompts = await generate_understanding_prompts(make_understanding())
|
||||
|
||||
assert prompts == [
|
||||
"Help me automate customer support triage",
|
||||
"Show me how to speed up support replies",
|
||||
"Find repetitive work in our support process",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_understanding_prompts_rejects_duplicates():
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
{
|
||||
"prompts": [
|
||||
"Help me automate customer support",
|
||||
"Help me automate customer support",
|
||||
"Find repetitive support work",
|
||||
]
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.understanding_prompts.AsyncOpenAI", return_value=mock_client
|
||||
),
|
||||
pytest.raises(ValueError, match="unique"),
|
||||
):
|
||||
await generate_understanding_prompts(make_understanding())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_understanding_prompts_rejects_long_prompt():
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
{
|
||||
"prompts": [
|
||||
"Please help me automate every part of our customer support workflow starting with ticket triage routing follow-up escalation and reporting today",
|
||||
"Show me better support workflows",
|
||||
"Find support busywork for me",
|
||||
]
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.understanding_prompts.AsyncOpenAI", return_value=mock_client
|
||||
),
|
||||
pytest.raises(ValueError, match="fewer than 20 words"),
|
||||
):
|
||||
await generate_understanding_prompts(make_understanding())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_understanding_prompts_rejects_invalid_shape():
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = json.dumps(
|
||||
{"prompts": ["Help me automate support", "Find repetitive work"]}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.understanding_prompts.AsyncOpenAI", return_value=mock_client
|
||||
),
|
||||
pytest.raises(ValueError, match="exactly three prompts"),
|
||||
):
|
||||
await generate_understanding_prompts(make_understanding())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_understanding_prompts_timeout():
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.side_effect = asyncio.TimeoutError()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.understanding_prompts.AsyncOpenAI", return_value=mock_client
|
||||
),
|
||||
patch("backend.data.understanding_prompts._LLM_TIMEOUT", 0.001),
|
||||
pytest.raises(asyncio.TimeoutError),
|
||||
):
|
||||
await generate_understanding_prompts(make_understanding())
|
||||
@@ -0,0 +1,134 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Backfill quick prompts for saved business understanding records."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
import click
|
||||
from prisma.models import CoPilotUnderstanding
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstanding,
|
||||
update_business_understanding_prompts,
|
||||
)
|
||||
from backend.data.understanding_prompts import (
|
||||
generate_understanding_prompts,
|
||||
has_prompt_generation_context,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def backfill_understanding_prompts(
|
||||
batch_size: int = 100,
|
||||
limit: int | None = None,
|
||||
dry_run: bool = False,
|
||||
) -> dict[str, int]:
|
||||
summary = {
|
||||
"scanned": 0,
|
||||
"candidates": 0,
|
||||
"eligible": 0,
|
||||
"updated": 0,
|
||||
"failed": 0,
|
||||
"skipped_existing": 0,
|
||||
"skipped_no_context": 0,
|
||||
}
|
||||
offset = 0
|
||||
|
||||
while True:
|
||||
records = await CoPilotUnderstanding.prisma().find_many(
|
||||
order={"id": "asc"},
|
||||
skip=offset,
|
||||
take=batch_size,
|
||||
)
|
||||
if not records:
|
||||
break
|
||||
|
||||
offset += len(records)
|
||||
|
||||
for record in records:
|
||||
summary["scanned"] += 1
|
||||
understanding = BusinessUnderstanding.from_db(record)
|
||||
|
||||
if understanding.prompts:
|
||||
summary["skipped_existing"] += 1
|
||||
continue
|
||||
|
||||
if limit is not None and summary["candidates"] >= limit:
|
||||
logger.info("Reached backfill limit of %s records", limit)
|
||||
return summary
|
||||
|
||||
summary["candidates"] += 1
|
||||
|
||||
if not has_prompt_generation_context(understanding):
|
||||
summary["skipped_no_context"] += 1
|
||||
continue
|
||||
|
||||
summary["eligible"] += 1
|
||||
if dry_run:
|
||||
continue
|
||||
|
||||
try:
|
||||
prompts = await generate_understanding_prompts(understanding)
|
||||
updated = await update_business_understanding_prompts(
|
||||
understanding.user_id, prompts
|
||||
)
|
||||
except Exception:
|
||||
summary["failed"] += 1
|
||||
logger.exception(
|
||||
"Failed to backfill prompts for user %s", understanding.user_id
|
||||
)
|
||||
continue
|
||||
|
||||
if updated is None:
|
||||
summary["failed"] += 1
|
||||
logger.warning(
|
||||
"Skipped backfill for user %s because the record no longer exists",
|
||||
understanding.user_id,
|
||||
)
|
||||
continue
|
||||
|
||||
summary["updated"] += 1
|
||||
|
||||
logger.info("Understanding prompt backfill summary: %s", json.dumps(summary))
|
||||
return summary
|
||||
|
||||
|
||||
async def run_backfill(
|
||||
batch_size: int = 100,
|
||||
limit: int | None = None,
|
||||
dry_run: bool = False,
|
||||
) -> dict[str, int]:
|
||||
await db.connect()
|
||||
try:
|
||||
return await backfill_understanding_prompts(
|
||||
batch_size=batch_size,
|
||||
limit=limit,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
finally:
|
||||
await db.disconnect()
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--dry-run", is_flag=True, default=False, help="Report candidates only.")
|
||||
@click.option("--limit", type=click.IntRange(min=1), default=None)
|
||||
@click.option(
|
||||
"--batch-size", type=click.IntRange(min=1), default=100, show_default=True
|
||||
)
|
||||
def main(dry_run: bool, limit: int | None, batch_size: int) -> None:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
summary = asyncio.run(
|
||||
run_backfill(
|
||||
batch_size=batch_size,
|
||||
limit=limit,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
)
|
||||
click.echo(json.dumps(summary, indent=2, sort_keys=True))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,158 @@
|
||||
"""Tests for the understanding prompt backfill script."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from scripts.backfill_understanding_prompts import backfill_understanding_prompts
|
||||
|
||||
|
||||
def make_record(*, user_id: str, business: dict) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=f"understanding-{user_id}",
|
||||
userId=user_id,
|
||||
createdAt=datetime.now(timezone.utc),
|
||||
updatedAt=datetime.now(timezone.utc),
|
||||
data={"business": business},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_understanding_prompts_dry_run():
|
||||
record = make_record(
|
||||
user_id="user-1",
|
||||
business={"business_name": "Acme", "industry": "Support"},
|
||||
)
|
||||
prisma = AsyncMock()
|
||||
prisma.find_many.side_effect = [[record], []]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"scripts.backfill_understanding_prompts.CoPilotUnderstanding.prisma",
|
||||
return_value=prisma,
|
||||
),
|
||||
patch(
|
||||
"scripts.backfill_understanding_prompts.generate_understanding_prompts",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_generate,
|
||||
patch(
|
||||
"scripts.backfill_understanding_prompts.update_business_understanding_prompts",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update,
|
||||
):
|
||||
summary = await backfill_understanding_prompts(batch_size=10, dry_run=True)
|
||||
|
||||
assert summary == {
|
||||
"scanned": 1,
|
||||
"candidates": 1,
|
||||
"eligible": 1,
|
||||
"updated": 0,
|
||||
"failed": 0,
|
||||
"skipped_existing": 0,
|
||||
"skipped_no_context": 0,
|
||||
}
|
||||
mock_generate.assert_not_awaited()
|
||||
mock_update.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_understanding_prompts_skips_existing_prompts():
|
||||
record = make_record(
|
||||
user_id="user-1",
|
||||
business={
|
||||
"business_name": "Acme",
|
||||
"prompts": ["Prompt one", "Prompt two", "Prompt three"],
|
||||
},
|
||||
)
|
||||
prisma = AsyncMock()
|
||||
prisma.find_many.side_effect = [[record], []]
|
||||
|
||||
with patch(
|
||||
"scripts.backfill_understanding_prompts.CoPilotUnderstanding.prisma",
|
||||
return_value=prisma,
|
||||
):
|
||||
summary = await backfill_understanding_prompts(batch_size=10)
|
||||
|
||||
assert summary == {
|
||||
"scanned": 1,
|
||||
"candidates": 0,
|
||||
"eligible": 0,
|
||||
"updated": 0,
|
||||
"failed": 0,
|
||||
"skipped_existing": 1,
|
||||
"skipped_no_context": 0,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_understanding_prompts_updates_missing_prompts():
|
||||
record = make_record(
|
||||
user_id="user-1",
|
||||
business={"business_name": "Acme", "industry": "Support"},
|
||||
)
|
||||
prisma = AsyncMock()
|
||||
prisma.find_many.side_effect = [[record], []]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"scripts.backfill_understanding_prompts.CoPilotUnderstanding.prisma",
|
||||
return_value=prisma,
|
||||
),
|
||||
patch(
|
||||
"scripts.backfill_understanding_prompts.generate_understanding_prompts",
|
||||
new_callable=AsyncMock,
|
||||
return_value=["Prompt one", "Prompt two", "Prompt three"],
|
||||
) as mock_generate,
|
||||
patch(
|
||||
"scripts.backfill_understanding_prompts.update_business_understanding_prompts",
|
||||
new_callable=AsyncMock,
|
||||
return_value=object(),
|
||||
) as mock_update,
|
||||
):
|
||||
summary = await backfill_understanding_prompts(batch_size=10)
|
||||
|
||||
assert summary == {
|
||||
"scanned": 1,
|
||||
"candidates": 1,
|
||||
"eligible": 1,
|
||||
"updated": 1,
|
||||
"failed": 0,
|
||||
"skipped_existing": 0,
|
||||
"skipped_no_context": 0,
|
||||
}
|
||||
mock_generate.assert_awaited_once()
|
||||
mock_update.assert_awaited_once_with(
|
||||
"user-1", ["Prompt one", "Prompt two", "Prompt three"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_understanding_prompts_skips_records_without_context():
|
||||
record = make_record(user_id="user-1", business={})
|
||||
prisma = AsyncMock()
|
||||
prisma.find_many.side_effect = [[record], []]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"scripts.backfill_understanding_prompts.CoPilotUnderstanding.prisma",
|
||||
return_value=prisma,
|
||||
),
|
||||
patch(
|
||||
"scripts.backfill_understanding_prompts.generate_understanding_prompts",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_generate,
|
||||
):
|
||||
summary = await backfill_understanding_prompts(batch_size=10)
|
||||
|
||||
assert summary == {
|
||||
"scanned": 1,
|
||||
"candidates": 1,
|
||||
"eligible": 0,
|
||||
"updated": 0,
|
||||
"failed": 0,
|
||||
"skipped_existing": 0,
|
||||
"skipped_no_context": 1,
|
||||
}
|
||||
mock_generate.assert_not_awaited()
|
||||
@@ -7,11 +7,8 @@ import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { SpinnerGapIcon } from "@phosphor-icons/react";
|
||||
import { motion } from "framer-motion";
|
||||
import { useEffect, useState } from "react";
|
||||
import {
|
||||
getGreetingName,
|
||||
getInputPlaceholder,
|
||||
getQuickActions,
|
||||
} from "./helpers";
|
||||
import { getGreetingName, getInputPlaceholder } from "./helpers";
|
||||
import { useQuickActions } from "./useQuickActions";
|
||||
|
||||
interface Props {
|
||||
inputLayoutId: string;
|
||||
@@ -33,7 +30,7 @@ export function EmptySession({
|
||||
}: Props) {
|
||||
const { user } = useSupabase();
|
||||
const greetingName = getGreetingName(user);
|
||||
const quickActions = getQuickActions();
|
||||
const quickActions = useQuickActions(user);
|
||||
const [loadingAction, setLoadingAction] = useState<string | null>(null);
|
||||
const [inputPlaceholder, setInputPlaceholder] = useState(
|
||||
getInputPlaceholder(),
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
import { User } from "@supabase/supabase-js";
|
||||
|
||||
export const DEFAULT_QUICK_ACTIONS = [
|
||||
"I don't know where to start, just ask me stuff",
|
||||
"I do the same thing every week and it's killing me",
|
||||
"Help me find where I'm wasting my time",
|
||||
];
|
||||
|
||||
export function getInputPlaceholder(width?: number) {
|
||||
if (!width) return "What's your role and what eats up most of your day?";
|
||||
|
||||
@@ -12,12 +18,15 @@ export function getInputPlaceholder(width?: number) {
|
||||
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
|
||||
}
|
||||
|
||||
export function getQuickActions() {
|
||||
return [
|
||||
"I don't know where to start, just ask me stuff",
|
||||
"I do the same thing every week and it's killing me",
|
||||
"Help me find where I'm wasting my time",
|
||||
];
|
||||
export function getQuickActions(prompts?: string[] | null) {
|
||||
const normalizedPrompts =
|
||||
prompts
|
||||
?.map((prompt) => prompt.trim())
|
||||
.filter((prompt) => prompt.length > 0) ?? [];
|
||||
|
||||
return normalizedPrompts.length > 0
|
||||
? normalizedPrompts
|
||||
: DEFAULT_QUICK_ACTIONS;
|
||||
}
|
||||
|
||||
export function getGreetingName(user?: User | null) {
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
import { renderHook } from "@testing-library/react";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import type { User } from "@supabase/supabase-js";
|
||||
import { DEFAULT_QUICK_ACTIONS } from "./helpers";
|
||||
import { useQuickActions } from "./useQuickActions";
|
||||
|
||||
const { mockUseGetV1GetBusinessUnderstandingPrompts } = vi.hoisted(() => ({
|
||||
mockUseGetV1GetBusinessUnderstandingPrompts: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("@/app/api/__generated__/endpoints/auth/auth", () => ({
|
||||
useGetV1GetBusinessUnderstandingPrompts:
|
||||
mockUseGetV1GetBusinessUnderstandingPrompts,
|
||||
}));
|
||||
|
||||
function makeUser() {
|
||||
return { id: "user-1" } as User;
|
||||
}
|
||||
|
||||
describe("useQuickActions", () => {
|
||||
it("uses server prompts when available", () => {
|
||||
mockUseGetV1GetBusinessUnderstandingPrompts.mockReturnValue({
|
||||
data: ["Help me automate onboarding", "Find my biggest bottleneck"],
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useQuickActions(makeUser()));
|
||||
|
||||
expect(result.current).toEqual([
|
||||
"Help me automate onboarding",
|
||||
"Find my biggest bottleneck",
|
||||
]);
|
||||
expect(mockUseGetV1GetBusinessUnderstandingPrompts).toHaveBeenCalledWith({
|
||||
query: expect.objectContaining({ enabled: true }),
|
||||
});
|
||||
});
|
||||
|
||||
it("falls back to defaults when the user is not authenticated", () => {
|
||||
mockUseGetV1GetBusinessUnderstandingPrompts.mockReturnValue({
|
||||
data: undefined,
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useQuickActions(null));
|
||||
|
||||
expect(result.current).toEqual(DEFAULT_QUICK_ACTIONS);
|
||||
expect(mockUseGetV1GetBusinessUnderstandingPrompts).toHaveBeenCalledWith({
|
||||
query: expect.objectContaining({ enabled: false }),
|
||||
});
|
||||
});
|
||||
|
||||
it("falls back to defaults when the API returns no prompts", () => {
|
||||
mockUseGetV1GetBusinessUnderstandingPrompts.mockReturnValue({
|
||||
data: [],
|
||||
error: new Error("no prompts"),
|
||||
isError: true,
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useQuickActions(makeUser()));
|
||||
|
||||
expect(result.current).toEqual(DEFAULT_QUICK_ACTIONS);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,17 @@
|
||||
"use client";
|
||||
|
||||
import { useGetV1GetBusinessUnderstandingPrompts } from "@/app/api/__generated__/endpoints/auth/auth";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { User } from "@supabase/supabase-js";
|
||||
import { getQuickActions } from "./helpers";
|
||||
|
||||
export function useQuickActions(user?: User | null) {
|
||||
const quickPrompts = useGetV1GetBusinessUnderstandingPrompts({
|
||||
query: {
|
||||
enabled: Boolean(user),
|
||||
select: (response) => okData(response)?.prompts,
|
||||
},
|
||||
}).data;
|
||||
|
||||
return getQuickActions(quickPrompts);
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user