Compare commits

...

2 Commits

Author SHA1 Message Date
Swifty
7d6338c1fc Merge branch 'dev' into swiftyos/custom-user-prompts 2026-03-09 15:00:15 +01:00
Swifty
ea69536165 add custom user prompts 2026-03-09 14:59:13 +01:00
17 changed files with 871 additions and 14567 deletions

View File

@@ -9,6 +9,21 @@ import pytest
from pytest_snapshot.plugin import Snapshot
@pytest.fixture
def test_user_id() -> str:
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
@pytest.fixture
def admin_user_id() -> str:
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
@pytest.fixture
def setup_test_user(test_user_id: str) -> str:
return test_user_id
@pytest.fixture
def configured_snapshot(snapshot: Snapshot) -> Snapshot:
"""Pre-configured snapshot fixture with standard settings."""

View File

@@ -29,6 +29,7 @@ from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
from typing_extensions import Optional, TypedDict
from backend.api.model import (
BusinessUnderstandingPromptsResponse,
CreateAPIKeyRequest,
CreateAPIKeyResponse,
CreateGraph,
@@ -54,6 +55,7 @@ from backend.data.credit import (
get_user_credit_model,
set_auto_top_up,
)
from backend.data.db_accessors import understanding_db
from backend.data.graph import GraphSettings
from backend.data.model import CredentialsMetaInput, UserOnboarding
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
@@ -158,6 +160,22 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
return user.model_dump()
@v1_router.get(
"/auth/user/understanding/prompts",
summary="Get business understanding prompts",
tags=["auth"],
dependencies=[Security(requires_user)],
response_model=BusinessUnderstandingPromptsResponse,
)
async def get_business_understanding_prompts_route(
user_id: Annotated[str, Security(get_user_id)],
) -> BusinessUnderstandingPromptsResponse:
understanding = await understanding_db().get_business_understanding(user_id)
return BusinessUnderstandingPromptsResponse(
prompts=understanding.prompts if understanding else []
)
@v1_router.post(
"/auth/user/email",
summary="Update user email",

View File

@@ -89,6 +89,48 @@ def test_update_user_email_route(
)
def test_get_business_understanding_prompts_route(
mocker: pytest_mock.MockFixture,
test_user_id: str,
) -> None:
mock_understanding_db = Mock()
mock_understanding_db.get_business_understanding = AsyncMock(
return_value=Mock(prompts=["Prompt one", "Prompt two", "Prompt three"])
)
mocker.patch(
"backend.api.features.v1.understanding_db",
return_value=mock_understanding_db,
)
response = client.get("/auth/user/understanding/prompts")
assert response.status_code == 200
assert response.json() == {"prompts": ["Prompt one", "Prompt two", "Prompt three"]}
mock_understanding_db.get_business_understanding.assert_awaited_once_with(
test_user_id
)
def test_get_business_understanding_prompts_route_returns_empty_list(
mocker: pytest_mock.MockFixture,
test_user_id: str,
) -> None:
mock_understanding_db = Mock()
mock_understanding_db.get_business_understanding = AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.v1.understanding_db",
return_value=mock_understanding_db,
)
response = client.get("/auth/user/understanding/prompts")
assert response.status_code == 200
assert response.json() == {"prompts": []}
mock_understanding_db.get_business_understanding.assert_awaited_once_with(
test_user_id
)
# Blocks endpoints tests
def test_get_graph_blocks(
mocker: pytest_mock.MockFixture,

View File

@@ -85,6 +85,10 @@ class UpdateTimezoneRequest(pydantic.BaseModel):
timezone: TimeZoneName
class BusinessUnderstandingPromptsResponse(pydantic.BaseModel):
prompts: list[str] = pydantic.Field(default_factory=list)
class NotificationPayload(pydantic.BaseModel):
type: str
event: str

View File

@@ -91,6 +91,7 @@ from backend.data.notifications import (
from backend.data.onboarding import increment_onboarding_runs
from backend.data.understanding import (
get_business_understanding,
update_business_understanding_prompts,
upsert_business_understanding,
)
from backend.data.user import (
@@ -311,6 +312,7 @@ class DatabaseManager(AppService):
# ============ Understanding ============ #
get_business_understanding = _(get_business_understanding)
update_business_understanding_prompts = _(update_business_understanding_prompts)
upsert_business_understanding = _(upsert_business_understanding)
# ============ CoPilot Chat Sessions ============ #
@@ -379,6 +381,11 @@ class DatabaseManagerClient(AppServiceClient):
backfill_missing_embeddings = _(d.backfill_missing_embeddings)
cleanup_orphaned_embeddings = _(d.cleanup_orphaned_embeddings)
# Understanding
get_business_understanding = _(d.get_business_understanding)
update_business_understanding_prompts = _(d.update_business_understanding_prompts)
upsert_business_understanding = _(d.upsert_business_understanding)
class DatabaseManagerAsyncClient(AppServiceClient):
d = DatabaseManager
@@ -493,6 +500,7 @@ class DatabaseManagerAsyncClient(AppServiceClient):
# ============ Understanding ============ #
get_business_understanding = d.get_business_understanding
update_business_understanding_prompts = d.update_business_understanding_prompts
upsert_business_understanding = d.upsert_business_understanding
# ============ CoPilot Chat Sessions ============ #

View File

@@ -10,10 +10,13 @@ from openai import AsyncOpenAI
from backend.data.redis_client import get_redis_async
from backend.data.understanding import (
BusinessUnderstanding,
BusinessUnderstandingInput,
get_business_understanding,
update_business_understanding_prompts,
upsert_business_understanding,
)
from backend.data.understanding_prompts import generate_understanding_prompts
from backend.util.request import Requests
from backend.util.settings import Settings
@@ -418,9 +421,31 @@ async def populate_understanding_from_tally(user_id: str, email: str) -> None:
understanding_input = await extract_business_understanding(formatted)
# Upsert into database
await upsert_business_understanding(user_id, understanding_input)
understanding = await upsert_business_understanding(
user_id, understanding_input
)
await _generate_and_store_prompts(user_id, understanding)
logger.info(f"Tally: successfully populated understanding for user {user_id}")
except Exception:
logger.exception(f"Tally: error populating understanding for user {user_id}")
async def _generate_and_store_prompts(
user_id: str, understanding: BusinessUnderstanding
) -> None:
try:
prompts = await generate_understanding_prompts(understanding)
except ValueError as e:
logger.warning(f"Tally: skipping quick prompt generation for {user_id}: {e}")
return
except Exception:
logger.exception(
f"Tally: failed to generate quick prompts for understanding {user_id}"
)
return
try:
await update_business_understanding_prompts(user_id, prompts)
except Exception:
logger.exception(f"Tally: failed to store quick prompts for user {user_id}")

View File

@@ -284,6 +284,7 @@ async def test_populate_understanding_full_flow():
],
}
mock_input = MagicMock()
mock_understanding = MagicMock()
with (
patch(
@@ -305,12 +306,35 @@ async def test_populate_understanding_full_flow():
patch(
"backend.data.tally.upsert_business_understanding",
new_callable=AsyncMock,
return_value=mock_understanding,
) as mock_upsert,
patch(
"backend.data.tally.generate_understanding_prompts",
new_callable=AsyncMock,
return_value=[
"Help me automate customer support",
"Find repetitive support work",
"Show me faster support workflows",
],
) as mock_generate_prompts,
patch(
"backend.data.tally.update_business_understanding_prompts",
new_callable=AsyncMock,
) as mock_update_prompts,
):
await populate_understanding_from_tally("user-1", "alice@example.com")
mock_extract.assert_awaited_once()
mock_upsert.assert_awaited_once_with("user-1", mock_input)
mock_generate_prompts.assert_awaited_once_with(mock_understanding)
mock_update_prompts.assert_awaited_once_with(
"user-1",
[
"Help me automate customer support",
"Find repetitive support work",
"Show me faster support workflows",
],
)
@pytest.mark.asyncio
@@ -352,6 +376,55 @@ async def test_populate_understanding_handles_llm_timeout():
mock_upsert.assert_not_awaited()
@pytest.mark.asyncio
async def test_populate_understanding_keeps_understanding_when_prompt_generation_fails():
mock_settings = MagicMock()
mock_settings.secrets.tally_api_key = "test-key"
submission = {
"responses": [{"questionId": "q1", "value": "Alice"}],
}
mock_input = MagicMock()
mock_understanding = MagicMock()
with (
patch(
"backend.data.tally.get_business_understanding",
new_callable=AsyncMock,
return_value=None,
),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.find_submission_by_email",
new_callable=AsyncMock,
return_value=(submission, SAMPLE_QUESTIONS),
),
patch(
"backend.data.tally.extract_business_understanding",
new_callable=AsyncMock,
return_value=mock_input,
),
patch(
"backend.data.tally.upsert_business_understanding",
new_callable=AsyncMock,
return_value=mock_understanding,
) as mock_upsert,
patch(
"backend.data.tally.generate_understanding_prompts",
new_callable=AsyncMock,
side_effect=ValueError("bad prompts"),
),
patch(
"backend.data.tally.update_business_understanding_prompts",
new_callable=AsyncMock,
) as mock_update_prompts,
):
await populate_understanding_from_tally("user-1", "alice@example.com")
mock_upsert.assert_awaited_once_with("user-1", mock_input)
mock_update_prompts.assert_not_awaited()
# ── _mask_email ───────────────────────────────────────────────────────────────

View File

@@ -118,6 +118,7 @@ class BusinessUnderstanding(pydantic.BaseModel):
# Current tools
current_software: list[str] = pydantic.Field(default_factory=list)
existing_automation: list[str] = pydantic.Field(default_factory=list)
prompts: list[str] = pydantic.Field(default_factory=list)
# Additional context
additional_notes: Optional[str] = None
@@ -148,6 +149,7 @@ class BusinessUnderstanding(pydantic.BaseModel):
automation_goals=_json_to_list(business.get("automation_goals")),
current_software=_json_to_list(business.get("current_software")),
existing_automation=_json_to_list(business.get("existing_automation")),
prompts=_json_to_list(business.get("prompts")),
additional_notes=business.get("additional_notes"),
)
@@ -313,6 +315,40 @@ async def upsert_business_understanding(
return understanding
async def update_business_understanding_prompts(
user_id: str, prompts: list[str]
) -> Optional[BusinessUnderstanding]:
"""Update derived quick prompts for an existing business understanding."""
existing = await CoPilotUnderstanding.prisma().find_unique(
where={"userId": user_id}
)
if existing is None:
return None
existing_data: dict[str, Any] = {}
if isinstance(existing.data, dict):
existing_data = dict(existing.data)
existing_business: dict[str, Any] = {}
if isinstance(existing_data.get("business"), dict):
existing_business = dict(existing_data["business"])
existing_business["prompts"] = prompts
existing_business["version"] = 1
existing_data["business"] = existing_business
record = await CoPilotUnderstanding.prisma().update(
where={"userId": user_id},
data={"data": SafeJson(existing_data)},
)
if record is None:
return None
understanding = BusinessUnderstanding.from_db(record)
await _set_cache(user_id, understanding)
return understanding
async def clear_business_understanding(user_id: str) -> bool:
"""Clear/delete business understanding for a user from both DB and cache."""
# Delete from cache first

View File

@@ -0,0 +1,115 @@
"""Helpers for generating quick prompts from saved business understanding."""
import asyncio
import json
import logging
from openai import AsyncOpenAI
from backend.data.understanding import (
BusinessUnderstanding,
format_understanding_for_prompt,
)
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
_LLM_TIMEOUT = 30
_PROMPTS_PROMPT = """\
You generate three short starter prompts for a user to click in a chat UI.
Return a JSON object with this exact shape:
{"prompts":["...","...","..."]}
Requirements:
- Exactly 3 prompts
- Each prompt must be written in first person, as if the user is speaking
- Each prompt must be shorter than 20 words
- Keep them specific to the user's business context
- Do not number the prompts
- Do not add labels or explanations
Business context:
"""
_PROMPTS_SUFFIX = "\n\nReturn ONLY valid JSON."
def has_prompt_generation_context(understanding: BusinessUnderstanding) -> bool:
return bool(format_understanding_for_prompt(understanding).strip())
def _normalize_prompt(prompt: str) -> str:
return " ".join(prompt.split())
def _validate_prompts(value: object) -> list[str]:
if not isinstance(value, list) or len(value) != 3:
raise ValueError("Prompt response must contain exactly three prompts")
prompts: list[str] = []
seen: set[str] = set()
for item in value:
if not isinstance(item, str):
raise ValueError("Each prompt must be a string")
prompt = _normalize_prompt(item)
if not prompt:
raise ValueError("Prompts cannot be empty")
if len(prompt.split()) >= 20:
raise ValueError("Prompts must be fewer than 20 words")
if prompt in seen:
raise ValueError("Prompts must be unique")
seen.add(prompt)
prompts.append(prompt)
return prompts
async def generate_understanding_prompts(
understanding: BusinessUnderstanding,
) -> list[str]:
"""Generate validated quick prompts from a saved understanding snapshot."""
context = format_understanding_for_prompt(understanding)
if not context.strip():
raise ValueError("Understanding does not contain usable context")
settings = Settings()
client = AsyncOpenAI(
api_key=settings.secrets.open_router_api_key,
base_url="https://openrouter.ai/api/v1",
)
try:
response = await asyncio.wait_for(
client.chat.completions.create(
model="openai/gpt-4o-mini",
messages=[
{
"role": "user",
"content": f"{_PROMPTS_PROMPT}{context}{_PROMPTS_SUFFIX}",
}
],
response_format={"type": "json_object"},
temperature=0.2,
),
timeout=_LLM_TIMEOUT,
)
except asyncio.TimeoutError:
logger.warning("Understanding prompts: generation timed out")
raise
raw = response.choices[0].message.content or "{}"
try:
data = json.loads(raw)
except json.JSONDecodeError:
logger.warning("Understanding prompts: invalid JSON response")
raise
if not isinstance(data, dict):
raise ValueError("Prompt response must be a JSON object")
return _validate_prompts(data.get("prompts"))

View File

@@ -0,0 +1,145 @@
"""Tests for backend.data.understanding_prompts."""
import asyncio
import json
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.data.understanding import BusinessUnderstanding
from backend.data.understanding_prompts import generate_understanding_prompts
def make_understanding(**overrides) -> BusinessUnderstanding:
data = {
"id": "understanding-1",
"user_id": "user-1",
"created_at": datetime.now(timezone.utc),
"updated_at": datetime.now(timezone.utc),
"industry": "Customer support",
"pain_points": ["manual ticket triage"],
"automation_goals": ["speed up support responses"],
}
data.update(overrides)
return BusinessUnderstanding(**data)
@pytest.mark.asyncio
async def test_generate_understanding_prompts_success():
mock_choice = MagicMock()
mock_choice.message.content = json.dumps(
{
"prompts": [
"Help me automate customer support triage",
"Show me how to speed up support replies",
"Find repetitive work in our support process",
]
}
)
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_client = AsyncMock()
mock_client.chat.completions.create.return_value = mock_response
with patch(
"backend.data.understanding_prompts.AsyncOpenAI", return_value=mock_client
):
prompts = await generate_understanding_prompts(make_understanding())
assert prompts == [
"Help me automate customer support triage",
"Show me how to speed up support replies",
"Find repetitive work in our support process",
]
@pytest.mark.asyncio
async def test_generate_understanding_prompts_rejects_duplicates():
mock_choice = MagicMock()
mock_choice.message.content = json.dumps(
{
"prompts": [
"Help me automate customer support",
"Help me automate customer support",
"Find repetitive support work",
]
}
)
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_client = AsyncMock()
mock_client.chat.completions.create.return_value = mock_response
with (
patch(
"backend.data.understanding_prompts.AsyncOpenAI", return_value=mock_client
),
pytest.raises(ValueError, match="unique"),
):
await generate_understanding_prompts(make_understanding())
@pytest.mark.asyncio
async def test_generate_understanding_prompts_rejects_long_prompt():
mock_choice = MagicMock()
mock_choice.message.content = json.dumps(
{
"prompts": [
"Please help me automate every part of our customer support workflow starting with ticket triage routing follow-up escalation and reporting today",
"Show me better support workflows",
"Find support busywork for me",
]
}
)
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_client = AsyncMock()
mock_client.chat.completions.create.return_value = mock_response
with (
patch(
"backend.data.understanding_prompts.AsyncOpenAI", return_value=mock_client
),
pytest.raises(ValueError, match="fewer than 20 words"),
):
await generate_understanding_prompts(make_understanding())
@pytest.mark.asyncio
async def test_generate_understanding_prompts_rejects_invalid_shape():
mock_choice = MagicMock()
mock_choice.message.content = json.dumps(
{"prompts": ["Help me automate support", "Find repetitive work"]}
)
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_client = AsyncMock()
mock_client.chat.completions.create.return_value = mock_response
with (
patch(
"backend.data.understanding_prompts.AsyncOpenAI", return_value=mock_client
),
pytest.raises(ValueError, match="exactly three prompts"),
):
await generate_understanding_prompts(make_understanding())
@pytest.mark.asyncio
async def test_generate_understanding_prompts_timeout():
mock_client = AsyncMock()
mock_client.chat.completions.create.side_effect = asyncio.TimeoutError()
with (
patch(
"backend.data.understanding_prompts.AsyncOpenAI", return_value=mock_client
),
patch("backend.data.understanding_prompts._LLM_TIMEOUT", 0.001),
pytest.raises(asyncio.TimeoutError),
):
await generate_understanding_prompts(make_understanding())

View File

@@ -0,0 +1,134 @@
#!/usr/bin/env python3
"""Backfill quick prompts for saved business understanding records."""
import asyncio
import json
import logging
import click
from prisma.models import CoPilotUnderstanding
from backend.data import db
from backend.data.understanding import (
BusinessUnderstanding,
update_business_understanding_prompts,
)
from backend.data.understanding_prompts import (
generate_understanding_prompts,
has_prompt_generation_context,
)
logger = logging.getLogger(__name__)
async def backfill_understanding_prompts(
batch_size: int = 100,
limit: int | None = None,
dry_run: bool = False,
) -> dict[str, int]:
summary = {
"scanned": 0,
"candidates": 0,
"eligible": 0,
"updated": 0,
"failed": 0,
"skipped_existing": 0,
"skipped_no_context": 0,
}
offset = 0
while True:
records = await CoPilotUnderstanding.prisma().find_many(
order={"id": "asc"},
skip=offset,
take=batch_size,
)
if not records:
break
offset += len(records)
for record in records:
summary["scanned"] += 1
understanding = BusinessUnderstanding.from_db(record)
if understanding.prompts:
summary["skipped_existing"] += 1
continue
if limit is not None and summary["candidates"] >= limit:
logger.info("Reached backfill limit of %s records", limit)
return summary
summary["candidates"] += 1
if not has_prompt_generation_context(understanding):
summary["skipped_no_context"] += 1
continue
summary["eligible"] += 1
if dry_run:
continue
try:
prompts = await generate_understanding_prompts(understanding)
updated = await update_business_understanding_prompts(
understanding.user_id, prompts
)
except Exception:
summary["failed"] += 1
logger.exception(
"Failed to backfill prompts for user %s", understanding.user_id
)
continue
if updated is None:
summary["failed"] += 1
logger.warning(
"Skipped backfill for user %s because the record no longer exists",
understanding.user_id,
)
continue
summary["updated"] += 1
logger.info("Understanding prompt backfill summary: %s", json.dumps(summary))
return summary
async def run_backfill(
batch_size: int = 100,
limit: int | None = None,
dry_run: bool = False,
) -> dict[str, int]:
await db.connect()
try:
return await backfill_understanding_prompts(
batch_size=batch_size,
limit=limit,
dry_run=dry_run,
)
finally:
await db.disconnect()
@click.command()
@click.option("--dry-run", is_flag=True, default=False, help="Report candidates only.")
@click.option("--limit", type=click.IntRange(min=1), default=None)
@click.option(
"--batch-size", type=click.IntRange(min=1), default=100, show_default=True
)
def main(dry_run: bool, limit: int | None, batch_size: int) -> None:
logging.basicConfig(level=logging.INFO)
summary = asyncio.run(
run_backfill(
batch_size=batch_size,
limit=limit,
dry_run=dry_run,
)
)
click.echo(json.dumps(summary, indent=2, sort_keys=True))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,158 @@
"""Tests for the understanding prompt backfill script."""
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
from scripts.backfill_understanding_prompts import backfill_understanding_prompts
def make_record(*, user_id: str, business: dict) -> SimpleNamespace:
return SimpleNamespace(
id=f"understanding-{user_id}",
userId=user_id,
createdAt=datetime.now(timezone.utc),
updatedAt=datetime.now(timezone.utc),
data={"business": business},
)
@pytest.mark.asyncio
async def test_backfill_understanding_prompts_dry_run():
record = make_record(
user_id="user-1",
business={"business_name": "Acme", "industry": "Support"},
)
prisma = AsyncMock()
prisma.find_many.side_effect = [[record], []]
with (
patch(
"scripts.backfill_understanding_prompts.CoPilotUnderstanding.prisma",
return_value=prisma,
),
patch(
"scripts.backfill_understanding_prompts.generate_understanding_prompts",
new_callable=AsyncMock,
) as mock_generate,
patch(
"scripts.backfill_understanding_prompts.update_business_understanding_prompts",
new_callable=AsyncMock,
) as mock_update,
):
summary = await backfill_understanding_prompts(batch_size=10, dry_run=True)
assert summary == {
"scanned": 1,
"candidates": 1,
"eligible": 1,
"updated": 0,
"failed": 0,
"skipped_existing": 0,
"skipped_no_context": 0,
}
mock_generate.assert_not_awaited()
mock_update.assert_not_awaited()
@pytest.mark.asyncio
async def test_backfill_understanding_prompts_skips_existing_prompts():
record = make_record(
user_id="user-1",
business={
"business_name": "Acme",
"prompts": ["Prompt one", "Prompt two", "Prompt three"],
},
)
prisma = AsyncMock()
prisma.find_many.side_effect = [[record], []]
with patch(
"scripts.backfill_understanding_prompts.CoPilotUnderstanding.prisma",
return_value=prisma,
):
summary = await backfill_understanding_prompts(batch_size=10)
assert summary == {
"scanned": 1,
"candidates": 0,
"eligible": 0,
"updated": 0,
"failed": 0,
"skipped_existing": 1,
"skipped_no_context": 0,
}
@pytest.mark.asyncio
async def test_backfill_understanding_prompts_updates_missing_prompts():
record = make_record(
user_id="user-1",
business={"business_name": "Acme", "industry": "Support"},
)
prisma = AsyncMock()
prisma.find_many.side_effect = [[record], []]
with (
patch(
"scripts.backfill_understanding_prompts.CoPilotUnderstanding.prisma",
return_value=prisma,
),
patch(
"scripts.backfill_understanding_prompts.generate_understanding_prompts",
new_callable=AsyncMock,
return_value=["Prompt one", "Prompt two", "Prompt three"],
) as mock_generate,
patch(
"scripts.backfill_understanding_prompts.update_business_understanding_prompts",
new_callable=AsyncMock,
return_value=object(),
) as mock_update,
):
summary = await backfill_understanding_prompts(batch_size=10)
assert summary == {
"scanned": 1,
"candidates": 1,
"eligible": 1,
"updated": 1,
"failed": 0,
"skipped_existing": 0,
"skipped_no_context": 0,
}
mock_generate.assert_awaited_once()
mock_update.assert_awaited_once_with(
"user-1", ["Prompt one", "Prompt two", "Prompt three"]
)
@pytest.mark.asyncio
async def test_backfill_understanding_prompts_skips_records_without_context():
record = make_record(user_id="user-1", business={})
prisma = AsyncMock()
prisma.find_many.side_effect = [[record], []]
with (
patch(
"scripts.backfill_understanding_prompts.CoPilotUnderstanding.prisma",
return_value=prisma,
),
patch(
"scripts.backfill_understanding_prompts.generate_understanding_prompts",
new_callable=AsyncMock,
) as mock_generate,
):
summary = await backfill_understanding_prompts(batch_size=10)
assert summary == {
"scanned": 1,
"candidates": 1,
"eligible": 0,
"updated": 0,
"failed": 0,
"skipped_existing": 0,
"skipped_no_context": 1,
}
mock_generate.assert_not_awaited()

View File

@@ -7,11 +7,8 @@ import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { SpinnerGapIcon } from "@phosphor-icons/react";
import { motion } from "framer-motion";
import { useEffect, useState } from "react";
import {
getGreetingName,
getInputPlaceholder,
getQuickActions,
} from "./helpers";
import { getGreetingName, getInputPlaceholder } from "./helpers";
import { useQuickActions } from "./useQuickActions";
interface Props {
inputLayoutId: string;
@@ -33,7 +30,7 @@ export function EmptySession({
}: Props) {
const { user } = useSupabase();
const greetingName = getGreetingName(user);
const quickActions = getQuickActions();
const quickActions = useQuickActions(user);
const [loadingAction, setLoadingAction] = useState<string | null>(null);
const [inputPlaceholder, setInputPlaceholder] = useState(
getInputPlaceholder(),

View File

@@ -1,5 +1,11 @@
import { User } from "@supabase/supabase-js";
export const DEFAULT_QUICK_ACTIONS = [
"I don't know where to start, just ask me stuff",
"I do the same thing every week and it's killing me",
"Help me find where I'm wasting my time",
];
export function getInputPlaceholder(width?: number) {
if (!width) return "What's your role and what eats up most of your day?";
@@ -12,12 +18,15 @@ export function getInputPlaceholder(width?: number) {
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
}
export function getQuickActions() {
return [
"I don't know where to start, just ask me stuff",
"I do the same thing every week and it's killing me",
"Help me find where I'm wasting my time",
];
export function getQuickActions(prompts?: string[] | null) {
const normalizedPrompts =
prompts
?.map((prompt) => prompt.trim())
.filter((prompt) => prompt.length > 0) ?? [];
return normalizedPrompts.length > 0
? normalizedPrompts
: DEFAULT_QUICK_ACTIONS;
}
export function getGreetingName(user?: User | null) {

View File

@@ -0,0 +1,61 @@
import { renderHook } from "@testing-library/react";
import { describe, expect, it, vi } from "vitest";
import type { User } from "@supabase/supabase-js";
import { DEFAULT_QUICK_ACTIONS } from "./helpers";
import { useQuickActions } from "./useQuickActions";
const { mockUseGetV1GetBusinessUnderstandingPrompts } = vi.hoisted(() => ({
mockUseGetV1GetBusinessUnderstandingPrompts: vi.fn(),
}));
vi.mock("@/app/api/__generated__/endpoints/auth/auth", () => ({
useGetV1GetBusinessUnderstandingPrompts:
mockUseGetV1GetBusinessUnderstandingPrompts,
}));
function makeUser() {
return { id: "user-1" } as User;
}
describe("useQuickActions", () => {
it("uses server prompts when available", () => {
mockUseGetV1GetBusinessUnderstandingPrompts.mockReturnValue({
data: ["Help me automate onboarding", "Find my biggest bottleneck"],
});
const { result } = renderHook(() => useQuickActions(makeUser()));
expect(result.current).toEqual([
"Help me automate onboarding",
"Find my biggest bottleneck",
]);
expect(mockUseGetV1GetBusinessUnderstandingPrompts).toHaveBeenCalledWith({
query: expect.objectContaining({ enabled: true }),
});
});
it("falls back to defaults when the user is not authenticated", () => {
mockUseGetV1GetBusinessUnderstandingPrompts.mockReturnValue({
data: undefined,
});
const { result } = renderHook(() => useQuickActions(null));
expect(result.current).toEqual(DEFAULT_QUICK_ACTIONS);
expect(mockUseGetV1GetBusinessUnderstandingPrompts).toHaveBeenCalledWith({
query: expect.objectContaining({ enabled: false }),
});
});
it("falls back to defaults when the API returns no prompts", () => {
mockUseGetV1GetBusinessUnderstandingPrompts.mockReturnValue({
data: [],
error: new Error("no prompts"),
isError: true,
});
const { result } = renderHook(() => useQuickActions(makeUser()));
expect(result.current).toEqual(DEFAULT_QUICK_ACTIONS);
});
});

View File

@@ -0,0 +1,17 @@
"use client";
import { useGetV1GetBusinessUnderstandingPrompts } from "@/app/api/__generated__/endpoints/auth/auth";
import { okData } from "@/app/api/helpers";
import { User } from "@supabase/supabase-js";
import { getQuickActions } from "./helpers";
export function useQuickActions(user?: User | null) {
const quickPrompts = useGetV1GetBusinessUnderstandingPrompts({
query: {
enabled: Boolean(user),
select: (response) => okData(response)?.prompts,
},
}).data;
return getQuickActions(quickPrompts);
}

File diff suppressed because it is too large Load Diff