Compare commits

...

67 Commits

Author SHA1 Message Date
Bentlybro
be328c1ec5 feat(backend/llm-registry): wire refresh_runtime_caches to Redis invalidation and pub/sub
After any admin DB mutation, clear the shared Redis cache, refresh this
process's in-memory state, then publish a notification so all other workers
reload from Redis without hitting the database.
2026-04-07 18:35:41 +01:00
Bentlybro
8410448c16 fix(backend/llm-registry): enforce single recommended model in update_model
When setting is_recommended=True on a model, first clears the flag on all
other models within the same transaction so only one model can be
recommended at a time.
2026-04-07 18:35:08 +01:00
Bentlybro
e168597663 chore: regenerate OpenAPI schema from current backend endpoints 2026-04-07 18:35:08 +01:00
Bentlybro
1d903ae287 fix: add trailing newline to openapi.json 2026-04-07 18:35:08 +01:00
Bentlybro
1be7aebdea chore: regenerate OpenAPI schema for new migration endpoints 2026-04-07 18:35:08 +01:00
Bentlybro
36045c7007 feat(backend): Add model migration system - usage tracking, safe delete, disable with migration, revert
- GET /llm/models/{slug}/usage - count AgentNodes using a model
- DELETE /llm/models/{slug} with optional replacement_model_slug for safe migration
- POST /llm/models/{slug}/toggle with migration support when disabling
- GET /llm/migrations - list model migrations (with include_reverted filter)
- POST /llm/migrations/{id}/revert - revert a migration (restores nodes, re-enables source model)
- Transactional migration: counts nodes, migrates atomically, creates LlmModelMigration audit record
- Ported from original PR #11699's db.py
2026-04-07 18:35:08 +01:00
Bentlybro
445eb173a5 feat(backend): Add admin list endpoints, creator CRUD, model cost creation
- GET /llm/admin/providers - list all providers from DB (includes empty ones)
- GET /llm/admin/models - list all models with costs and creator info
- POST /llm/creators - create new creator
- PATCH /llm/creators/{name} - update creator
- DELETE /llm/creators/{name} - delete creator (with model check)
- Create LlmModelCost records when creating a model
- Resolve provider name to ID in create_model
- Add costs field to CreateLlmModelRequest
2026-04-07 18:35:08 +01:00
Bentlybro
393a138fee fix(backend): Use {slug:path} for model routes to support slugs with slashes
Model slugs like 'openai/gpt-oss-120b' contain forward slashes which
FastAPI's default {slug} parameter doesn't capture. Using {slug:path}
allows the full slug to be captured as a single parameter.
2026-04-07 18:35:08 +01:00
Bentlybro
ccc1e35c5b Add LLM creators endpoint and OpenAPI entry
Introduce a read endpoint for LLM model creators: add _map_creator_response serializer and an admin-only GET /llm/creators route that queries prisma.models.LlmModelCreator (ordered by name), logs results, and returns serialized creators with error handling. Also update frontend OpenAPI spec with the /api/llm/creators GET operation.
2026-04-07 18:35:07 +01:00
Bentlybro
c66f114e28 Add LLM model/provider API endpoints
Add admin LLM CRUD endpoints and request schemas to the OpenAPI spec. Introduces POST /api/llm/models and DELETE/PATCH /api/llm/models/{slug}, and POST /api/llm/providers and DELETE/PATCH /api/llm/providers/{name} (all with HTTPBearerJWT security where applicable). Adds CreateLlmModelRequest, UpdateLlmModelRequest, CreateLlmProviderRequest, and UpdateLlmProviderRequest component schemas and corresponding responses (201/200/204 plus validation and auth errors). Notes provider deletion requires no associated models.
2026-04-07 18:35:07 +01:00
Bentlybro
939edc73b8 feat(platform): Implement LLM registry admin API functionality
Implement full CRUD operations for admin API:

Database layer (db_write.py):
- create_provider, update_provider, delete_provider
- create_model, update_model, delete_model
- refresh_runtime_caches - invalidates in-memory registry after mutations
- Proper validation and error handling

Admin routes (admin_routes.py):
- All endpoints now functional (no more 501)
- Proper error responses (400 for validation, 404 for not found, 500 for server errors)
- Lookup by slug/name before operations
- Cache refresh after all mutations

Features:
- Provider deletion blocked if models exist (FK constraint)
- All mutations refresh registry cache automatically
- Proper logging for audit trail
- Admin auth enforced on all endpoints

Based on original implementation from PR #11699 (upstream-llm branch).

Builds on:
- PR #12357: Schema foundation
- PR #12359: Registry core
- PR #12371: Public read API
2026-04-07 18:35:07 +01:00
Bentlybro
d52409c853 feat(platform): Add LLM registry admin API skeleton - Part 4 of 6
Add admin write API endpoints for LLM registry management:
- POST /api/llm/models - Create model
- PATCH /api/llm/models/{slug} - Update model
- DELETE /api/llm/models/{slug} - Delete model
- POST /api/llm/providers - Create provider
- PATCH /api/llm/providers/{name} - Update provider
- DELETE /api/llm/providers/{name} - Delete provider

All endpoints require admin authentication via requires_admin_user.

Request/response models defined in admin_model.py:
- CreateLlmModelRequest, UpdateLlmModelRequest
- CreateLlmProviderRequest, UpdateLlmProviderRequest

Implementation coming in follow-up commits (currently returns 501 Not Implemented).

This builds on:
- PR #12357: Schema foundation
- PR #12359: Registry core
- PR #12371: Public read API
2026-04-07 18:35:07 +01:00
Bentlybro
90a68084eb fix(backend): Include is_enabled in public model list response 2026-04-07 18:34:58 +01:00
Bentlybro
fb9a3224be Add is_enabled field to OpenAPI model
Introduce a new boolean property `is_enabled` (default: true) into the OpenAPI schema in autogpt_platform/frontend/src/app/api/openapi.json next to `price_tier` and `is_recommended`. This exposes an enable/disable flag in the API model for consumers and defaults new entries to enabled.
2026-04-07 18:34:58 +01:00
Bentlybro
eb76b95aa5 Add is_enabled flag to LlmModel
Introduce an is_enabled: bool = True field to the LlmModel pydantic model to allow toggling model availability. Defaulting to True preserves backward compatibility and avoids breaking changes; can be used by APIs or UIs to filter or disable models without removing them.
2026-04-07 18:34:58 +01:00
Bentlybro
cc17884360 Add LLM models/providers endpoints to OpenAPI
Add two new GET endpoints to the OpenAPI spec: /api/llm/models (with optional enabled_only query param, JWT auth) and /api/llm/providers (JWT auth). These endpoints expose the in-memory LLM registry: list of models and grouped providers with their enabled models. Also add related component schemas (LlmModel, LlmModelCost, LlmModelCreator, LlmModelsResponse, LlmProvider, LlmProvidersResponse) describing model metadata, costs, creators and response shapes.
2026-04-07 18:34:58 +01:00
Bentlybro
1ce3cc0231 fix: remove incorrectly placed openapi.json file 2026-04-07 18:34:58 +01:00
Bentlybro
bd1f4b5701 fix: regenerate OpenAPI schema after rebase 2026-04-07 18:34:58 +01:00
Bentlybro
e89e56d90d feat(platform): Add LLM registry public read API
Implements public GET endpoints for querying LLM models and providers - Part 3 of 6 in the incremental registry rollout.

**Endpoints:**
- GET /api/llm/models - List all models (filterable by enabled_only)
- GET /api/llm/providers - List providers with their models

**Design:**
- Uses in-memory registry from PR 2 (no DB queries)
- Fast reads from cache populated at startup
- Grouped by provider for easy UI rendering

**Response models:**
- LlmModel - model info with capabilities, costs, creator
- LlmProvider - provider with nested models
- LlmModelsResponse - list + total count
- LlmProvidersResponse - grouped by provider

**Authentication:**
- Requires user auth (requires_user dependency)
- Public within authenticated sessions

**Integration:**
- Registered in rest_api.py at /api prefix
- Tagged with v2 + llm for OpenAPI grouping

**What's NOT included (later PRs):**
- Admin write API (PR 4)
- Block integration (PR 5)
- Redis cache (PR 6)

Lines: ~180 total
Files: 4 (3 new, 1 modified)
Review time: < 10 minutes
2026-04-07 18:34:58 +01:00
Bentlybro
2a923dcd92 feat(backend/llm-registry): add Redis-backed cache and cross-process pub/sub sync
- Wrap DB fetch with @cached(shared_cache=True) so results are stored in
  Redis automatically — other workers skip the DB on warm cache
- Add notifications.py with publish/subscribe helpers using llm_registry:refresh
  pub/sub channel for cross-process invalidation
- clear_registry_cache() invalidates the shared Redis entry before a forced
  DB refresh (called by admin mutations)
- rest_api.py: start a background subscription task so every worker reloads
  its in-process cache when another worker refreshes the registry
2026-04-07 18:34:45 +01:00
Bentlybro
1fffd21b16 fix(registry): switch to Pydantic models, add typed capabilities, add unit tests
- Replace frozen dataclasses with Pydantic BaseModel(frozen=True) for true immutability
- Add typed boolean fields for model capabilities (supports_tools, etc.)
- Add comprehensive unit tests for registry module
- Addresses Majdyz review feedback on PR #12359
2026-04-05 08:04:55 +00:00
Bentlybro
2241a62b75 fix(registry): address Majdyz review - extract helper, fix schema prefix, return copies, remove re-export 2026-04-04 20:38:13 +00:00
Bentlybro
a5b71b9783 style: fix trailing whitespace in registry.py 2026-03-25 14:12:06 +00:00
Bentlybro
7632548408 fix(startup): handle missing AgentNode table in migrate_llm_models
Tests fail with 'relation "platform.AgentNode" does not exist' because
migrate_llm_models() runs during startup and queries a table that doesn't
exist in fresh test databases.

This is an existing bug in the codebase - the function has no error handling.

Wrap the call in try/except to gracefully handle test environments where
the AgentNode table hasn't been created yet.
2026-03-25 14:12:06 +00:00
Bentlybro
05fa10925c refactor: address CodeRabbit/Majdyz review feedback
- Fix ModelMetadata duplicate type collision by importing from blocks.llm
- Remove _json_to_dict helper, use dict() inline
- Add warning when Provider relation is missing (data corruption indicator)
- Optimize get_default_model_slug with next() (single sort pass)
- Optimize _build_schema_options to use list comprehension
- Move llm_registry import to top-level in rest_api.py
- Ensure max_output_tokens falls back to context_window when null

All critical and quick-win issues addressed.
2026-03-25 14:12:06 +00:00
Bentlybro
c64246be87 fix: address Sentry/CodeRabbit critical and major issues
**CRITICAL FIX - ModelMetadata instantiation:**
- Removed non-existent 'supports_vision' argument
- Added required fields: display_name, provider_name, creator_name, price_tier
- Handle nullable DB fields (Creator, priceTier, maxOutputTokens) safely
- Fallback: creator_name='Unknown' if no Creator, price_tier=1 if invalid

**MAJOR FIX - Preserve pricing unit:**
- Added 'unit' field to RegistryModelCost dataclass
- Prevents RUN vs TOKENS ambiguity in cached costs
- Convert Prisma enum to string when building cost objects

**MAJOR FIX - Deterministic default model:**
- Sort recommended models by display_name before selection
- Prevents non-deterministic results when multiple models are recommended
- Ensures consistent default across refreshes

**STARTUP IMPROVEMENT:**
- Added comment: graceful fallback OK for now (no blocks use registry yet)
- Will be stricter in PR #5 when block integration lands
- Added success log message for registry refresh

Fixes identified by Sentry (critical TypeError) and CodeRabbit review.
2026-03-25 14:12:06 +00:00
Bentlybro
253937e7b9 feat(platform): Add LLM registry core - DB layer + in-memory cache
Implements the registry core for dynamic LLM model management:

**DB Layer:**
- Fetch models with provider, costs, and creator relations
- Prisma query with includes for related data
- Convert DB records to typed dataclasses

**In-memory Cache:**
- Global dict for fast model lookups
- Atomic cache refresh with lock protection
- Schema options generation for UI dropdowns

**Public API:**
- get_model(slug) - lookup by slug
- get_all_models() - all models (including disabled)
- get_enabled_models() - enabled models only
- get_schema_options() - UI dropdown data
- get_default_model_slug() - recommended or first enabled
- refresh_llm_registry() - manual refresh trigger

**Integration:**
- Refresh at API startup (before block init)
- Graceful fallback if registry unavailable
- Enables blocks to consume registry data

**Models:**
- RegistryModel - full model with metadata
- RegistryModelCost - pricing configuration
- RegistryModelCreator - model creator info
- ModelMetadata - context window, capabilities

**Next PRs:**
- PR #3: Public read API (GET endpoints)
- PR #4: Admin write API (POST/PATCH/DELETE)
- PR #5: Block integration (update LLM block)
- PR #6: Redis cache (solve thundering herd)

Lines: ~230 (registry.py ~210, __init__.py ~30, model.py from draft)
Files: 4 (3 new, 1 modified)
2026-03-25 14:12:06 +00:00
Bentlybro
73e481b508 revert: undo changes to graph.py
Reverting migrate_llm_models modifications per request.
Back to dev baseline for this file.
2026-03-25 14:12:06 +00:00
Bentlybro
f0cc4ae573 Seed LLM model creators and link models
Update migration to seed LLM model creators and associate them with models. Adds INSERTs for LlmModelCreator, updates the file header comment, introduces a creator_ids CTE, and extends the LlmModel INSERT to include creatorId (joining on creator name). Existing provider seeding and model cost logic remain unchanged; ON CONFLICT behavior preserved.
2026-03-25 13:57:41 +00:00
Bently
e0282b00db Merge branch 'dev' into feat/llm-registry-schema 2026-03-23 13:43:15 +00:00
Bentlybro
9a9c36b806 Update LLM registry seeds and conflict clause
Add and rename model slugs and costs in the LLM registry seed migration (e.g. rename 'o3' -> 'o3-2025-04-16', add 'gpt-5.2-2025-12-11', Anthropic 'claude-opus-4-6'/'claude-sonnet-4-6', multiple Google Gemini and Mistralai OpenRouter entries, and other provider models). Also tighten the ON CONFLICT upsert semantics so conflicts are ignored only when "credentialId" IS NULL, preventing silent skips for credentialed entries. These changes seed new models and ensure correct conflict handling during migration.
2026-03-23 13:20:48 +00:00
Zamil Majdy
e86ac21c43 feat(platform): add workflow import from other tools (n8n, Make.com, Zapier) (#12440)
## Summary
- Enable one-click import of workflows from other platforms (n8n,
Make.com, Zapier, etc.) into AutoGPT via CoPilot
- **No backend endpoint** — import is entirely client-side: the dialog
reads the file or fetches the n8n template URL, uploads the JSON to the
workspace via `uploadFileDirect`, stores the file reference in
`sessionStorage`, and redirects to CoPilot with `autosubmit=true`
- CoPilot receives the workflow JSON as a proper file attachment and
uses the existing agent-generator pipeline to convert it
- Library dialog redesigned: 2 tabs — "AutoGPT agent" (upload exported
agent JSON) and "Another platform" (file upload + optional n8n URL)

## How it works
1. User uploads a workflow JSON (or pastes an n8n template URL)
2. Frontend fetches/reads the JSON and uploads it to the user's
workspace via the existing file upload API
3. User is redirected to `/copilot?source=import&autosubmit=true`
4. CoPilot picks up the file from `sessionStorage` and sends it as a
`FileUIPart` attachment with a prompt to recreate the workflow as an
AutoGPT agent

## Test plan
- [x] Manual test: import a real n8n workflow JSON via the dialog
- [x] Manual test: paste an n8n template URL and verify it fetches +
converts
- [x] Manual test: import Make.com / Zapier workflow export JSON
- [x] Repeated imports don't cause 409 conflicts (filenames use
`crypto.randomUUID()`)
- [x] E2E: Import dialog has 2 tabs (AutoGPT agent + Another platform)
- [x] E2E: n8n quick-start template buttons present
- [x] E2E: n8n URL input enables Import button on valid URL
- [x] E2E: Workspace upload API returns file_id
2026-03-23 13:03:02 +00:00
Bentlybro
d5381625cd Add timestamps to LLM registry seed inserts
Update migration.sql to include createdAt and updatedAt columns/values for LlmProvider, LlmModel, and LlmModelCost seed inserts. Uses CURRENT_TIMESTAMP for both timestamp fields and adjusts the INSERT SELECT ordering for models to match the added columns. This ensures the seed data satisfies schemas that require timestamps and provides consistent created/updated metadata.
2026-03-23 12:59:30 +00:00
Bentlybro
f6ae3d6593 Update migration.sql 2026-03-23 12:43:06 +00:00
Lluis Agusti
94224be841 Merge remote-tracking branch 'origin/master' into dev 2026-03-23 20:42:32 +08:00
Bentlybro
0fb1b854df add llm's via migration 2026-03-23 12:37:29 +00:00
Otto
da4bdc7ab9 fix(backend+frontend): reduce Sentry noise from user-caused errors (#12513)
Requested by @majdyz

User-caused errors (no payment method, webhook agent invocation, missing
credentials, bad API keys) were hitting Sentry via `logger.exception()`
in the `ValueError` handler, creating noise that obscures real bugs.
Additionally, a frontend crash on the copilot page (BUILDER-71J) needed
fixing.

**Changes:**

**Backend — rest_api.py**
- Set `log_error=False` for the `ValueError` exception handler (line
278), consistent with how `FolderValidationError` and `NotFoundError`
are already handled. User-caused 400 errors no longer trigger
`logger.exception()` → Sentry.

**Backend — executor/manager.py**
- Downgrade `ExecutionManager` input validation skip errors from `error`
to `warning` level. Missing credentials is expected user behavior, not
an internal error.

**Backend — blocks/llm.py**
- Sanitize unpaired surrogates in LLM prompt content before sending to
provider APIs. Prevents `UnicodeEncodeError: surrogates not allowed`
when httpx encodes the JSON body (AUTOGPT-SERVER-8AX).

**Frontend — package.json**
- Upgrade `ai` SDK from `6.0.59` to `6.0.134` to fix BUILDER-71J
(`TypeError: undefined is not an object (evaluating
'this.activeResponse.state')` on /copilot page). This is a known issue
in the Vercel AI SDK fixed in later patch versions.

**Sentry issues addressed:**
- `No payment method found` (ValueError → 400)
- `This agent is triggered by an external event (webhook)` (ValueError →
400)
- `Node input updated with non-existent credentials` (ValueError → 400)
- `[ExecutionManager] Skip execution, input validation error: missing
input {credentials}`
- `UnicodeEncodeError: surrogates not allowed` (AUTOGPT-SERVER-8AX)
- `TypeError: activeResponse.state` (BUILDER-71J)

Resolves SECRT-2166

---
Co-authored-by: Zamil Majdy (@majdyz) <zamil.majdy@agpt.co>

---------

Co-authored-by: Zamil Majdy (@majdyz) <zamil.majdy@agpt.co>
2026-03-23 12:22:49 +00:00
Zamil Majdy
7176cecf25 perf(copilot): reduce tool schema token cost by 34% (#12398)
## Summary

Reduce CoPilot per-turn token overhead by systematically trimming tool
descriptions, parameter schemas, and system prompt content. All 35 MCP
tool schemas are passed on every SDK call — this PR reduces their size.

### Strategy

1. **Tool descriptions**: Trimmed verbose multi-sentence explanations to
concise single-sentence summaries while preserving meaning
2. **Parameter schemas**: Shortened parameter descriptions to essential
info, removed some `default` values (handled in code)
3. **System prompt**: Condensed `_SHARED_TOOL_NOTES` and storage
supplement template in `prompting.py`
4. **Cross-tool references**: Removed duplicate workflow hints (e.g.
"call find_block before run_block" appeared in BOTH tools — kept only in
the dependent tool). Critical cross-tool references retained (e.g.
`continue_run_block` in `run_block`, `fix_agent_graph` in
`validate_agent`, `get_doc_page` in `search_docs`, `web_fetch`
preference in `browser_navigate`)

### Token Impact

| Metric | Before | After | Reduction |
|--------|--------|-------|-----------|
| System Prompt | ~865 tokens | ~497 tokens | 43% |
| Tool Schemas | ~9,744 tokens | ~6,470 tokens | 34% |
| **Grand Total** | **~10,609 tokens** | **~6,967 tokens** | **34%** |

Saves **~3,642 tokens per conversation turn**.

### Key Decisions

- **Mostly description changes**: Tool logic, parameters, and types
unchanged. However, some schema-level `default` fields were removed
(e.g. `save` in `customize_agent`) — these are machine-readable
metadata, not just prose, and may affect LLM behavior.
- **Quality preserved**: All descriptions still convey what the tool
does and essential usage patterns
- **Cross-references trimmed carefully**: Kept prerequisite hints in the
dependent tool (run_block mentions find_block) but removed the reverse
(find_block no longer mentions run_block). Critical cross-tool guidance
retained where removal would degrade model behavior.
- **`run_time` description fixed**: Added missing supported values
(today, last 30 days, ISO datetime) per review feedback

### Future Optimization

The SDK passes all 35 tools on every call. The MCP protocol's
`list_tools()` handler supports dynamic tool registration — a follow-up
PR could implement lazy tool loading (register core tools + a discovery
meta-tool) to further reduce per-turn token cost.

### Changes

- Trimmed descriptions across 25 tool files
- Condensed `_SHARED_TOOL_NOTES` and `_build_storage_supplement` in
`prompting.py`
- Fixed `run_time` schema description in `agent_output.py`

### Checklist

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] All 273 copilot tests pass locally
  - [x] All 35 tools load and produce valid schemas
  - [x] Before/after token dumps compared
  - [x] Formatting passes (`poetry run format`)
  - [x] CI green
2026-03-23 08:27:24 +00:00
Zamil Majdy
f35210761c feat(devops): add /pr-test skill + subscription mode auto-provisioning (#12507)
## Summary
- Adds `/pr-test` skill for automated E2E testing of PRs using docker
compose, agent-browser, and API calls
- Covers full environment setup (copy .env, configure copilot auth,
ARM64 Docker fix)
- Includes browser UI testing, direct API testing, screenshot capture,
and test report generation
- Has `--fix` mode for auto-fixing bugs found during testing (similar to
`/pr-address`)
- **Screenshot uploads use GitHub Git API** (blobs → tree → commit →
ref) — no local git operations, safe for worktrees
- **Subscription mode improvements:**
- Extract subscription auth logic to `sdk/subscription.py` — uses SDK's
bundled CLI binary instead of requiring `npm install -g
@anthropic-ai/claude-code`
- Auto-provision `~/.claude/.credentials.json` from
`CLAUDE_CODE_OAUTH_TOKEN` env var on container startup — no `claude
login` needed in Docker
- Add `scripts/refresh_claude_token.sh` — cross-platform helper
(macOS/Linux/Windows) to extract OAuth tokens from host and update
`backend/.env`

## Test plan
- [x] Validated skill on multiple PRs (#12482, #12483, #12499, #12500,
#12501, #12440, #12472) — all test scenarios passed
- [x] Confirmed screenshot upload via GitHub Git API renders correctly
on all 7 PRs
- [x] Verified subscription mode E2E in Docker:
`refresh_claude_token.sh` → `docker compose up` → copilot chat responds
correctly with no API keys (pure OAuth subscription)
- [x] Verified auto-provisioning of credentials file inside container
from `CLAUDE_CODE_OAUTH_TOKEN` env var
- [x] Confirmed bundled CLI detection
(`claude_agent_sdk._bundled/claude`) works without system-installed
`claude`
- [x] `poetry run pytest backend/copilot/sdk/service_test.py` — 24/24
tests pass
2026-03-23 15:29:00 +07:00
Zamil Majdy
1ebcf85669 fix(platform): resolve 5 production Sentry alerts (#12496)
## Summary

Fixes 5 high-priority Sentry alerts from production:

- **AUTOGPT-SERVER-8AM**: Fix `TypeError: TypedDict does not support
instance and class checks` — `_value_satisfies_type` in `type.py` now
handles TypedDict classes that don't support `isinstance()` checks
- **AUTOGPT-SERVER-8AN**: Fix `ValueError: No payment method found`
triggering Sentry error — catch the expected ValueError in the
auto-top-up endpoint and return HTTP 422 instead
- **BUILDER-7F5**: Fix `Upload failed (409): File already exists` — add
`overwrite` query param to workspace upload endpoint and set it to
`true` from the frontend direct-upload
- **BUILDER-7F0**: Fix `LaTeX-incompatible input` KaTeX warnings
flooding Sentry — set `strict: false` on rehype-katex plugin to suppress
warnings for unrecognized Unicode characters
- **AUTOGPT-SERVER-89N**: Fix `Tool execution with manager failed:
validation error for dict[str,list[any]]` — make RPC return type
validation resilient (log warning instead of crash) and downgrade
SmartDecisionMaker tool execution errors to warnings

## Test plan
- [ ] Verify TypedDict type coercion works for
GithubMultiFileCommitBlock inputs
- [ ] Verify auto-top-up without payment method returns 422, not 500
- [ ] Verify file re-upload in copilot succeeds (overwrites instead of
409)
- [ ] Verify LaTeX rendering with Unicode characters doesn't produce
console warnings
- [ ] Verify SmartDecisionMaker tool execution failures are logged at
warning level
2026-03-23 08:05:08 +00:00
Otto
ab7c38bda7 fix(frontend): detect closed OAuth popup and allow dismissing waiting modal (#12443)
Requested by @kcze

When a user closes the OAuth sign-in popup without completing
authentication, the 'Waiting on sign-in process' modal was stuck open
with no way to dismiss it, forcing a page refresh.

Two bugs caused this:

1. `oauth-popup.ts` had no detection for the popup being closed by the
user. The promise would hang until the 5-minute timeout.

2. The modal's cancel button aborted a disconnected `AbortController`
instead of the actual OAuth flow's abort function, so clicking
cancel/close did nothing.

### Changes

- Add `popup.closed` polling (500ms) in `openOAuthPopup()` that rejects
the promise when the user closes the auth window
- Add reject-on-abort so the cancel button properly terminates the flow
- Replace the disconnected `oAuthPopupController` with a direct
`cancelOAuthFlow()` function that calls the real abort ref
- Handle popup-closed and user-canceled as silent cancellations (no
error toast)

### Testing

Tested manually 
- [x] Start OAuth flow → close popup window → modal dismisses
automatically 
- [x] Start OAuth flow → click cancel on modal → popup closes, modal
dismisses 
- [x] Complete OAuth flow normally → works as before 

Resolves SECRT-2054

---
Co-authored-by: Krzysztof Czerwinski (@kcze)
<krzysztof.czerwinski@agpt.co>

---------

Co-authored-by: Krzysztof Czerwinski <kpczerwinski@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-20 14:41:09 +00:00
Ubbe
0f67e45d05 hotfix(marketplace): adjust card height overflow (#12497)
## Summary

### Before

<img width="500" height="501" alt="Screenshot 2026-03-20 at 21 50 31"
src="https://github.com/user-attachments/assets/6154cffb-6772-4c3d-a703-527c8ca0daff"
/>

### After

<img width="500" height="581" alt="Screenshot 2026-03-20 at 21 33 12"
src="https://github.com/user-attachments/assets/2f9bd69d-30c5-4d06-ad1e-ed76b184afe5"
/>

### Other minor fixes

- minor spacing adjustments in creator/search pages when empty and
between sections


### Summary

- Increase StoreCard height from 25rem to 26.5rem to prevent content
overflow
- Replace manual tooltip-based title truncation with `OverflowText`
component in StoreCard
- Adjust carousel indicator positioning and hide it on md+ when exactly
3 featured agents are shown

## Test plan
- [x] Verify marketplace cards display without text overflow
- [x] Verify featured section carousel indicators behave correctly
- [x] Check responsive behavior at common breakpoints

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-20 22:03:28 +08:00
Ubbe
b9ce37600e refactor(frontend/marketplace): move download below Add to library with contextual text (#12486)
## Summary

<img width="1487" height="670" alt="Screenshot 2026-03-20 at 00 52 58"
src="https://github.com/user-attachments/assets/f09de2a0-3c5b-4bce-b6f4-8a853f6792cf"
/>


- Move the download button from inline next to "Add to library" to a
separate line below it
- Add contextual text: "Want to use this agent locally? Download here"
- Style the "Download here" as a violet ghost button link with the
download icon

## Test plan
- [ ] Visit a marketplace agent page
- [ ] Verify "Add to library" button renders in its row
- [ ] Verify "Want to use this agent locally? Download here" appears
below it
- [ ] Click "Download here" and confirm the agent downloads correctly

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-20 13:13:59 +00:00
Otto
3921deaef1 fix(frontend): truncate marketplace card description to 2 lines (#12494)
Reduces `line-clamp` from 3 to 2 on the marketplace `StoreCard`
description to prevent text from overlapping with the
absolutely-positioned run count and +Add button at the bottom of the
card.

Resolves SECRT-2156.

---
Co-authored-by: Abhimanyu Yadav (@Abhi1992002)
<122007096+Abhi1992002@users.noreply.github.com>
2026-03-20 09:10:21 +00:00
Nicholas Tindle
f01f668674 fix(backend): support Responses API in SmartDecisionMakerBlock (#12489)
## Summary

- Fixes SmartDecisionMakerBlock conversation management to work with
OpenAI's Responses API, which was introduced in #12099 (commit 1240f38)
- The migration to `responses.create` updated the outbound LLM call but
missed the conversation history serialization — the `raw_response` is
now the entire `Response` object (not a `ChatCompletionMessage`), and
tool calls/results use `function_call` / `function_call_output` types
instead of role-based messages
- This caused a 400 error on the second LLM call in agent mode:
`"Invalid value: ''. Supported values are: 'assistant', 'system',
'developer', and 'user'."`

### Changes

**`smart_decision_maker.py`** — 6 functions updated:
| Function | Fix |
|---|---|
| `_convert_raw_response_to_dict` | Detects Responses API `Response`
objects, extracts output items as a list |
| `_get_tool_requests` | Recognizes `type: "function_call"` items |
| `_get_tool_responses` | Recognizes `type: "function_call_output"`
items |
| `_create_tool_response` | New `responses_api` kwarg produces
`function_call_output` format |
| `_update_conversation` | Handles list return from
`_convert_raw_response_to_dict` |
| Non-agent mode path | Same list handling for traditional execution |

**`test_smart_decision_maker_responses_api.py`** — 61 tests covering:
- Every branch of all 6 affected helper functions
- Chat Completions, Anthropic, and Responses API formats
- End-to-end agent mode and traditional mode conversation validity

## Test plan

- [x] 61 new unit tests all pass
- [x] 11 existing SmartDecisionMakerBlock tests still pass (no
regressions)
- [x] All pre-commit hooks pass (ruff, black, isort, pyright)
- [ ] CI integration tests

🤖 Generated with [Claude Code](https://claude.com/claude-code)

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> Updates core LLM invocation and agent conversation/tool-call
bookkeeping to match OpenAI’s Responses API, which can affect tool
execution loops and prompt serialization across providers. Risk is
mitigated by extensive new unit tests, but regressions could surface in
production agent-mode flows or token/usage accounting.
> 
> **Overview**
> **Migrates OpenAI calls from Chat Completions to the Responses API
end-to-end**, including tool schema conversion, output parsing,
reasoning/text extraction, and updated token usage fields in
`LLMResponse`.
> 
> **Fixes SmartDecisionMakerBlock conversation/tool handling for
Responses API** by treating `raw_response` as a Response object
(splitting it into `output` items for replay), recognizing
`function_call`/`function_call_output` entries, and emitting tool
outputs in the correct Responses format to prevent invalid follow-up
prompts.
> 
> Also adjusts prompt compaction/token estimation to understand
Responses API tool items, changes
`get_execution_outputs_by_node_exec_id` to return list-valued
`CompletedBlockOutput`, removes `gpt-3.5-turbo` from model/cost/docs
lists, and adds focused unit tests plus a lightweight `conftest.py` to
run these tests without the full server stack.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
ff292efd3d. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Otto <otto@agpt.co>
Co-authored-by: Krzysztof Czerwinski <kpczerwinski@gmail.com>
2026-03-20 03:23:52 +00:00
Otto
f7a3491f91 docs(platform): add TDD guidance to CLAUDE.md files (#12491)
Requested by @majdyz

Adds TDD (test-driven development) guidance to CLAUDE.md files so Claude
Code follows a test-first workflow when fixing bugs or adding features.

**Changes:**
- **Parent `CLAUDE.md`**: Cross-cutting TDD workflow — write a failing
`xfail` test, implement the fix, remove the marker
- **Backend `CLAUDE.md`**: Concrete pytest example with
`@pytest.mark.xfail` pattern
- **Frontend `CLAUDE.md`**: Note about using Playwright `.fixme`
annotation for bug-fix tests

The workflow is: write a failing test first → confirm it fails for the
right reason → implement → confirm it passes. This ensures every bug fix
is covered by a test that would have caught the regression.

---
Co-authored-by: Zamil Majdy (@majdyz) <zamil.majdy@agpt.co>
2026-03-20 02:13:16 +00:00
Nicholas Tindle
cbff3b53d3 Revert "feat(backend): migrate OpenAI provider to Responses API" (#12490)
Reverts Significant-Gravitas/AutoGPT#12099

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> Reverts the OpenAI integration in `llm_call` from the Responses API
back to `chat.completions`, which can change tool-calling, JSON-mode
behavior, and token accounting across core AI blocks. The change is
localized but touches the primary LLM execution path and associated
tests/docs.
> 
> **Overview**
> Reverts the OpenAI path in `backend/blocks/llm.py` from the Responses
API back to `chat.completions`, including updating JSON-mode
(`response_format`), tool handling, and usage extraction to match the
Chat Completions response shape.
> 
> Removes the now-unused `backend/util/openai_responses.py` helpers and
their unit tests, updates LLM tests to mock `chat.completions.create`,
and adds `gpt-3.5-turbo` to the supported model list, cost config, and
LLM docs.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
7d6226d10e. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->
2026-03-20 01:51:56 +00:00
Reinier van der Leer
5b9a4c52c9 revert(platform): Revert invite system (#12485)
## Summary

Reverts the invite system PRs due to security gaps identified during
review:

- The move from Supabase-native `allowed_users` gating to
application-level gating allows orphaned Supabase auth accounts (valid
JWT without a platform `User`)
- The auth middleware never verifies `User` existence, so orphaned users
get 500s instead of clean 403s
- OAuth/Google SSO signup completely bypasses the invite gate
- The DB trigger that atomically created `User` + `Profile` on signup
was dropped in favor of a client-initiated API call, introducing a
failure window

### Reverted PRs
- Reverts #12347 — Foundation: InvitedUser model, invite-gated signup,
admin UI
- Reverts #12374 — Tally enrichment: personalized prompts from form
submissions
- Reverts #12451 — Pre-check: POST /auth/check-invite endpoint
- Reverts #12452 (collateral) — Themed prompt categories /
SuggestionThemes UI. This PR built on top of #12374's
`suggested_prompts` backend field and `/chat/suggested-prompts`
endpoint, so it cannot remain without #12374. The copilot empty session
falls back to hardcoded default prompts.

### Migration
Includes a new migration (`20260319120000_revert_invite_system`) that:
- Drops the `InvitedUser` table and its enums (`InvitedUserStatus`,
`TallyComputationStatus`)
- Restores the `add_user_and_profile_to_platform()` trigger on
`auth.users`
- Backfills `User` + `Profile` rows for any auth accounts created during
the invite-gate window

### What's NOT reverted
- The `generate_username()` function (never dropped, still used by
backfill migration)
- The old `add_user_to_platform()` function (superseded by
`add_user_and_profile_to_platform()`)
- PR #12471 (admin UX improvements) — was never merged, no action needed

## Test plan
- [x] Verify migration: `InvitedUser` table dropped, enums dropped,
trigger restored
- [x] Verify backfill: no orphaned auth users, no users without Profile
- [x] Verify existing users can still log in (email + OAuth)
- [x] Verify CoPilot chat page loads with default prompts
- [ ] Verify new user signup creates `User` + `Profile` via the restored
trigger
- [ ] Verify admin `/admin/users` page loads without crashing
- [ ] Run backend tests: `poetry run test`

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-03-19 17:15:30 +00:00
Otto
0ce1c90b55 fix(frontend): rename "CoPilot" to "AutoPilot" on credits page (#12481)
Requested by @kcze

Renames "CoPilot" → "AutoPilot" on the credits/usage limits page:

- **Heading:** "CoPilot Usage Limits" → "AutoPilot Usage Limits"
- **Button:** "Open CoPilot" → "Open AutoPilot"
- Comment updated to match

---
Co-authored-by: Zamil Majdy (@majdyz) <zamil.majdy@agpt.co>

Co-authored-by: Zamil Majdy (@majdyz) <zamil.majdy@agpt.co>
2026-03-19 15:25:21 +00:00
Ubbe
d4c6eb9adc fix(frontend): collapse navbar text to icons below 1280px (#12484)
## Summary

<img width="400" height="339" alt="Screenshot 2026-03-19 at 22 53 23"
src="https://github.com/user-attachments/assets/2fa76b8f-424d-4764-90ac-b7a331f5f610"
/>

<img width="600" height="595" alt="Screenshot 2026-03-19 at 22 53 31"
src="https://github.com/user-attachments/assets/23f51cc7-b01e-4d83-97ba-2c43683877db"
/>

<img width="800" height="523" alt="Screenshot 2026-03-19 at 22 53 36"
src="https://github.com/user-attachments/assets/1e447b9a-1cca-428c-bccd-1730f1670b8e"
/>

Now that we have the `Give feedback` button on the Navigation bar,
collpase some of the links below `1280px` so there is more space and
they don't collide with each other...

- Collapse navbar link text to icon-only below 1280px (`xl` breakpoint)
to prevent crowding
- Wallet button shows only the wallet icon below 1280px instead of "Earn
credits" text
- Feedback button shows only the chat icon below 1280px instead of "Give
Feedback" text
- Added `whitespace-nowrap` to feedback button to prevent wrapping

## Changes
- `NavbarLink.tsx`: `lg:block` → `xl:block` for link text
- `Wallet.tsx`: `md:hidden`/`md:inline-block` →
`xl:hidden`/`xl:inline-block`
- `FeedbackButton.tsx`: wrap text in `hidden xl:inline` span, add
`whitespace-nowrap`

## Test plan
- [ ] Resize browser between 1024px–1280px and verify navbar shows only
icons
- [ ] At 1280px+ verify full text labels appear for links, wallet, and
feedback
- [ ] Verify mobile navbar still works correctly below `md` breakpoint

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-19 15:10:27 +00:00
Ubbe
1bb91b53b7 fix(frontend/marketplace): comprehensive marketplace UI redesign (#12462)
## Summary

<img width="600" height="964" alt="Screenshot_2026-03-19_at_00 07 52"
src="https://github.com/user-attachments/assets/95c0430a-26a3-499b-8f6a-25b9715d3012"
/>
<img width="600" height="968" alt="Screenshot_2026-03-19_at_00 08 01"
src="https://github.com/user-attachments/assets/d440c3b0-c247-4f13-bf82-a51ff2e50902"
/>
<img width="600" height="939" alt="Screenshot_2026-03-19_at_00 08 14"
src="https://github.com/user-attachments/assets/f19be759-e102-4a95-9474-64f18bce60cf"
/>"
<img width="600" height="953" alt="Screenshot_2026-03-19_at_00 08 24"
src="https://github.com/user-attachments/assets/ba4fa644-3958-45e2-89e9-a6a4448c63c5"
/>



- Re-style and re-skin the Marketplace pages to look more "professional"
...
- Move the `Give feedback` button to the header

## Test plan
- [x] Verify marketplace page search bar matches Form text field styling
- [x] Verify agent cards have padding and subtle border
- [x] Verify hover/focus states work correctly
- [x] Check responsive behavior at different breakpoints

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-19 22:28:01 +08:00
Bentlybro
64a011664a fix(schema): address Majdyz review feedback
- Add FK constraints on LlmModelMigration (sourceModelSlug, targetModelSlug → LlmModel.slug)
- Remove unused @@index([credentialProvider]) on LlmModelCost
- Remove redundant @@index([isReverted]) on LlmModelMigration (covered by composite)
- Add documentation for credentialProvider field explaining its purpose
- Add reverse relation fields to LlmModel (SourceMigrations, TargetMigrations)

Fixes data integrity: typos in migration slugs now caught at DB level.
2026-03-19 11:01:09 +00:00
Bentlybro
1db7c048d9 fix: isort import order 2026-03-19 11:01:09 +00:00
Bentlybro
4c5627c966 fix: use execute_raw_with_schema for proper multi-schema support
Per Sentry feedback: db.execute_raw ignores connection string's ?schema=
parameter and defaults to 'public' schema. This breaks in multi-schema setups.

Changes:
- Import execute_raw_with_schema from .db
- Use {schema_prefix} placeholder in query
- Call execute_raw_with_schema instead of db.execute_raw

This matches the pattern used in fix_llm_provider_credentials and other
schema-aware migrations. Works in both CI (public schema) and local
(platform schema from connection string).
2026-03-19 11:01:09 +00:00
Bentlybro
d97d137a51 fix: remove hardcoded schema prefix from migrate_llm_models query
The raw SQL query in migrate_llm_models() hardcoded platform."AgentNode"
which fails in CI where tables are in 'public' schema (not 'platform').

This code exists in dev but only runs when LLM registry has data. With our
new schema, the migration tries to run at startup and fails in CI.

Changed: UPDATE platform."AgentNode" -> UPDATE "AgentNode"

Matches pattern of all other migrations - let connection string's default
schema handle routing.
2026-03-19 11:01:09 +00:00
Bentlybro
ded9e293ff fix: remove CREATE SCHEMA to match CI environment
CI uses schema "public" as default (not "platform"), so creating
a platform schema then tables without prefix puts tables in public
but Prisma looks in platform.

Existing migrations don't create schema - they rely on connection
string's default. Remove CREATE SCHEMA IF NOT EXISTS to match.
2026-03-19 11:01:09 +00:00
Bentlybro
83d504bed2 fix: remove schema prefix from migration SQL to match existing pattern
CI failing with 'relation "platform.AgentNode" does not exist' because
Prisma generates queries differently when tables are created with
explicit schema prefixes.

Existing AutoGPT migrations use:
  CREATE TABLE "AgentNode" (...)

Not:
  CREATE TABLE "platform"."AgentNode" (...)

The connection string's ?schema=platform handles schema selection,
so explicit prefixes aren't needed and cause compatibility issues.

Changes:
- Remove all "platform". prefixes from:
  * CREATE TYPE statements
  * CREATE TABLE statements
  * CREATE INDEX statements
  * ALTER TABLE statements
  * REFERENCES clauses in foreign keys

Now matches existing migration pattern exactly.
2026-03-19 11:01:09 +00:00
Bentlybro
a5f1ffb35b fix: add partial unique indexes for data integrity
Per CodeRabbit feedback - fix 2 actual bugs:

1. Prevent multiple active migrations per source model
   - Add partial unique index: UNIQUE (sourceModelSlug) WHERE isReverted = false
   - Prevents ambiguous routing when resolving migrations

2. Allow both default and credential-specific costs
   - Remove @@unique([llmModelId, credentialProvider, unit])
   - Add 2 partial unique indexes:
     * UNIQUE (llmModelId, provider, unit) WHERE credentialId IS NULL (defaults)
     * UNIQUE (llmModelId, provider, credentialId, unit) WHERE credentialId IS NOT NULL (overrides)
   - Enables provider-level default costs + per-credential overrides

Schema comments document that these constraints exist in migration SQL.
2026-03-19 11:01:09 +00:00
Bentlybro
97c6516a14 fix: remove multiSchema - follow existing AutoGPT pattern
Remove unnecessary multiSchema configuration that broke existing models.

AutoGPT uses connection string's ?schema=platform parameter as default,
not Prisma's multiSchema feature. Existing models (User, AgentGraph, etc.)
have no @@schema() directives and work fine.

Changes:
- Remove schemas = ["platform", "public"] from datasource
- Remove "multiSchema" from previewFeatures
- Remove all @@schema() directives from LLM models and enum

Migration SQL already creates tables in platform schema explicitly
(CREATE TABLE "platform"."LlmProvider" etc.) which is correct.

This matches the existing pattern used throughout the codebase.
2026-03-19 11:01:09 +00:00
Bentlybro
876dde8bc7 fix: address CodeRabbit design feedback
Per CodeRabbit review:

1. **Safety: Change capability defaults false → safer for partial seeding**
   - supportsTools: true → false
   - supportsJsonOutput: true → false
   - Prevents partially-seeded rows from being assumed capable

2. **Clarity: Rename supportsParallelTool → supportsParallelToolCalls**
   - More explicit about what the field represents

3. **Performance: Remove redundant indexes**
   - Drop @@index([llmModelId]) - covered by unique constraint
   - Drop @@index([sourceModelSlug]) - covered by composite index
   - Reduces write overhead and storage

4. **Documentation: Acknowledge customCreditCost limitation**
   - It's unit-agnostic (doesn't distinguish RUN vs TOKENS)
   - Noted as TODO for follow-up PR with proper unit-aware override

Schema + migration both updated to match.
2026-03-19 11:01:09 +00:00
Bentlybro
0bfdd74b25 fix: add @@schema("platform") to LlmCostUnit enum
Sentry caught this - enums also need @@schema directive with multiSchema enabled.
Without it, Prisma looks for enum in public schema but it's created in platform.
2026-03-19 11:01:09 +00:00
Bentlybro
a7d2f81b18 feat: add database CHECK constraints for data integrity
Per CodeRabbit feedback - enforce numeric domain rules at DB level:

Migration:
- priceTier: CHECK (priceTier BETWEEN 1 AND 3)
- creditCost: CHECK (creditCost >= 0)
- nodeCount: CHECK (nodeCount >= 0)
- customCreditCost: CHECK (customCreditCost IS NULL OR customCreditCost >= 0)

Schema comments:
- Document constraints inline for developer visibility

Prevents invalid data (negative costs, out-of-range tiers) from
entering the database, matching backend/blocks/llm.py contract.
2026-03-19 11:01:09 +00:00
Bentlybro
3699eaa556 fix: use @@schema() instead of @@map() for platform schema + create schema in migration
Critical fixes from PR review:

1. Replace @@map("platform.ModelName") with @@schema("platform")
   - Sentry correctly identified: Prisma was looking for literal table "platform.LlmProvider" with dot
   - Proper syntax: enable multiSchema feature + use @@schema directive

2. Create platform schema in migration
   - CI failed: schema "platform" does not exist
   - Add CREATE SCHEMA IF NOT EXISTS at start of migration

Schema changes:
- datasource: add schemas = ["platform", "public"]
- generator: add "multiSchema" to previewFeatures
- All 5 models: @@map() → @@schema("platform")

Migration changes:
- Add CREATE SCHEMA IF NOT EXISTS "platform" before enum creation

Fixes CI failure and Sentry-identified bug.
2026-03-19 11:01:09 +00:00
Bentlybro
21adf9e0fb feat(platform): Add LLM registry database schema
Add Prisma schema and migration for dynamic LLM model registry:

Schema additions:
- LlmProvider: Registry of LLM providers (OpenAI, Anthropic, etc.)
- LlmModel: Individual models with capabilities and metadata
- LlmModelCost: Per-model pricing configuration
- LlmModelCreator: Model creators/trainers (OpenAI, Meta, etc.)
- LlmModelMigration: Track model migrations and reverts
- LlmCostUnit enum: RUN vs TOKENS pricing units

Key features:
- Model-specific capabilities (tools, JSON, reasoning, parallel calls)
- Flexible creator/provider separation (e.g., Meta model via Hugging Face)
- Migration tracking with custom pricing overrides
- Indexes for performance on common queries

Part 1 of incremental LLM registry implementation.
Refs: Draft PR #11699
2026-03-19 11:01:08 +00:00
Ubbe
a5f9c43a41 feat(platform): replace suggestion pills with themed prompt categories (#12452)
## Summary



https://github.com/user-attachments/assets/13da6d36-5f35-429b-a6cf-e18316bb8709



Replaces the flat list of suggestion pills in the CoPilot empty session
with themed prompt categories (Learn, Create, Automate, Organize), each
shown as a popover with contextual prompts.

- **Backend**: Changes `suggested_prompts` from a flat `list[str]` to a
themed `dict[str, list[str]]` keyed by category. Updates Tally
extraction LLM prompt to generate prompts per theme, and the
`/suggested-prompts` API to return grouped themes. Legacy `list[str]`
rows are preserved under a `"General"` key for backward compatibility.
- **Frontend**: Replaces inline pill buttons with a `SuggestionThemes`
popover component. Each theme button (with icon) opens a dropdown of 5
relevant prompts. Falls back to hardcoded defaults when the API has no
personalized prompts. Normalizes partial API responses by padding
missing themes with defaults. Legacy `"General"` prompts are distributed
round-robin across themes so existing users keep their personalized
suggestions.

### Changes 🏗️

- `backend/data/understanding.py`: `suggested_prompts` field changed
from `list[str]` to `dict[str, list[str]]`; legacy list rows preserved
under `"General"` key; list items validated as strings
- `backend/data/tally.py`: LLM prompt updated to generate themed
prompts; validation now per-theme with blank-string rejection
- `backend/api/features/chat/routes.py`: New `SuggestedTheme` model;
endpoint returns `themes[]`
- `frontend/copilot/components/EmptySession/EmptySession.tsx`: Uses
generated API types directly (no cast)
- `frontend/copilot/components/EmptySession/helpers.ts`:
`DEFAULT_THEMES` replaces `DEFAULT_QUICK_ACTIONS`; `getSuggestionThemes`
normalizes partial API responses and distributes legacy `"General"`
prompts across themes
-
`frontend/copilot/components/EmptySession/components/SuggestionThemes/`:
New popover component with theme icons and loading states

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verify themed suggestion buttons render on CoPilot empty session
  - [x] Click each theme button and confirm popover opens with prompts
  - [x] Click a prompt and confirm it sends the message
- [x] Verify fallback to default themes when API returns no custom
prompts
- [x] Verify legacy users' personalized prompts are preserved and
visible


🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-19 18:46:12 +08:00
Otto
1240f38f75 feat(backend): migrate OpenAI provider to Responses API (#12099)
## Summary

Migrates the OpenAI provider in the LLM block from
`chat.completions.create` to `responses.create` — OpenAI's newer,
unified API. Also removes the obsolete GPT-3.5-turbo model.

Resolves #11624
Linear:
[OPEN-2911](https://linear.app/autogpt/issue/OPEN-2911/update-openai-calls-to-use-responsescreate)

## Changes

- **`backend/blocks/llm.py`** — OpenAI provider now uses
`responses.create` exclusively. Removed GPT-3.5-turbo enum + metadata.
- **`backend/util/openai_responses.py`** *(new)* — Helpers for the
Responses API: tool format conversion, content/reasoning/usage/tool-call
extraction.
- **`backend/util/openai_responses_test.py`** *(new)* — Unit tests for
all helper functions.
- **`backend/data/block_cost_config.py`** — Removed GPT-3.5 cost entry.
- **`docs/integrations/block-integrations/llm.md`** — Regenerated block
docs.

## Key API differences handled

| Aspect | Chat Completions | Responses API |
|--------|-----------------|---------------|
| Messages param | `messages` | `input` |
| Max tokens param | `max_completion_tokens` | `max_output_tokens` |
| Usage fields | `prompt_tokens` / `completion_tokens` | `input_tokens`
/ `output_tokens` |
| Tool format | Nested under `function` key | Flat structure |

## Test plan

- [x] Unit tests for all `openai_responses.py` helpers
- [x] Existing LLM block tests updated for Responses API mocks
- [x] Regular OpenAI models work
- [x] Reasoning OpenAI models work
- [x] Non-OpenAI models work

---------

Co-authored-by: Krzysztof Czerwinski <kpczerwinski@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-19 09:19:31 +00:00
Zamil Majdy
f617f50f0b dx(skills): improve pr-address skill — full thread context + PR description backtick fix (#12480)
## Summary

Improves the `pr-address` skill with two fixes:

- **Full comment thread loading**: Adds `--paginate` to the inline
comments fetch and explicit instructions to reconstruct threads using
`in_reply_to_id`, reading root-to-last-reply before acting. Previously,
only the opening comment was visible — missing reviewer replies led to
wrong fixes.
- **Backtick-safe PR descriptions**: Adds instructions to write the PR
body to a temp file via `<<'PREOF'` heredoc before passing to `gh pr
edit/create`. Inlining the body directly causes backticks to be
shell-escaped, breaking markdown rendering.

## Test plan
- [ ] Run `/pr-address` on a PR with multi-reply inline comment threads
— verify the last reply is what gets acted on
- [ ] Update a PR description containing backticks — verify they render
correctly in GitHub
2026-03-19 15:11:14 +07:00
161 changed files with 10736 additions and 5150 deletions

View File

@@ -19,16 +19,60 @@ gh pr view {N}
## Fetch comments (all sources)
### 1. Inline review threads — GraphQL (primary source of actionable items)
Use GraphQL to fetch inline threads. It natively exposes `isResolved`, returns threads already grouped with all replies, and paginates via cursor — no manual thread reconstruction needed.
```bash
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews # top-level reviews
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments # inline review comments
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments # PR conversation comments
gh api graphql -f query='
{
repository(owner: "Significant-Gravitas", name: "AutoGPT") {
pullRequest(number: {N}) {
reviewThreads(first: 100) {
pageInfo { hasNextPage endCursor }
nodes {
id
isResolved
path
comments(last: 1) {
nodes { databaseId body author { login } createdAt }
}
}
}
}
}
}'
```
**Bots to watch for:**
- `autogpt-reviewer` — posts "Blockers", "Should Fix", "Nice to Have". Address ALL of them.
- `sentry[bot]` — bug predictions. Fix real bugs, explain false positives.
- `coderabbitai[bot]` — automated review. Address actionable items.
If `pageInfo.hasNextPage` is true, fetch subsequent pages by adding `after: "<endCursor>"` to `reviewThreads(first: 100, after: "...")` and repeat until `hasNextPage` is false.
**Filter to unresolved threads only** — skip any thread where `isResolved: true`. `comments(last: 1)` returns the most recent comment in the thread — act on that; it reflects the reviewer's final ask. Use the thread `id` (Relay global ID) to track threads across polls.
### 2. Top-level reviews — REST (MUST paginate)
```bash
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate
```
**CRITICAL — always `--paginate`.** Reviews default to 30 per page. PRs can have 80170+ reviews (mostly empty resolution events). Without pagination you miss reviews past position 30 — including `autogpt-reviewer`'s structured review which is typically posted after several CI runs and sits well beyond the first page.
Two things to extract:
- **Overall state**: look for `CHANGES_REQUESTED` or `APPROVED` reviews.
- **Actionable feedback**: non-empty bodies only. Empty-body reviews are thread-resolution events — they indicate progress but have no feedback to act on.
**Where each reviewer posts:**
- `autogpt-reviewer` — posts detailed structured reviews ("Blockers", "Should Fix", "Nice to Have") as **top-level reviews**. Not present on every PR. Address ALL items.
- `sentry[bot]` — posts bug predictions as **inline threads**. Fix real bugs, explain false positives.
- `coderabbitai[bot]` — posts summaries as **top-level reviews** AND actionable items as **inline threads**. Address actionable items.
- Human reviewers — can post in any source. Address ALL non-empty feedback.
### 3. PR conversation comments — REST
```bash
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments --paginate
```
Mostly contains: bot summaries (`coderabbitai[bot]`), CI/conflict detection (`github-actions[bot]`), and author status updates. Scan for non-empty messages from non-bot human reviewers that aren't the PR author — those are the ones that need a response.
## For each unaddressed comment
@@ -94,13 +138,23 @@ gh pr view {N} --repo Significant-Gravitas/AutoGPT --json mergeable --jq '.merge
```
If the result is `"CONFLICTING"`, the PR has a merge conflict — see "Resolving merge conflicts" below. If `"UNKNOWN"`, GitHub is still computing mergeability — wait and re-check next poll.
3. Check for new comments (all three sources):
```bash
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments # inline review comments
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments # PR conversation comments
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews # top-level reviews
```
Compare against previously seen comments to detect new ones.
3. Check for new/changed comments (all three sources):
**Inline threads** — re-run the GraphQL query from "Fetch comments". For each unresolved thread, record `{thread_id, last_comment_databaseId}` as your baseline. On each poll, action is needed if:
- A new thread `id` appears that wasn't in the baseline (new thread), OR
- An existing thread's `last_comment_databaseId` has changed (new reply on existing thread)
**Conversation comments:**
```bash
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments --paginate
```
Compare total count and newest `id` against baseline. Filter to non-empty, non-bot, non-author-update messages.
**Top-level reviews:**
```bash
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate
```
Watch for new non-empty reviews (`CHANGES_REQUESTED` or `COMMENTED` with body). Compare total count and newest `id` against baseline.
4. **React in this precedence order (first match wins):**

View File

@@ -28,7 +28,7 @@ gh pr diff {N}
Before posting anything, fetch existing inline comments to avoid duplicates:
```bash
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews
```

View File

@@ -0,0 +1,534 @@
---
name: pr-test
description: "E2E manual testing of PRs/branches using docker compose, agent-browser, and API calls. TRIGGER when user asks to manually test a PR, test a feature end-to-end, or run integration tests against a running system."
user-invocable: true
argument-hint: "[worktree path or PR number] — tests the PR in the given worktree. Optional flags: --fix (auto-fix issues found)"
metadata:
author: autogpt-team
version: "1.0.0"
---
# Manual E2E Test
Test a PR/branch end-to-end by building the full platform, interacting via browser and API, capturing screenshots, and reporting results.
## Arguments
- `$ARGUMENTS` — worktree path (e.g. `$REPO_ROOT`) or PR number
- If `--fix` flag is present, auto-fix bugs found and push fixes (like pr-address loop)
## Step 0: Resolve the target
```bash
# If argument is a PR number, find its worktree
gh pr view {N} --json headRefName --jq '.headRefName'
# If argument is a path, use it directly
```
Determine:
- `REPO_ROOT` — the root repo directory: `git -C "$WORKTREE_PATH" worktree list | head -1 | awk '{print $1}'` (or `git rev-parse --show-toplevel` if not a worktree)
- `WORKTREE_PATH` — the worktree directory
- `PLATFORM_DIR``$WORKTREE_PATH/autogpt_platform`
- `BACKEND_DIR``$PLATFORM_DIR/backend`
- `FRONTEND_DIR``$PLATFORM_DIR/frontend`
- `PR_NUMBER` — the PR number (from `gh pr list --head $(git branch --show-current)`)
- `PR_TITLE` — the PR title, slugified (e.g. "Add copilot permissions" → "add-copilot-permissions")
- `RESULTS_DIR``$REPO_ROOT/test-results/PR-{PR_NUMBER}-{slugified-title}`
Create the results directory:
```bash
PR_NUMBER=$(cd $WORKTREE_PATH && gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT --json number --jq '.[0].number')
PR_TITLE=$(cd $WORKTREE_PATH && gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT --json title --jq '.[0].title' | tr '[:upper:]' '[:lower:]' | sed 's/[^a-z0-9]/-/g' | sed 's/--*/-/g' | sed 's/^-//;s/-$//' | head -c 50)
RESULTS_DIR="$REPO_ROOT/test-results/PR-${PR_NUMBER}-${PR_TITLE}"
mkdir -p $RESULTS_DIR
```
**Test user credentials** (for logging into the UI or verifying results manually):
- Email: `test@test.com`
- Password: `testtest123`
## Step 1: Understand the PR
Before testing, understand what changed:
```bash
cd $WORKTREE_PATH
git log --oneline dev..HEAD | head -20
git diff dev --stat
```
Read the changed files to understand:
1. What feature/fix does this PR implement?
2. What components are affected? (backend, frontend, copilot, executor, etc.)
3. What are the key user-facing behaviors to test?
## Step 2: Write test scenarios
Based on the PR analysis, write a test plan to `$RESULTS_DIR/test-plan.md`:
```markdown
# Test Plan: PR #{N} — {title}
## Scenarios
1. [Scenario name] — [what to verify]
2. ...
## API Tests (if applicable)
1. [Endpoint] — [expected behavior]
## UI Tests (if applicable)
1. [Page/component] — [interaction to test]
## Negative Tests
1. [What should NOT happen]
```
**Be critical** — include edge cases, error paths, and security checks.
## Step 3: Environment setup
### 3a. Copy .env files from the root worktree
The root worktree (`$REPO_ROOT`) has the canonical `.env` files with all API keys. Copy them to the target worktree:
```bash
# CRITICAL: .env files are NOT checked into git. They must be copied manually.
cp $REPO_ROOT/autogpt_platform/.env $PLATFORM_DIR/.env
cp $REPO_ROOT/autogpt_platform/backend/.env $BACKEND_DIR/.env
cp $REPO_ROOT/autogpt_platform/frontend/.env $FRONTEND_DIR/.env
```
### 3b. Configure copilot authentication
The copilot needs an LLM API to function. Two approaches (try subscription first):
#### Option 1: Subscription mode (preferred — uses your Claude Max/Pro subscription)
The `claude_agent_sdk` Python package **bundles its own Claude CLI binary** — no need to install `@anthropic-ai/claude-code` via npm. The backend auto-provisions credentials from environment variables on startup.
Run the helper script to extract tokens from your host and auto-update `backend/.env` (works on macOS, Linux, and Windows/WSL):
```bash
# Extracts OAuth tokens and writes CLAUDE_CODE_OAUTH_TOKEN + CLAUDE_CODE_REFRESH_TOKEN into .env
bash $BACKEND_DIR/scripts/refresh_claude_token.sh --env-file $BACKEND_DIR/.env
```
**How it works:** The script reads the OAuth token from:
- **macOS**: system keychain (`"Claude Code-credentials"`)
- **Linux/WSL**: `~/.claude/.credentials.json`
- **Windows**: `%APPDATA%/claude/.credentials.json`
It sets `CLAUDE_CODE_OAUTH_TOKEN`, `CLAUDE_CODE_REFRESH_TOKEN`, and `CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true` in the `.env` file. On container startup, the backend auto-provisions `~/.claude/.credentials.json` inside the container from these env vars. The SDK's bundled CLI then authenticates using that file. No `claude login`, no npm install needed.
**Note:** The OAuth token expires (~24h). If copilot returns auth errors, re-run the script and restart: `$BACKEND_DIR/scripts/refresh_claude_token.sh --env-file $BACKEND_DIR/.env && docker compose up -d copilot_executor`
#### Option 2: OpenRouter API key mode (fallback)
If subscription mode doesn't work, switch to API key mode using OpenRouter:
```bash
# In $BACKEND_DIR/.env, ensure these are set:
CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=false
CHAT_API_KEY=<value of OPEN_ROUTER_API_KEY from the same .env>
CHAT_BASE_URL=https://openrouter.ai/api/v1
CHAT_USE_CLAUDE_AGENT_SDK=true
```
Use `sed` to update these values:
```bash
ORKEY=$(grep "^OPEN_ROUTER_API_KEY=" $BACKEND_DIR/.env | cut -d= -f2)
[ -n "$ORKEY" ] || { echo "ERROR: OPEN_ROUTER_API_KEY is missing in $BACKEND_DIR/.env"; exit 1; }
perl -i -pe 's/CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true/CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=false/' $BACKEND_DIR/.env
# Add or update CHAT_API_KEY and CHAT_BASE_URL
grep -q "^CHAT_API_KEY=" $BACKEND_DIR/.env && perl -i -pe "s|^CHAT_API_KEY=.*|CHAT_API_KEY=$ORKEY|" $BACKEND_DIR/.env || echo "CHAT_API_KEY=$ORKEY" >> $BACKEND_DIR/.env
grep -q "^CHAT_BASE_URL=" $BACKEND_DIR/.env && perl -i -pe 's|^CHAT_BASE_URL=.*|CHAT_BASE_URL=https://openrouter.ai/api/v1|' $BACKEND_DIR/.env || echo "CHAT_BASE_URL=https://openrouter.ai/api/v1" >> $BACKEND_DIR/.env
```
### 3c. Stop conflicting containers
```bash
# Stop any running app containers (keep infra: supabase, redis, rabbitmq, clamav)
docker ps --format "{{.Names}}" | grep -E "rest_server|executor|copilot|websocket|database_manager|scheduler|notification|frontend|migrate" | while read name; do
docker stop "$name" 2>/dev/null
done
```
### 3e. Build and start
```bash
cd $PLATFORM_DIR && docker compose build --no-cache 2>&1 | tail -20
if [ ${PIPESTATUS[0]} -ne 0 ]; then echo "ERROR: Docker build failed"; exit 1; fi
cd $PLATFORM_DIR && docker compose up -d 2>&1 | tail -20
if [ ${PIPESTATUS[0]} -ne 0 ]; then echo "ERROR: Docker compose up failed"; exit 1; fi
```
**Note:** If the container appears to be running old code (e.g. missing PR changes), use `docker compose build --no-cache` to force a full rebuild. Docker BuildKit may sometimes reuse cached `COPY` layers from a previous build on a different branch.
**Expected time: 3-8 minutes** for build, 5-10 minutes with `--no-cache`.
### 3f. Wait for services to be ready
```bash
# Poll until backend and frontend respond
for i in $(seq 1 60); do
BACKEND=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8006/docs 2>/dev/null)
FRONTEND=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:3000 2>/dev/null)
if [ "$BACKEND" = "200" ] && [ "$FRONTEND" = "200" ]; then
echo "Services ready"
break
fi
sleep 5
done
```
### 3h. Create test user and get auth token
```bash
ANON_KEY=$(grep "NEXT_PUBLIC_SUPABASE_ANON_KEY=" $FRONTEND_DIR/.env | sed 's/.*NEXT_PUBLIC_SUPABASE_ANON_KEY=//' | tr -d '[:space:]')
# Signup (idempotent — returns "User already registered" if exists)
RESULT=$(curl -s -X POST 'http://localhost:8000/auth/v1/signup' \
-H "apikey: $ANON_KEY" \
-H 'Content-Type: application/json' \
-d '{"email":"test@test.com","password":"testtest123"}')
# If "Database error finding user", restart supabase-auth and retry
if echo "$RESULT" | grep -q "Database error"; then
docker restart supabase-auth && sleep 5
curl -s -X POST 'http://localhost:8000/auth/v1/signup' \
-H "apikey: $ANON_KEY" \
-H 'Content-Type: application/json' \
-d '{"email":"test@test.com","password":"testtest123"}'
fi
# Get auth token
TOKEN=$(curl -s -X POST 'http://localhost:8000/auth/v1/token?grant_type=password' \
-H "apikey: $ANON_KEY" \
-H 'Content-Type: application/json' \
-d '{"email":"test@test.com","password":"testtest123"}' | jq -r '.access_token // ""')
```
**Use this token for ALL API calls:**
```bash
curl -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/...
```
## Step 4: Run tests
### Service ports reference
| Service | Port | URL |
|---------|------|-----|
| Frontend | 3000 | http://localhost:3000 |
| Backend REST | 8006 | http://localhost:8006 |
| Supabase Auth (via Kong) | 8000 | http://localhost:8000 |
| Executor | 8002 | http://localhost:8002 |
| Copilot Executor | 8008 | http://localhost:8008 |
| WebSocket | 8001 | http://localhost:8001 |
| Database Manager | 8005 | http://localhost:8005 |
| Redis | 6379 | localhost:6379 |
| RabbitMQ | 5672 | localhost:5672 |
### API testing
Use `curl` with the auth token for backend API tests:
```bash
# Example: List agents
curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/graphs | jq . | head -20
# Example: Create an agent
curl -s -X POST http://localhost:8006/api/graphs \
-H "Authorization: Bearer $TOKEN" \
-H 'Content-Type: application/json' \
-d '{...}' | jq .
# Example: Run an agent
curl -s -X POST "http://localhost:8006/api/graphs/{graph_id}/execute" \
-H "Authorization: Bearer $TOKEN" \
-H 'Content-Type: application/json' \
-d '{"data": {...}}'
# Example: Get execution results
curl -s -H "Authorization: Bearer $TOKEN" \
"http://localhost:8006/api/graphs/{graph_id}/executions/{exec_id}" | jq .
```
### Browser testing with agent-browser
```bash
# Close any existing session
agent-browser close 2>/dev/null || true
# Use --session-name to persist cookies across navigations
# This means login only needs to happen once per test session
agent-browser --session-name pr-test open 'http://localhost:3000/login' --timeout 15000
# Get interactive elements
agent-browser --session-name pr-test snapshot | grep "textbox\|button"
# Login
agent-browser --session-name pr-test fill {email_ref} "test@test.com"
agent-browser --session-name pr-test fill {password_ref} "testtest123"
agent-browser --session-name pr-test click {login_button_ref}
sleep 5
# Dismiss cookie banner if present
agent-browser --session-name pr-test click 'text=Accept All' 2>/dev/null || true
# Navigate — cookies are preserved so login persists
agent-browser --session-name pr-test open 'http://localhost:3000/copilot' --timeout 10000
# Take screenshot
agent-browser --session-name pr-test screenshot $RESULTS_DIR/01-page.png
# Interact with elements
agent-browser --session-name pr-test fill {ref} "text"
agent-browser --session-name pr-test press "Enter"
agent-browser --session-name pr-test click {ref}
agent-browser --session-name pr-test click 'text=Button Text'
# Read page content
agent-browser --session-name pr-test snapshot | grep "text:"
```
**Key pages:**
- `/copilot` — CoPilot chat (for testing copilot features)
- `/build` — Agent builder (for testing block/node features)
- `/build?flowID={id}` — Specific agent in builder
- `/library` — Agent library (for testing listing/import features)
- `/library/agents/{id}` — Agent detail with run history
- `/marketplace` — Marketplace
### Checking logs
```bash
# Backend REST server
docker logs autogpt_platform-rest_server-1 2>&1 | tail -30
# Executor (runs agent graphs)
docker logs autogpt_platform-executor-1 2>&1 | tail -30
# Copilot executor (runs copilot chat sessions)
docker logs autogpt_platform-copilot_executor-1 2>&1 | tail -30
# Frontend
docker logs autogpt_platform-frontend-1 2>&1 | tail -30
# Filter for errors
docker logs autogpt_platform-executor-1 2>&1 | grep -i "error\|exception\|traceback" | tail -20
```
### Copilot chat testing
The copilot uses SSE streaming. To test via API:
```bash
# Create a session
SESSION_ID=$(curl -s -X POST 'http://localhost:8006/api/chat/sessions' \
-H "Authorization: Bearer $TOKEN" \
-H 'Content-Type: application/json' \
-d '{}' | jq -r '.id // .session_id // ""')
# Stream a message (SSE - will stream chunks)
curl -N -X POST "http://localhost:8006/api/chat/sessions/$SESSION_ID/stream" \
-H "Authorization: Bearer $TOKEN" \
-H 'Content-Type: application/json' \
-d '{"message": "Hello, what can you help me with?"}' \
--max-time 60 2>/dev/null | head -50
```
Or test via browser (preferred for UI verification):
```bash
agent-browser --session-name pr-test open 'http://localhost:3000/copilot' --timeout 10000
# ... fill chat input and press Enter, wait 20-30s for response
```
## Step 5: Record results
For each test scenario, record in `$RESULTS_DIR/test-report.md`:
```markdown
# E2E Test Report: PR #{N} — {title}
Date: {date}
Branch: {branch}
Worktree: {path}
## Environment
- Docker services: [list running containers]
- API keys: OpenRouter={present/missing}, E2B={present/missing}
## Test Results
### Scenario 1: {name}
**Steps:**
1. ...
2. ...
**Expected:** ...
**Actual:** ...
**Result:** PASS / FAIL
**Screenshot:** {filename}.png
**Logs:** (if relevant)
### Scenario 2: {name}
...
## Summary
- Total: X scenarios
- Passed: Y
- Failed: Z
- Bugs found: [list]
```
Take screenshots at each significant step:
```bash
agent-browser --session-name pr-test screenshot $RESULTS_DIR/{NN}-{description}.png
```
## Step 6: Report results
After all tests complete, output a summary to the user:
1. Table of all scenarios with PASS/FAIL
2. Screenshots of failures (read the PNG files to show them)
3. Any bugs found with details
4. Recommendations
### Post test results as PR comment with screenshots
Upload screenshots to the PR using the GitHub Git API (no local git operations — safe for worktrees).
```bash
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
REPO="Significant-Gravitas/AutoGPT"
SCREENSHOTS_BRANCH="test-screenshots/pr-${PR_NUMBER}"
SCREENSHOTS_DIR="test-screenshots/PR-${PR_NUMBER}"
# Step 1: Create blobs for each screenshot
declare -a TREE_ENTRIES
for img in $RESULTS_DIR/*.png; do
BASENAME=$(basename "$img")
B64=$(base64 < "$img")
BLOB_SHA=$(gh api "repos/${REPO}/git/blobs" -f content="$B64" -f encoding="base64" --jq '.sha')
TREE_ENTRIES+=("-f" "tree[][path]=${SCREENSHOTS_DIR}/${BASENAME}" "-f" "tree[][mode]=100644" "-f" "tree[][type]=blob" "-f" "tree[][sha]=${BLOB_SHA}")
done
# Step 2: Create a tree with all screenshot blobs
# Build the tree JSON manually since gh api doesn't handle arrays well
TREE_JSON='['
FIRST=true
for img in $RESULTS_DIR/*.png; do
BASENAME=$(basename "$img")
B64=$(base64 < "$img")
BLOB_SHA=$(gh api "repos/${REPO}/git/blobs" -f content="$B64" -f encoding="base64" --jq '.sha')
if [ "$FIRST" = true ]; then FIRST=false; else TREE_JSON+=','; fi
TREE_JSON+="{\"path\":\"${SCREENSHOTS_DIR}/${BASENAME}\",\"mode\":\"100644\",\"type\":\"blob\",\"sha\":\"${BLOB_SHA}\"}"
done
TREE_JSON+=']'
TREE_SHA=$(echo "$TREE_JSON" | gh api "repos/${REPO}/git/trees" --input - -f base_tree="" --jq '.sha' 2>/dev/null \
|| echo "$TREE_JSON" | jq -c '{tree: .}' | gh api "repos/${REPO}/git/trees" --input - --jq '.sha')
# Step 3: Create a commit pointing to that tree
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
-f tree="$TREE_SHA" \
--jq '.sha')
# Step 4: Create or update the ref (branch) — no local checkout needed
gh api "repos/${REPO}/git/refs" \
-f ref="refs/heads/${SCREENSHOTS_BRANCH}" \
-f sha="$COMMIT_SHA" 2>/dev/null \
|| gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" \
-X PATCH -f sha="$COMMIT_SHA" -f force=true
# Step 5: Build image markdown and post the comment
REPO_URL="https://raw.githubusercontent.com/${REPO}/${SCREENSHOTS_BRANCH}"
IMAGE_MARKDOWN=""
for img in $RESULTS_DIR/*.png; do
BASENAME=$(basename "$img")
IMAGE_MARKDOWN="$IMAGE_MARKDOWN
![${BASENAME}](${REPO_URL}/${SCREENSHOTS_DIR}/${BASENAME})"
done
gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -f body="$(cat <<EOF
## 🧪 E2E Test Report
$(cat $RESULTS_DIR/test-report.md)
### Screenshots
${IMAGE_MARKDOWN}
EOF
)"
```
This approach uses the GitHub Git API to create blobs, trees, commits, and refs entirely server-side. No local `git checkout` or `git push` — safe for worktrees and won't interfere with the PR branch.
## Fix mode (--fix flag)
When `--fix` is present, after finding a bug:
1. Identify the root cause in the code
2. Fix it in the worktree
3. Rebuild the affected service: `cd $PLATFORM_DIR && docker compose up --build -d {service_name}`
4. Re-test the scenario
5. If fix works, commit and push:
```bash
cd $WORKTREE_PATH
git add -A
git commit -m "fix: {description of fix}"
git push
```
6. Continue testing remaining scenarios
7. After all fixes, run the full test suite again to ensure no regressions
### Fix loop (like pr-address)
```text
test scenario → find bug → fix code → rebuild service → re-test
→ repeat until all scenarios pass
→ commit + push all fixes
→ run full re-test to verify
```
## Known issues and workarounds
### Problem: "Database error finding user" on signup
**Cause:** Supabase auth service schema cache is stale after migration.
**Fix:** `docker restart supabase-auth && sleep 5` then retry signup.
### Problem: Copilot returns auth errors in subscription mode
**Cause:** `CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true` but `CLAUDE_CODE_OAUTH_TOKEN` is not set or expired.
**Fix:** Re-extract the OAuth token from macOS keychain (see step 3b, Option 1) and recreate the container (`docker compose up -d copilot_executor`). The backend auto-provisions `~/.claude/.credentials.json` from the env var on startup. No `npm install` or `claude login` needed — the SDK bundles its own CLI binary.
### Problem: agent-browser can't find chromium
**Cause:** The Dockerfile auto-provisions system chromium on all architectures (including ARM64). If your branch is behind `dev`, this may not be present yet.
**Fix:** Check if chromium exists: `which chromium || which chromium-browser`. If missing, install it: `apt-get install -y chromium` and set `AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium` in the container environment.
### Problem: agent-browser selector matches multiple elements
**Cause:** `text=X` matches all elements containing that text.
**Fix:** Use `agent-browser snapshot` to get specific `ref=eNN` references, then use those: `agent-browser click eNN`.
### Problem: Frontend shows cookie banner blocking interaction
**Fix:** `agent-browser click 'text=Accept All'` before other interactions.
### Problem: Container loses npm packages after rebuild
**Cause:** `docker compose up --build` rebuilds the image, losing runtime installs.
**Fix:** Add packages to the Dockerfile instead of installing at runtime.
### Problem: Services not starting after `docker compose up`
**Fix:** Wait and check health: `docker compose ps`. Common cause: migration hasn't finished. Check: `docker logs autogpt_platform-migrate-1 2>&1 | tail -5`. If supabase-db isn't healthy: `docker restart supabase-db && sleep 10`.
### Problem: Docker uses cached layers with old code (PR changes not visible)
**Cause:** `docker compose up --build` reuses cached `COPY` layers from previous builds. If the PR branch changes Python files but the previous build already cached that layer from `dev`, the container runs `dev` code.
**Fix:** Always use `docker compose build --no-cache` for the first build of a PR branch. Subsequent rebuilds within the same branch can use `--build`.
### Problem: `agent-browser open` loses login session
**Cause:** Without session persistence, `agent-browser open` starts fresh.
**Fix:** Use `--session-name pr-test` on ALL agent-browser commands. This auto-saves/restores cookies and localStorage across navigations. Alternatively, use `agent-browser eval "window.location.href = '...'"` to navigate within the same context.
### Problem: Supabase auth returns "Database error querying schema"
**Cause:** The database schema changed (migration ran) but supabase-auth has a stale schema cache.
**Fix:** `docker restart supabase-db && sleep 10 && docker restart supabase-auth && sleep 8`. If user data was lost, re-signup.

View File

@@ -56,15 +56,35 @@ AutoGPT Platform is a monorepo containing:
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
- Use conventional commit messages (see below)
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
```bash
PR_BODY=$(mktemp)
cat > "$PR_BODY" << 'PREOF'
## Summary
- use `backticks` freely here
PREOF
gh pr create --title "..." --body-file "$PR_BODY" --base dev
rm "$PR_BODY"
```
- Run the github pre-commit hooks to ensure code quality.
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, follow a test-first approach:
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
2. **Implement the fix/feature** — write the minimal code to make the test pass.
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
This ensures every change is covered by a test and that the test actually validates the intended behavior.
### Reviewing/Revising Pull Requests
Use `/pr-review` to review a PR or `/pr-address` to address comments.
When fetching comments manually:
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` — top-level reviews
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` — inline review comments
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
### Conventional Commits

View File

@@ -37,10 +37,6 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
## ===== SIGNUP / INVITE GATE ===== ##
# Set to true to require an invite before users can sign up
ENABLE_INVITE_GATE=false
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
# Platform URLs (set these for webhooks and OAuth to work)
PLATFORM_BASE_URL=http://localhost:8000

View File

@@ -85,6 +85,30 @@ poetry run pytest path/to/test.py --snapshot-update
- After refactoring, update mock targets to match new module paths
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, write the test **before** the implementation:
```python
# 1. Write a failing test marked xfail
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
# 2. Run it — confirm it fails (XFAIL)
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
# 3. Implement the fix
# 4. Remove xfail, run again — confirm it passes
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
```
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
## Database Schema
Key models (defined in `schema.prisma`):

View File

@@ -1,17 +1,8 @@
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, Optional
import prisma.enums
from pydantic import BaseModel, EmailStr
from pydantic import BaseModel
from backend.data.model import UserTransaction
from backend.util.models import Pagination
if TYPE_CHECKING:
from backend.data.invited_user import BulkInvitedUsersResult, InvitedUserRecord
class UserHistoryResponse(BaseModel):
"""Response model for listings with version history"""
@@ -23,70 +14,3 @@ class UserHistoryResponse(BaseModel):
class AddUserCreditsResponse(BaseModel):
new_balance: int
transaction_key: str
class CreateInvitedUserRequest(BaseModel):
email: EmailStr
name: Optional[str] = None
class InvitedUserResponse(BaseModel):
id: str
email: str
status: prisma.enums.InvitedUserStatus
auth_user_id: Optional[str] = None
name: Optional[str] = None
tally_understanding: Optional[dict[str, Any]] = None
tally_status: prisma.enums.TallyComputationStatus
tally_computed_at: Optional[datetime] = None
tally_error: Optional[str] = None
created_at: datetime
updated_at: datetime
@classmethod
def from_record(cls, record: InvitedUserRecord) -> InvitedUserResponse:
return cls.model_validate(record.model_dump())
class InvitedUsersResponse(BaseModel):
invited_users: list[InvitedUserResponse]
pagination: Pagination
class BulkInvitedUserRowResponse(BaseModel):
row_number: int
email: Optional[str] = None
name: Optional[str] = None
status: Literal["CREATED", "SKIPPED", "ERROR"]
message: str
invited_user: Optional[InvitedUserResponse] = None
class BulkInvitedUsersResponse(BaseModel):
created_count: int
skipped_count: int
error_count: int
results: list[BulkInvitedUserRowResponse]
@classmethod
def from_result(cls, result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
return cls(
created_count=result.created_count,
skipped_count=result.skipped_count,
error_count=result.error_count,
results=[
BulkInvitedUserRowResponse(
row_number=row.row_number,
email=row.email,
name=row.name,
status=row.status,
message=row.message,
invited_user=(
InvitedUserResponse.from_record(row.invited_user)
if row.invited_user is not None
else None
),
)
for row in result.results
],
)

View File

@@ -1,137 +0,0 @@
import logging
import math
from autogpt_libs.auth import get_user_id, requires_admin_user
from fastapi import APIRouter, File, Query, Security, UploadFile
from backend.data.invited_user import (
bulk_create_invited_users_from_file,
create_invited_user,
list_invited_users,
retry_invited_user_tally,
revoke_invited_user,
)
from backend.data.tally import mask_email
from backend.util.models import Pagination
from .model import (
BulkInvitedUsersResponse,
CreateInvitedUserRequest,
InvitedUserResponse,
InvitedUsersResponse,
)
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/admin",
tags=["users", "admin"],
dependencies=[Security(requires_admin_user)],
)
@router.get(
"/invited-users",
response_model=InvitedUsersResponse,
summary="List Invited Users",
)
async def get_invited_users(
admin_user_id: str = Security(get_user_id),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=200),
) -> InvitedUsersResponse:
logger.info("Admin user %s requested invited users", admin_user_id)
invited_users, total = await list_invited_users(page=page, page_size=page_size)
return InvitedUsersResponse(
invited_users=[InvitedUserResponse.from_record(iu) for iu in invited_users],
pagination=Pagination(
total_items=total,
total_pages=max(1, math.ceil(total / page_size)),
current_page=page,
page_size=page_size,
),
)
@router.post(
"/invited-users",
response_model=InvitedUserResponse,
summary="Create Invited User",
)
async def create_invited_user_route(
request: CreateInvitedUserRequest,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info(
"Admin user %s creating invited user for %s",
admin_user_id,
mask_email(request.email),
)
invited_user = await create_invited_user(request.email, request.name)
logger.info(
"Admin user %s created invited user %s",
admin_user_id,
invited_user.id,
)
return InvitedUserResponse.from_record(invited_user)
@router.post(
"/invited-users/bulk",
response_model=BulkInvitedUsersResponse,
summary="Bulk Create Invited Users",
operation_id="postV2BulkCreateInvitedUsers",
)
async def bulk_create_invited_users_route(
file: UploadFile = File(...),
admin_user_id: str = Security(get_user_id),
) -> BulkInvitedUsersResponse:
logger.info(
"Admin user %s bulk invited users from %s",
admin_user_id,
file.filename or "<unnamed>",
)
content = await file.read()
result = await bulk_create_invited_users_from_file(file.filename, content)
return BulkInvitedUsersResponse.from_result(result)
@router.post(
"/invited-users/{invited_user_id}/revoke",
response_model=InvitedUserResponse,
summary="Revoke Invited User",
)
async def revoke_invited_user_route(
invited_user_id: str,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info(
"Admin user %s revoking invited user %s", admin_user_id, invited_user_id
)
invited_user = await revoke_invited_user(invited_user_id)
logger.info("Admin user %s revoked invited user %s", admin_user_id, invited_user_id)
return InvitedUserResponse.from_record(invited_user)
@router.post(
"/invited-users/{invited_user_id}/retry-tally",
response_model=InvitedUserResponse,
summary="Retry Invited User Tally",
)
async def retry_invited_user_tally_route(
invited_user_id: str,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info(
"Admin user %s retrying Tally seed for invited user %s",
admin_user_id,
invited_user_id,
)
invited_user = await retry_invited_user_tally(invited_user_id)
logger.info(
"Admin user %s retried Tally seed for invited user %s",
admin_user_id,
invited_user_id,
)
return InvitedUserResponse.from_record(invited_user)

View File

@@ -1,168 +0,0 @@
from datetime import datetime, timezone
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import prisma.enums
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from backend.data.invited_user import (
BulkInvitedUserRowResult,
BulkInvitedUsersResult,
InvitedUserRecord,
)
from .user_admin_routes import router as user_admin_router
app = fastapi.FastAPI()
app.include_router(user_admin_router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def _sample_invited_user() -> InvitedUserRecord:
now = datetime.now(timezone.utc)
return InvitedUserRecord(
id="invite-1",
email="invited@example.com",
status=prisma.enums.InvitedUserStatus.INVITED,
auth_user_id=None,
name="Invited User",
tally_understanding=None,
tally_status=prisma.enums.TallyComputationStatus.PENDING,
tally_computed_at=None,
tally_error=None,
created_at=now,
updated_at=now,
)
def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
return BulkInvitedUsersResult(
created_count=1,
skipped_count=1,
error_count=0,
results=[
BulkInvitedUserRowResult(
row_number=1,
email="invited@example.com",
name=None,
status="CREATED",
message="Invite created",
invited_user=_sample_invited_user(),
),
BulkInvitedUserRowResult(
row_number=2,
email="duplicate@example.com",
name=None,
status="SKIPPED",
message="An invited user with this email already exists",
invited_user=None,
),
],
)
def test_get_invited_users(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.list_invited_users",
AsyncMock(return_value=([_sample_invited_user()], 1)),
)
response = client.get("/admin/invited-users")
assert response.status_code == 200
data = response.json()
assert len(data["invited_users"]) == 1
assert data["invited_users"][0]["email"] == "invited@example.com"
assert data["invited_users"][0]["status"] == "INVITED"
assert data["pagination"]["total_items"] == 1
assert data["pagination"]["current_page"] == 1
assert data["pagination"]["page_size"] == 50
def test_create_invited_user(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.create_invited_user",
AsyncMock(return_value=_sample_invited_user()),
)
response = client.post(
"/admin/invited-users",
json={"email": "invited@example.com", "name": "Invited User"},
)
assert response.status_code == 200
data = response.json()
assert data["email"] == "invited@example.com"
assert data["name"] == "Invited User"
def test_bulk_create_invited_users(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.bulk_create_invited_users_from_file",
AsyncMock(return_value=_sample_bulk_invited_users_result()),
)
response = client.post(
"/admin/invited-users/bulk",
files={
"file": ("invites.txt", b"invited@example.com\nduplicate@example.com\n")
},
)
assert response.status_code == 200
data = response.json()
assert data["created_count"] == 1
assert data["skipped_count"] == 1
assert data["results"][0]["status"] == "CREATED"
assert data["results"][1]["status"] == "SKIPPED"
def test_revoke_invited_user(
mocker: pytest_mock.MockerFixture,
) -> None:
revoked = _sample_invited_user().model_copy(
update={"status": prisma.enums.InvitedUserStatus.REVOKED}
)
mocker.patch(
"backend.api.features.admin.user_admin_routes.revoke_invited_user",
AsyncMock(return_value=revoked),
)
response = client.post("/admin/invited-users/invite-1/revoke")
assert response.status_code == 200
assert response.json()["status"] == "REVOKED"
def test_retry_invited_user_tally(
mocker: pytest_mock.MockerFixture,
) -> None:
retried = _sample_invited_user().model_copy(
update={"tally_status": prisma.enums.TallyComputationStatus.RUNNING}
)
mocker.patch(
"backend.api.features.admin.user_admin_routes.retry_invited_user_tally",
AsyncMock(return_value=retried),
)
response = client.post("/admin/invited-users/invite-1/retry-tally")
assert response.status_code == 200
assert response.json()["tally_status"] == "RUNNING"

View File

@@ -60,7 +60,6 @@ from backend.copilot.tools.models import (
)
from backend.copilot.tracking import track_user_message
from backend.data.redis_client import get_redis_async
from backend.data.understanding import get_business_understanding
from backend.data.workspace import get_or_create_workspace
from backend.util.exceptions import NotFoundError
@@ -895,36 +894,6 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Suggested Prompts ==========
class SuggestedPromptsResponse(BaseModel):
"""Response model for user-specific suggested prompts."""
prompts: list[str]
@router.get(
"/suggested-prompts",
dependencies=[Security(auth.requires_user)],
)
async def get_suggested_prompts(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> SuggestedPromptsResponse:
"""
Get LLM-generated suggested prompts for the authenticated user.
Returns personalized quick-action prompts based on the user's
business understanding. Returns an empty list if no custom prompts
are available.
"""
understanding = await get_business_understanding(user_id)
if understanding is None:
return SuggestedPromptsResponse(prompts=[])
return SuggestedPromptsResponse(prompts=understanding.suggested_prompts)
# ========== Configuration ==========

View File

@@ -1,7 +1,7 @@
"""Tests for chat API routes: session title update, file attachment validation, usage, rate limiting, and suggested prompts."""
"""Tests for chat API routes: session title update, file attachment validation, usage, and rate limiting."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
@@ -400,62 +400,3 @@ def test_usage_rejects_unauthenticated_request() -> None:
response = unauthenticated_client.get("/usage")
assert response.status_code == 401
# ─── Suggested prompts endpoint ──────────────────────────────────────
def _mock_get_business_understanding(
mocker: pytest_mock.MockerFixture,
*,
return_value=None,
):
"""Mock get_business_understanding."""
return mocker.patch(
"backend.api.features.chat.routes.get_business_understanding",
new_callable=AsyncMock,
return_value=return_value,
)
def test_suggested_prompts_returns_prompts(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with understanding and prompts gets them back."""
mock_understanding = MagicMock()
mock_understanding.suggested_prompts = ["Do X", "Do Y", "Do Z"]
_mock_get_business_understanding(mocker, return_value=mock_understanding)
response = client.get("/suggested-prompts")
assert response.status_code == 200
assert response.json() == {"prompts": ["Do X", "Do Y", "Do Z"]}
def test_suggested_prompts_no_understanding(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with no understanding gets empty list."""
_mock_get_business_understanding(mocker, return_value=None)
response = client.get("/suggested-prompts")
assert response.status_code == 200
assert response.json() == {"prompts": []}
def test_suggested_prompts_empty_prompts(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with understanding but no prompts gets empty list."""
mock_understanding = MagicMock()
mock_understanding.suggested_prompts = []
_mock_get_business_understanding(mocker, return_value=mock_understanding)
response = client.get("/suggested-prompts")
assert response.status_code == 200
assert response.json() == {"prompts": []}

View File

@@ -24,7 +24,7 @@ from fastapi import (
UploadFile,
)
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel, EmailStr
from pydantic import BaseModel
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
from typing_extensions import Optional, TypedDict
@@ -55,11 +55,6 @@ from backend.data.credit import (
set_auto_top_up,
)
from backend.data.graph import GraphSettings
from backend.data.invited_user import (
check_invite_eligibility,
get_or_activate_user,
is_internal_email,
)
from backend.data.model import CredentialsMetaInput, UserOnboarding
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.data.onboarding import (
@@ -74,8 +69,8 @@ from backend.data.onboarding import (
reset_user_onboarding,
update_user_onboarding,
)
from backend.data.redis_client import get_redis_async
from backend.data.user import (
get_or_create_user,
get_user_by_id,
get_user_notification_preference,
update_user_email,
@@ -134,69 +129,6 @@ v1_router = APIRouter()
_tally_background_tasks: set[asyncio.Task] = set()
class CheckInviteRequest(BaseModel):
email: EmailStr
class CheckInviteResponse(BaseModel):
allowed: bool
_CHECK_INVITE_RATE_LIMIT = 10 # requests
_CHECK_INVITE_RATE_WINDOW = 60 # seconds
@v1_router.post(
"/auth/check-invite",
summary="Check if an email is allowed to sign up",
tags=["auth"],
)
async def check_invite_route(
http_request: Request,
request: CheckInviteRequest,
) -> CheckInviteResponse:
"""Check if an email is allowed to sign up (no auth required).
Called by the frontend before creating a Supabase auth user to prevent
orphaned accounts when the invite gate is enabled.
"""
client_ip = (
http_request.headers.get("x-forwarded-for", "").split(",")[0].strip()
or http_request.headers.get("x-real-ip", "")
or (http_request.client.host if http_request.client else "unknown")
)
rate_key = f"rate:check-invite:{client_ip}"
try:
redis = await get_redis_async()
# Use a pipeline so that incr + expire are sent atomically.
# This prevents the key from persisting indefinitely when expire fails
# after a successful incr (which would permanently block the IP once
# the count exceeds the limit).
# NOTE: pipeline command methods (incr, expire) are NOT awaitable —
# they queue the command and return the pipeline. Only execute() is
# awaited, which flushes all queued commands in a single round-trip.
pipe = redis.pipeline()
pipe.incr(rate_key)
pipe.expire(rate_key, _CHECK_INVITE_RATE_WINDOW)
results = await pipe.execute()
count = results[0]
if count > _CHECK_INVITE_RATE_LIMIT:
raise HTTPException(status_code=429, detail="Too many requests")
except HTTPException:
raise
except Exception:
logger.debug("Rate limit check failed for check-invite, failing open")
if not settings.config.enable_invite_gate:
return CheckInviteResponse(allowed=True)
if is_internal_email(request.email):
return CheckInviteResponse(allowed=True)
allowed = await check_invite_eligibility(request.email)
return CheckInviteResponse(allowed=allowed)
@v1_router.post(
"/auth/user",
summary="Get or create user",
@@ -204,10 +136,12 @@ async def check_invite_route(
dependencies=[Security(requires_user)],
)
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
user = await get_or_activate_user(user_data)
user = await get_or_create_user(user_data)
# Fire-and-forget: backfill Tally understanding when invite pre-seeding did
# not produce a stored result before first activation.
# Fire-and-forget: populate business understanding from Tally form.
# We use created_at proximity instead of an is_new flag because
# get_or_create_user is cached — a separate is_new return value would be
# unreliable on repeated calls within the cache TTL.
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
if age_seconds < 30:
try:
@@ -231,8 +165,7 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
dependencies=[Security(requires_user)],
)
async def update_user_email_route(
user_id: Annotated[str, Security(get_user_id)],
email: str = Body(...),
user_id: Annotated[str, Security(get_user_id)], email: str = Body(...)
) -> dict[str, str]:
await update_user_email(user_id, email)
@@ -246,16 +179,10 @@ async def update_user_email_route(
dependencies=[Security(requires_user)],
)
async def get_user_timezone_route(
user_id: Annotated[str, Security(get_user_id)],
user_data: dict = Security(get_jwt_payload),
) -> TimezoneResponse:
"""Get user timezone setting."""
try:
user = await get_user_by_id(user_id)
except ValueError:
raise HTTPException(
status_code=HTTP_404_NOT_FOUND,
detail="User not found. Please complete activation via /auth/user first.",
)
user = await get_or_create_user(user_data)
return TimezoneResponse(timezone=user.timezone)
@@ -266,8 +193,7 @@ async def get_user_timezone_route(
dependencies=[Security(requires_user)],
)
async def update_user_timezone_route(
user_id: Annotated[str, Security(get_user_id)],
request: UpdateTimezoneRequest,
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
) -> TimezoneResponse:
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
user = await update_user_timezone(user_id, str(request.timezone))
@@ -666,6 +592,11 @@ async def fulfill_checkout(user_id: Annotated[str, Security(get_user_id)]):
async def configure_user_auto_top_up(
request: AutoTopUpConfig, user_id: Annotated[str, Security(get_user_id)]
) -> str:
"""Configure auto top-up settings and perform an immediate top-up if needed.
Raises HTTPException(422) if the request parameters are invalid or if
the credit top-up fails.
"""
if request.threshold < 0:
raise HTTPException(status_code=422, detail="Threshold must be greater than 0")
if request.amount < 500 and request.amount != 0:
@@ -680,10 +611,20 @@ async def configure_user_auto_top_up(
user_credit_model = await get_user_credit_model(user_id)
current_balance = await user_credit_model.get_credits(user_id)
if current_balance < request.threshold:
await user_credit_model.top_up_credits(user_id, request.amount)
else:
await user_credit_model.top_up_credits(user_id, 0)
try:
if current_balance < request.threshold:
await user_credit_model.top_up_credits(user_id, request.amount)
else:
await user_credit_model.top_up_credits(user_id, 0)
except ValueError as e:
known_messages = (
"must not be negative",
"already exists for user",
"No payment method found",
)
if any(msg in str(e) for msg in known_messages):
raise HTTPException(status_code=422, detail=str(e))
raise
await set_auto_top_up(
user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount)

View File

@@ -35,102 +35,6 @@ def setup_app_auth(mock_jwt_user, setup_test_user):
app.dependency_overrides.clear()
# check_invite_route tests
_RATE_LIMIT_PATCH = "backend.api.features.v1.get_redis_async"
def _make_redis_mock(count: int = 1) -> AsyncMock:
"""Return a mock Redis client that reports `count` for the rate-limit key.
The route uses a pipeline where incr/expire are synchronous (they queue
commands and return the pipeline) and only execute() is awaited.
"""
mock_pipe = Mock()
mock_pipe.incr = Mock(return_value=mock_pipe)
mock_pipe.expire = Mock(return_value=mock_pipe)
mock_pipe.execute = AsyncMock(return_value=[count, True])
mock_redis = AsyncMock()
mock_redis.pipeline = Mock(return_value=mock_pipe)
return mock_redis
def test_check_invite_gate_disabled(mocker: pytest_mock.MockFixture) -> None:
"""When enable_invite_gate is False every email is allowed."""
mocker.patch(_RATE_LIMIT_PATCH, return_value=_make_redis_mock())
mocker.patch(
"backend.api.features.v1.settings",
Mock(config=Mock(enable_invite_gate=False)),
)
response = client.post("/auth/check-invite", json={"email": "anyone@example.com"})
assert response.status_code == 200
assert response.json() == {"allowed": True}
def test_check_invite_internal_email_bypasses_gate(
mocker: pytest_mock.MockFixture,
) -> None:
"""@agpt.co addresses bypass the gate even when it is enabled."""
mocker.patch(_RATE_LIMIT_PATCH, return_value=_make_redis_mock())
mocker.patch(
"backend.api.features.v1.settings",
Mock(config=Mock(enable_invite_gate=True)),
)
response = client.post("/auth/check-invite", json={"email": "employee@agpt.co"})
assert response.status_code == 200
assert response.json() == {"allowed": True}
def test_check_invite_eligible_email(mocker: pytest_mock.MockFixture) -> None:
"""An email with INVITED status is allowed when the gate is enabled."""
mocker.patch(_RATE_LIMIT_PATCH, return_value=_make_redis_mock())
mocker.patch(
"backend.api.features.v1.settings",
Mock(config=Mock(enable_invite_gate=True)),
)
mocker.patch(
"backend.api.features.v1.check_invite_eligibility",
new=AsyncMock(return_value=True),
)
response = client.post("/auth/check-invite", json={"email": "invited@example.com"})
assert response.status_code == 200
assert response.json() == {"allowed": True}
def test_check_invite_ineligible_email(mocker: pytest_mock.MockFixture) -> None:
"""An email without an active invite is denied when the gate is enabled."""
mocker.patch(_RATE_LIMIT_PATCH, return_value=_make_redis_mock())
mocker.patch(
"backend.api.features.v1.settings",
Mock(config=Mock(enable_invite_gate=True)),
)
mocker.patch(
"backend.api.features.v1.check_invite_eligibility",
new=AsyncMock(return_value=False),
)
response = client.post("/auth/check-invite", json={"email": "stranger@example.com"})
assert response.status_code == 200
assert response.json() == {"allowed": False}
def test_check_invite_rate_limit_exceeded(mocker: pytest_mock.MockFixture) -> None:
"""Requests beyond the per-IP rate limit receive HTTP 429."""
mocker.patch(_RATE_LIMIT_PATCH, return_value=_make_redis_mock(count=11))
response = client.post("/auth/check-invite", json={"email": "flood@example.com"})
assert response.status_code == 429
# Auth endpoints tests
def test_get_or_create_user_route(
mocker: pytest_mock.MockFixture,
@@ -147,7 +51,7 @@ def test_get_or_create_user_route(
}
mocker.patch(
"backend.api.features.v1.get_or_activate_user",
"backend.api.features.v1.get_or_create_user",
return_value=mock_user,
)

View File

@@ -188,6 +188,7 @@ async def upload_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
file: UploadFile,
session_id: str | None = Query(default=None),
overwrite: bool = Query(default=False),
) -> UploadFileResponse:
"""
Upload a file to the user's workspace.
@@ -248,7 +249,9 @@ async def upload_file(
# Write file via WorkspaceManager
manager = WorkspaceManager(user_id, workspace.id, session_id)
try:
workspace_file = await manager.write_file(content, filename)
workspace_file = await manager.write_file(
content, filename, overwrite=overwrite
)
except ValueError as e:
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e

View File

@@ -1,3 +1,4 @@
import asyncio
import contextlib
import logging
import platform
@@ -19,7 +20,6 @@ from prisma.errors import PrismaError
import backend.api.features.admin.credit_admin_routes
import backend.api.features.admin.execution_analytics_routes
import backend.api.features.admin.store_admin_routes
import backend.api.features.admin.user_admin_routes
import backend.api.features.builder
import backend.api.features.builder.routes
import backend.api.features.chat.routes as chat_routes
@@ -38,8 +38,10 @@ import backend.api.features.workspace.routes as workspace_routes
import backend.data.block
import backend.data.db
import backend.data.graph
import backend.data.llm_registry
import backend.data.user
import backend.integrations.webhooks.utils
import backend.server.v2.llm
import backend.util.service
import backend.util.settings
from backend.api.features.library.exceptions import (
@@ -118,16 +120,56 @@ async def lifespan_context(app: fastapi.FastAPI):
AutoRegistry.patch_integrations()
# Load LLM registry before initializing blocks so blocks can use registry data.
# Tries Redis first (fast path on warm restart), falls back to DB.
# Note: Graceful fallback for now since no blocks consume registry yet (comes in PR #5)
try:
await backend.data.llm_registry.refresh_llm_registry()
logger.info("LLM registry loaded successfully at startup")
except Exception as e:
logger.warning(
f"Failed to load LLM registry at startup: {e}. "
"Blocks will initialize with empty registry."
)
# Start background task so this worker reloads its in-process cache whenever
# another worker (e.g. the admin API) refreshes the registry.
_registry_subscription_task = asyncio.create_task(
backend.data.llm_registry.subscribe_to_registry_refresh(
backend.data.llm_registry.refresh_llm_registry
)
)
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
try:
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
except Exception as e:
err_str = str(e)
if "AgentNode" in err_str or "does not exist" in err_str:
logger.warning(
f"migrate_llm_models skipped: AgentNode table not found ({e}). "
"This is expected in test environments."
)
else:
logger.error(
f"migrate_llm_models failed unexpectedly: {e}",
exc_info=True,
)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
with launch_darkly_context():
yield
_registry_subscription_task.cancel()
try:
await _registry_subscription_task
except asyncio.CancelledError:
pass
try:
await shutdown_cloud_storage_handler()
except Exception as e:
@@ -211,13 +253,22 @@ instrument_fastapi(
def handle_internal_http_error(status_code: int = 500, log_error: bool = True):
def handler(request: fastapi.Request, exc: Exception):
if log_error:
logger.exception(
"%s %s failed. Investigate and resolve the underlying issue: %s",
request.method,
request.url.path,
exc,
exc_info=exc,
)
if status_code >= 500:
logger.exception(
"%s %s failed. Investigate and resolve the underlying issue: %s",
request.method,
request.url.path,
exc,
exc_info=exc,
)
else:
logger.warning(
"%s %s failed with %d: %s",
request.method,
request.url.path,
status_code,
exc,
)
hint = (
"Adjust the request and retry."
@@ -267,12 +318,10 @@ async def validation_error_handler(
app.add_exception_handler(PrismaError, handle_internal_http_error(500))
app.add_exception_handler(
FolderAlreadyExistsError, handle_internal_http_error(409, False)
)
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400, False))
app.add_exception_handler(NotFoundError, handle_internal_http_error(404, False))
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403, False))
app.add_exception_handler(FolderAlreadyExistsError, handle_internal_http_error(409))
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400))
app.add_exception_handler(NotFoundError, handle_internal_http_error(404))
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403))
app.add_exception_handler(RequestValidationError, validation_error_handler)
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
@@ -312,11 +361,6 @@ app.include_router(
tags=["v2", "admin"],
prefix="/api/executions",
)
app.include_router(
backend.api.features.admin.user_admin_routes.router,
tags=["v2", "admin"],
prefix="/api/users",
)
app.include_router(
backend.api.features.executions.review.routes.router,
tags=["v2", "executions", "review"],
@@ -354,6 +398,16 @@ app.include_router(
tags=["oauth"],
prefix="/api/oauth",
)
app.include_router(
backend.server.v2.llm.router,
tags=["v2", "llm"],
prefix="/api",
)
app.include_router(
backend.server.v2.llm.admin_router,
tags=["v2", "llm", "admin"],
prefix="/api",
)
app.mount("/external-api", external_api)

View File

@@ -33,6 +33,13 @@ from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.clients import OPENROUTER_BASE_URL
from backend.util.logging import TruncatedLogger
from backend.util.openai_responses import (
convert_tools_to_responses_format,
extract_responses_content,
extract_responses_reasoning,
extract_responses_tool_calls,
extract_responses_usage,
)
from backend.util.prompt import compress_context, estimate_token_count
from backend.util.request import validate_url_host
from backend.util.settings import Settings
@@ -111,7 +118,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
GPT4O_MINI = "gpt-4o-mini"
GPT4O = "gpt-4o"
GPT4_TURBO = "gpt-4-turbo"
GPT3_5_TURBO = "gpt-3.5-turbo"
# Anthropic models
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
CLAUDE_4_OPUS = "claude-opus-4-20250514"
@@ -277,9 +283,6 @@ MODEL_METADATA = {
LlmModel.GPT4_TURBO: ModelMetadata(
"openai", 128000, 4096, "GPT-4 Turbo", "OpenAI", "OpenAI", 3
), # gpt-4-turbo-2024-04-09
LlmModel.GPT3_5_TURBO: ModelMetadata(
"openai", 16385, 4096, "GPT-3.5 Turbo", "OpenAI", "OpenAI", 1
), # gpt-3.5-turbo-0125
# https://docs.anthropic.com/en/docs/about-claude/models
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
"anthropic", 200000, 32000, "Claude Opus 4.1", "Anthropic", "Anthropic", 3
@@ -793,6 +796,19 @@ async def llm_call(
)
prompt = result.messages
# Sanitize unpaired surrogates in message content to prevent
# UnicodeEncodeError when httpx encodes the JSON request body.
for msg in prompt:
content = msg.get("content")
if isinstance(content, str):
try:
content.encode("utf-8")
except UnicodeEncodeError:
logger.warning("Sanitized unpaired surrogates in LLM prompt content")
msg["content"] = content.encode("utf-8", errors="surrogatepass").decode(
"utf-8", errors="replace"
)
# Calculate available tokens based on context window and input length
estimated_input_tokens = estimate_token_count(prompt)
model_max_output = llm_model.max_output_tokens or int(2**15)
@@ -801,36 +817,53 @@ async def llm_call(
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
if provider == "openai":
tools_param = tools if tools else openai.NOT_GIVEN
oai_client = openai.AsyncOpenAI(api_key=credentials.api_key.get_secret_value())
response_format = None
parallel_tool_calls = get_parallel_tool_calls_param(
llm_model, parallel_tool_calls
)
tools_param = convert_tools_to_responses_format(tools) if tools else openai.omit
text_config = openai.omit
if force_json_output:
response_format = {"type": "json_object"}
text_config = {"format": {"type": "json_object"}} # type: ignore
response = await oai_client.chat.completions.create(
response = await oai_client.responses.create(
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_completion_tokens=max_tokens,
tools=tools_param, # type: ignore
parallel_tool_calls=parallel_tool_calls,
input=prompt, # type: ignore[arg-type]
tools=tools_param, # type: ignore[arg-type]
max_output_tokens=max_tokens,
parallel_tool_calls=get_parallel_tool_calls_param(
llm_model, parallel_tool_calls
),
text=text_config, # type: ignore[arg-type]
store=False,
)
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
raw_tool_calls = extract_responses_tool_calls(response)
tool_calls = (
[
ToolContentBlock(
id=tc["id"],
type=tc["type"],
function=ToolCall(
name=tc["function"]["name"],
arguments=tc["function"]["arguments"],
),
)
for tc in raw_tool_calls
]
if raw_tool_calls
else None
)
reasoning = extract_responses_reasoning(response)
content = extract_responses_content(response)
prompt_tokens, completion_tokens = extract_responses_usage(response)
return LLMResponse(
raw_response=response.choices[0].message,
raw_response=response,
prompt=prompt,
response=response.choices[0].message.content or "",
response=content,
tool_calls=tool_calls,
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
reasoning=reasoning,
)
elif provider == "anthropic":

View File

@@ -61,20 +61,27 @@ class ExecutionParams(BaseModel):
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
"""
Return a list of tool_call_ids if the entry is a tool request.
Supports both OpenAI and Anthropics formats.
Supports OpenAI Chat Completions, Responses API, and Anthropic formats.
"""
tool_call_ids = []
# OpenAI Responses API: function_call items have type="function_call"
if entry.get("type") == "function_call":
if call_id := entry.get("call_id"):
tool_call_ids.append(call_id)
return tool_call_ids
if entry.get("role") != "assistant":
return tool_call_ids
# OpenAI: check for tool_calls in the entry.
# OpenAI Chat Completions: check for tool_calls in the entry.
calls = entry.get("tool_calls")
if isinstance(calls, list):
for call in calls:
if tool_id := call.get("id"):
tool_call_ids.append(tool_id)
# Anthropics: check content items for tool_use type.
# Anthropic: check content items for tool_use type.
content = entry.get("content")
if isinstance(content, list):
for item in content:
@@ -89,16 +96,22 @@ def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
def _get_tool_responses(entry: dict[str, Any]) -> list[str]:
"""
Return a list of tool_call_ids if the entry is a tool response.
Supports both OpenAI and Anthropics formats.
Supports OpenAI Chat Completions, Responses API, and Anthropic formats.
"""
tool_call_ids: list[str] = []
# OpenAI: a tool response message with role "tool" and key "tool_call_id".
# OpenAI Responses API: function_call_output items
if entry.get("type") == "function_call_output":
if call_id := entry.get("call_id"):
tool_call_ids.append(str(call_id))
return tool_call_ids
# OpenAI Chat Completions: a tool response message with role "tool".
if entry.get("role") == "tool":
if tool_call_id := entry.get("tool_call_id"):
tool_call_ids.append(str(tool_call_id))
# Anthropics: check content items for tool_result type.
# Anthropic: check content items for tool_result type.
if entry.get("role") == "user":
content = entry.get("content")
if isinstance(content, list):
@@ -111,14 +124,16 @@ def _get_tool_responses(entry: dict[str, Any]) -> list[str]:
return tool_call_ids
def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
def _create_tool_response(
call_id: str, output: Any, *, responses_api: bool = False
) -> dict[str, Any]:
"""
Create a tool response message for either OpenAI or Anthropics,
based on the tool_id format.
Create a tool response message for OpenAI, Anthropic, or OpenAI Responses API,
based on the tool_id format and the responses_api flag.
"""
content = output if isinstance(output, str) else json.dumps(output)
# Anthropics format: tool IDs typically start with "toolu_"
# Anthropic format: tool IDs typically start with "toolu_"
if call_id.startswith("toolu_"):
return {
"role": "user",
@@ -128,8 +143,11 @@ def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
],
}
# OpenAI format: tool IDs typically start with "call_".
# Or default fallback (if the tool_id doesn't match any known prefix)
# OpenAI Responses API format
if responses_api:
return {"type": "function_call_output", "call_id": call_id, "output": content}
# OpenAI Chat Completions format (default fallback)
return {"role": "tool", "tool_call_id": call_id, "content": content}
@@ -177,10 +195,19 @@ def _combine_tool_responses(tool_outputs: list[dict[str, Any]]) -> list[dict[str
return tool_outputs
def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]:
def _convert_raw_response_to_dict(
raw_response: Any,
) -> dict[str, Any] | list[dict[str, Any]]:
"""
Safely convert raw_response to dictionary format for conversation history.
Handles different response types from different LLM providers.
For the OpenAI Responses API, the raw_response is the entire Response
object. Its ``output`` items (messages, function_calls) are extracted
individually so they can be used as valid input items on the next call.
Returns a **list** of dicts in that case.
For Chat Completions / Anthropic / Ollama, returns a single dict.
"""
if isinstance(raw_response, str):
# Ollama returns a string, convert to dict format
@@ -188,11 +215,28 @@ def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]:
elif isinstance(raw_response, dict):
# Already a dict (from tests or some providers)
return raw_response
elif _is_responses_api_object(raw_response):
# OpenAI Responses API: extract individual output items
items = [json.to_dict(item) for item in raw_response.output]
return items if items else [{"role": "assistant", "content": ""}]
else:
# OpenAI/Anthropic return objects, convert with json.to_dict
# Chat Completions / Anthropic return message objects
return json.to_dict(raw_response)
def _is_responses_api_object(obj: Any) -> bool:
"""Detect an OpenAI Responses API Response object.
These have ``object == "response"`` and an ``output`` list, but no
``role`` attribute (unlike ChatCompletionMessage).
"""
return (
getattr(obj, "object", None) == "response"
and hasattr(obj, "output")
and not hasattr(obj, "role")
)
def get_pending_tool_calls(conversation_history: list[Any] | None) -> dict[str, int]:
"""
All the tool calls entry in the conversation history requires a response.
@@ -754,19 +798,34 @@ class SmartDecisionMakerBlock(Block):
self, prompt: list[dict], response, tool_outputs: list | None = None
):
"""Update conversation history with response and tool outputs."""
# Don't add separate reasoning message with tool calls (breaks Anthropic's tool_use->tool_result pairing)
assistant_message = _convert_raw_response_to_dict(response.raw_response)
has_tool_calls = isinstance(assistant_message.get("content"), list) and any(
item.get("type") == "tool_use"
for item in assistant_message.get("content", [])
)
converted = _convert_raw_response_to_dict(response.raw_response)
if response.reasoning and not has_tool_calls:
prompt.append(
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
if isinstance(converted, list):
# Responses API: output items are already individual dicts
has_tool_calls = any(
item.get("type") == "function_call" for item in converted
)
prompt.append(assistant_message)
if response.reasoning and not has_tool_calls:
prompt.append(
{
"role": "assistant",
"content": f"[Reasoning]: {response.reasoning}",
}
)
prompt.extend(converted)
else:
# Chat Completions / Anthropic: single assistant message dict
has_tool_calls = isinstance(converted.get("content"), list) and any(
item.get("type") == "tool_use" for item in converted.get("content", [])
)
if response.reasoning and not has_tool_calls:
prompt.append(
{
"role": "assistant",
"content": f"[Reasoning]: {response.reasoning}",
}
)
prompt.append(converted)
if tool_outputs:
prompt.extend(tool_outputs)
@@ -776,6 +835,8 @@ class SmartDecisionMakerBlock(Block):
tool_info: ToolInfo,
execution_params: ExecutionParams,
execution_processor: "ExecutionProcessor",
*,
responses_api: bool = False,
) -> dict:
"""Execute a single tool using the execution manager for proper integration."""
# Lazy imports to avoid circular dependencies
@@ -868,13 +929,17 @@ class SmartDecisionMakerBlock(Block):
if node_outputs
else "Tool executed successfully"
)
return _create_tool_response(tool_call.id, tool_response_content)
return _create_tool_response(
tool_call.id, tool_response_content, responses_api=responses_api
)
except Exception as e:
logger.error(f"Tool execution with manager failed: {e}")
logger.warning(f"Tool execution with manager failed: {e}")
# Return error response
return _create_tool_response(
tool_call.id, f"Tool execution failed: {str(e)}"
tool_call.id,
f"Tool execution failed: {str(e)}",
responses_api=responses_api,
)
async def _execute_tools_agent_mode(
@@ -895,6 +960,7 @@ class SmartDecisionMakerBlock(Block):
"""Execute tools in agent mode with a loop until finished."""
max_iterations = input_data.agent_mode_max_iterations
iteration = 0
use_responses_api = input_data.model.metadata.provider == "openai"
# Execution parameters for tool execution
execution_params = ExecutionParams(
@@ -951,14 +1017,19 @@ class SmartDecisionMakerBlock(Block):
for tool_info in processed_tools:
try:
tool_response = await self._execute_single_tool_with_manager(
tool_info, execution_params, execution_processor
tool_info,
execution_params,
execution_processor,
responses_api=use_responses_api,
)
tool_outputs.append(tool_response)
except Exception as e:
logger.error(f"Tool execution failed: {e}")
# Create error response for the tool
error_response = _create_tool_response(
tool_info.tool_call.id, f"Error: {str(e)}"
tool_info.tool_call.id,
f"Error: {str(e)}",
responses_api=use_responses_api,
)
tool_outputs.append(error_response)
@@ -1020,11 +1091,17 @@ class SmartDecisionMakerBlock(Block):
if pending_tool_calls and input_data.last_tool_output is None:
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
use_responses_api = input_data.model.metadata.provider == "openai"
tool_output = []
if pending_tool_calls and input_data.last_tool_output is not None:
first_call_id = next(iter(pending_tool_calls.keys()))
tool_output.append(
_create_tool_response(first_call_id, input_data.last_tool_output)
_create_tool_response(
first_call_id,
input_data.last_tool_output,
responses_api=use_responses_api,
)
)
prompt.extend(tool_output)
@@ -1056,7 +1133,9 @@ class SmartDecisionMakerBlock(Block):
)
if input_data.sys_prompt and not any(
p["role"] == "system" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
p.get("role") == "system"
and isinstance(p.get("content"), str)
and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
for p in prompt
):
prompt.append(
@@ -1067,7 +1146,9 @@ class SmartDecisionMakerBlock(Block):
)
if input_data.prompt and not any(
p["role"] == "user" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
p.get("role") == "user"
and isinstance(p.get("content"), str)
and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
for p in prompt
):
prompt.append(
@@ -1175,11 +1256,26 @@ class SmartDecisionMakerBlock(Block):
)
yield emit_key, arg_value
if response.reasoning:
converted = _convert_raw_response_to_dict(response.raw_response)
# Check for tool calls to avoid inserting reasoning between tool pairs
if isinstance(converted, list):
has_tool_calls = any(
item.get("type") == "function_call" for item in converted
)
else:
has_tool_calls = isinstance(converted.get("content"), list) and any(
item.get("type") == "tool_use" for item in converted.get("content", [])
)
if response.reasoning and not has_tool_calls:
prompt.append(
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
)
prompt.append(_convert_raw_response_to_dict(response.raw_response))
if isinstance(converted, list):
prompt.extend(converted)
else:
prompt.append(converted)
yield "conversations", prompt

View File

@@ -13,18 +13,17 @@ class TestLLMStatsTracking:
"""Test that llm_call returns proper token counts in LLMResponse."""
import backend.blocks.llm as llm
# Mock the OpenAI client
# Mock the OpenAI Responses API response
mock_response = MagicMock()
mock_response.choices = [
MagicMock(message=MagicMock(content="Test response", tool_calls=None))
]
mock_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20)
mock_response.output_text = "Test response"
mock_response.output = []
mock_response.usage = MagicMock(input_tokens=10, output_tokens=20)
# Test with mocked OpenAI response
with patch("openai.AsyncOpenAI") as mock_openai:
mock_client = AsyncMock()
mock_openai.return_value = mock_client
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
mock_client.responses.create = AsyncMock(return_value=mock_response)
response = await llm.llm_call(
credentials=llm.TEST_CREDENTIALS,
@@ -271,30 +270,17 @@ class TestLLMStatsTracking:
mock_response = MagicMock()
# Return different responses for chunk summary vs final summary
if call_count == 1:
mock_response.choices = [
MagicMock(
message=MagicMock(
content='<json_output id="test123456">{"summary": "Test chunk summary"}</json_output>',
tool_calls=None,
)
)
]
mock_response.output_text = '<json_output id="test123456">{"summary": "Test chunk summary"}</json_output>'
else:
mock_response.choices = [
MagicMock(
message=MagicMock(
content='<json_output id="test123456">{"final_summary": "Test final summary"}</json_output>',
tool_calls=None,
)
)
]
mock_response.usage = MagicMock(prompt_tokens=50, completion_tokens=30)
mock_response.output_text = '<json_output id="test123456">{"final_summary": "Test final summary"}</json_output>'
mock_response.output = []
mock_response.usage = MagicMock(input_tokens=50, output_tokens=30)
return mock_response
with patch("openai.AsyncOpenAI") as mock_openai:
mock_client = AsyncMock()
mock_openai.return_value = mock_client
mock_client.chat.completions.create = mock_create
mock_client.responses.create = mock_create
# Test with very short text (should only need 1 chunk + 1 final summary)
input_data = llm.AITextSummarizerBlock.Input(

View File

@@ -12,34 +12,18 @@ from backend.copilot.tools import TOOL_REGISTRY
# Shared technical notes that apply to both SDK and baseline modes
_SHARED_TOOL_NOTES = f"""\
### Sharing files with the user
After saving a file to the persistent workspace with `write_workspace_file`,
share it with the user by embedding the `download_url` from the response in
your message as a Markdown link or image:
### Sharing files
After `write_workspace_file`, embed the `download_url` in Markdown:
- File: `[report.csv](workspace://file_id#text/csv)`
- Image: `![chart](workspace://file_id#image/png)`
- Video: `![recording](workspace://file_id#video/mp4)`
- **Any file** — shows as a clickable download link:
`[report.csv](workspace://file_id#text/csv)`
- **Image** — renders inline in chat:
`![chart](workspace://file_id#image/png)`
- **Video** — renders inline in chat with player controls:
`![recording](workspace://file_id#video/mp4)`
The `download_url` field in the `write_workspace_file` response is already
in the correct format — paste it directly after the `(` in the Markdown.
### Passing file content to tools — @@agptfile: references
Instead of copying large file contents into a tool argument, pass a file
reference and the platform will load the content for you.
Syntax: `@@agptfile:<uri>[<start>-<end>]`
- `<uri>` **must** start with `workspace://` or `/` (absolute path):
- `workspace://<file_id>` — workspace file by ID
- `workspace:///<path>` — workspace file by virtual path
- `/absolute/local/path` — ephemeral or sdk_cwd file
- E2B sandbox absolute path (e.g. `/home/user/script.py`)
- `[<start>-<end>]` is an optional 1-indexed inclusive line range.
- URIs that do not start with `workspace://` or `/` are **not** expanded.
### File references — @@agptfile:
Pass large file content to tools by reference: `@@agptfile:<uri>[<start>-<end>]`
- `workspace://<file_id>` or `workspace:///<path>` — workspace files
- `/absolute/path` — local/sandbox files
- `[start-end]` — optional 1-indexed line range
- Multiple refs per argument supported. Only `workspace://` and absolute paths are expanded.
Examples:
```
@@ -50,21 +34,9 @@ Examples:
@@agptfile:/home/user/script.py
```
You can embed a reference inside any string argument, or use it as the entire
value. Multiple references in one argument are all expanded.
**Structured data**: When the entire argument is a single file reference, the platform auto-parses by extension/MIME. Supported: JSON, JSONL, CSV, TSV, YAML, TOML, Parquet, Excel (.xlsx only; legacy `.xls` is NOT supported). Unrecognised formats return plain string.
**Structured data**: When the **entire** argument value is a single file
reference (no surrounding text), the platform automatically parses the file
content based on its extension or MIME type. Supported formats: JSON, JSONL,
CSV, TSV, YAML, TOML, Parquet, and Excel (.xlsx — first sheet only).
For example, pass `@@agptfile:workspace://<id>` where the file is a `.csv` and
the rows will be parsed into `list[list[str]]` automatically. If the format is
unrecognised or parsing fails, the content is returned as a plain string.
Legacy `.xls` files are **not** supported — only the modern `.xlsx` format.
**Type coercion**: The platform also coerces expanded values to match the
block's expected input types. For example, if a block expects `list[list[str]]`
and the expanded value is a JSON string, it will be parsed into the correct type.
**Type coercion**: The platform auto-coerces expanded string values to match block input types (e.g. JSON string → `list[list[str]]`).
### Media file inputs (format: "file")
Some block inputs accept media files — their schema shows `"format": "file"`.
@@ -166,17 +138,12 @@ def _build_storage_supplement(
## Tool notes
### Shell commands
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
for shell commands — it runs {sandbox_type}.
### Working directory
- Your working directory is: `{working_dir}`
- All SDK file tools AND `bash_exec` operate on the same filesystem
- Use relative paths or absolute paths under `{working_dir}` for all file operations
### Shell & filesystem
- The SDK built-in Bash tool is NOT available. Use `bash_exec` for shell commands ({sandbox_type}). Working dir: `{working_dir}`
- SDK file tools (Read/Write/Edit/Glob/Grep) and `bash_exec` share one filesystem — use relative or absolute paths under this dir.
- `read_workspace_file`/`write_workspace_file` operate on **persistent cloud workspace storage** (separate from the working dir).
### Two storage systems — CRITICAL to understand
1. **{storage_system_1_name}** (`{working_dir}`):
{characteristics}
{persistence}

View File

@@ -2,13 +2,11 @@
import asyncio
import base64
import functools
import json
import logging
import os
import re
import shutil
import subprocess
import sys
import time
import uuid
@@ -77,6 +75,7 @@ from ..tracking import track_user_message
from .compaction import CompactionTracker, filter_compaction_messages
from .response_adapter import SDKResponseAdapter
from .security_hooks import create_security_hooks
from .subscription import validate_subscription as _validate_claude_code_subscription
from .tool_adapter import (
create_copilot_mcp_server,
get_copilot_tool_names,
@@ -458,37 +457,6 @@ def _resolve_sdk_model() -> str | None:
return model
@functools.cache
def _validate_claude_code_subscription() -> None:
"""Validate Claude CLI is installed and responds to `--version`.
Cached so the blocking subprocess check runs at most once per process
lifetime. A failure (CLI not installed) is a config error that requires
a process restart anyway.
"""
claude_path = shutil.which("claude")
if not claude_path:
raise RuntimeError(
"Claude Code CLI not found. Install it with: "
"npm install -g @anthropic-ai/claude-code"
)
result = subprocess.run(
[claude_path, "--version"],
capture_output=True,
text=True,
timeout=10,
)
if result.returncode != 0:
raise RuntimeError(
f"Claude CLI check failed (exit {result.returncode}): "
f"{result.stderr.strip()}"
)
logger.info(
"Claude Code subscription mode: CLI version %s",
result.stdout.strip(),
)
def _build_sdk_env(
session_id: str | None = None,
user_id: str | None = None,

View File

@@ -0,0 +1,144 @@
"""Claude Code subscription auth helpers.
Handles locating the SDK-bundled CLI binary, provisioning credentials from
environment variables, and validating that subscription auth is functional.
"""
import functools
import json
import logging
import os
import shutil
import subprocess
logger = logging.getLogger(__name__)
def find_bundled_cli() -> str:
"""Locate the Claude CLI binary bundled inside ``claude_agent_sdk``.
Falls back to ``shutil.which("claude")`` if the SDK bundle is absent.
"""
try:
from claude_agent_sdk._internal.transport.subprocess_cli import (
SubprocessCLITransport,
)
path = SubprocessCLITransport._find_bundled_cli(None) # type: ignore[arg-type]
if path:
return str(path)
except Exception:
pass
system_path = shutil.which("claude")
if system_path:
return system_path
raise RuntimeError(
"Claude CLI not found — neither the SDK-bundled binary nor a "
"system-installed `claude` could be located."
)
def provision_credentials_file() -> None:
"""Write ``~/.claude/.credentials.json`` from env when running headless.
If ``CLAUDE_CODE_OAUTH_TOKEN`` is set (an OAuth *access* token obtained
from ``claude auth status`` or extracted from the macOS keychain), this
helper writes a minimal credentials file so the bundled CLI can
authenticate without an interactive ``claude login``.
A ``CLAUDE_CODE_REFRESH_TOKEN`` env var is optional but recommended —
it lets the CLI silently refresh an expired access token.
"""
access_token = os.environ.get("CLAUDE_CODE_OAUTH_TOKEN", "").strip()
if not access_token:
return
creds_dir = os.path.expanduser("~/.claude")
creds_path = os.path.join(creds_dir, ".credentials.json")
# Don't overwrite an existing credentials file (e.g. from a volume mount).
if os.path.exists(creds_path):
logger.debug("Credentials file already exists at %s — skipping", creds_path)
return
os.makedirs(creds_dir, exist_ok=True)
creds = {
"claudeAiOauth": {
"accessToken": access_token,
"refreshToken": os.environ.get("CLAUDE_CODE_REFRESH_TOKEN", "").strip(),
"expiresAt": 0,
"scopes": [
"user:inference",
"user:profile",
"user:sessions:claude_code",
],
}
}
with open(creds_path, "w") as f:
json.dump(creds, f)
logger.info("Provisioned Claude credentials file at %s", creds_path)
@functools.cache
def validate_subscription() -> None:
"""Validate the bundled Claude CLI is reachable and authenticated.
Cached so the blocking subprocess check runs at most once per process
lifetime. On first call, also provisions ``~/.claude/.credentials.json``
from the ``CLAUDE_CODE_OAUTH_TOKEN`` env var when available.
"""
provision_credentials_file()
cli = find_bundled_cli()
result = subprocess.run(
[cli, "--version"],
capture_output=True,
text=True,
timeout=10,
)
if result.returncode != 0:
raise RuntimeError(
f"Claude CLI check failed (exit {result.returncode}): "
f"{result.stderr.strip()}"
)
logger.info(
"Claude Code subscription mode: CLI version %s",
result.stdout.strip(),
)
# Verify the CLI is actually authenticated.
auth_result = subprocess.run(
[cli, "auth", "status"],
capture_output=True,
text=True,
timeout=10,
env={
**os.environ,
"ANTHROPIC_API_KEY": "",
"ANTHROPIC_AUTH_TOKEN": "",
"ANTHROPIC_BASE_URL": "",
},
)
if auth_result.returncode != 0:
raise RuntimeError(
"Claude CLI is not authenticated. Either:\n"
" • Set CLAUDE_CODE_OAUTH_TOKEN env var (from `claude auth status` "
"or macOS keychain), or\n"
" • Mount ~/.claude/.credentials.json into the container, or\n"
" • Run `claude login` inside the container."
)
try:
status = json.loads(auth_result.stdout)
if not status.get("loggedIn"):
raise RuntimeError(
"Claude CLI reports loggedIn=false. Set CLAUDE_CODE_OAUTH_TOKEN "
"or run `claude login`."
)
logger.info(
"Claude subscription auth: method=%s, email=%s",
status.get("authMethod"),
status.get("email"),
)
except json.JSONDecodeError:
logger.warning("Could not parse `claude auth status` output")

View File

@@ -22,13 +22,12 @@ class AddUnderstandingTool(BaseTool):
@property
def description(self) -> str:
return """Capture and store information about the user's business context,
workflows, pain points, and automation goals. Call this tool whenever the user
shares information about their business. Each call incrementally adds to the
existing understanding - you don't need to provide all fields at once.
Use this to build a comprehensive profile that helps recommend better agents
and automations for the user's specific needs."""
return (
"Store user's business context, workflows, pain points, and automation goals. "
"Call whenever the user shares business info. Each call incrementally merges "
"with existing data — provide only the fields you have. "
"Builds a profile that helps recommend better agents for the user's needs."
)
@property
def parameters(self) -> dict[str, Any]:

View File

@@ -410,18 +410,11 @@ class BrowserNavigateTool(BaseTool):
@property
def description(self) -> str:
return (
"Navigate to a URL using a real browser. Returns an accessibility "
"tree snapshot listing the page's interactive elements with @ref IDs "
"(e.g. @e3) that can be used with browser_act. "
"Session persists — cookies and login state carry over between calls. "
"Use this (with browser_act) for multi-step interaction: login flows, "
"form filling, button clicks, or anything requiring page interaction. "
"For plain static pages, prefer web_fetch — no browser overhead. "
"For authenticated pages: navigate to the login page first, use browser_act "
"to fill credentials and submit, then navigate to the target page. "
"Note: for slow SPAs, the returned snapshot may reflect a partially-loaded "
"state. If elements seem missing, use browser_act with action='wait' and a "
"CSS selector or millisecond delay, then take a browser_screenshot to verify."
"Navigate to a URL in a real browser. Returns accessibility tree with @ref IDs "
"for browser_act. Session persists (cookies/auth carry over). "
"For static pages, prefer web_fetch. "
"For SPAs, elements may load late — use browser_act with wait + browser_screenshot to verify. "
"For auth: navigate to login, fill creds and submit with browser_act, then navigate to target."
)
@property
@@ -431,13 +424,13 @@ class BrowserNavigateTool(BaseTool):
"properties": {
"url": {
"type": "string",
"description": "The HTTP/HTTPS URL to navigate to.",
"description": "HTTP/HTTPS URL to navigate to.",
},
"wait_for": {
"type": "string",
"enum": ["networkidle", "load", "domcontentloaded"],
"default": "networkidle",
"description": "When to consider navigation complete. Use 'networkidle' for SPAs (default).",
"description": "Navigation completion strategy (default: networkidle).",
},
},
"required": ["url"],
@@ -556,14 +549,12 @@ class BrowserActTool(BaseTool):
@property
def description(self) -> str:
return (
"Interact with the current browser page. Use @ref IDs from the "
"snapshot (e.g. '@e3') to target elements. Returns an updated snapshot. "
"Supported actions: click, dblclick, fill, type, scroll, hover, press, "
"Interact with the current browser page using @ref IDs from the snapshot. "
"Actions: click, dblclick, fill, type, scroll, hover, press, "
"check, uncheck, select, wait, back, forward, reload. "
"fill clears the field before typing; type appends without clearing. "
"wait accepts a CSS selector (waits for element) or milliseconds string (e.g. '1000'). "
"Example login flow: fill @e1 with email → fill @e2 with password → "
"click @e3 (submit) → browser_navigate to the target page."
"fill clears field first; type appends. "
"wait accepts CSS selector or milliseconds (e.g. '1000'). "
"Returns updated snapshot."
)
@property
@@ -589,30 +580,21 @@ class BrowserActTool(BaseTool):
"forward",
"reload",
],
"description": "The action to perform.",
"description": "Action to perform.",
},
"target": {
"type": "string",
"description": (
"Element to target. Use @ref from snapshot (e.g. '@e3'), "
"a CSS selector, or a text description. "
"Required for: click, dblclick, fill, type, hover, check, uncheck, select. "
"For wait: a CSS selector to wait for, or milliseconds as a string (e.g. '1000')."
),
"description": "@ref ID (e.g. '@e3'), CSS selector, or text. Required for: click, dblclick, fill, type, hover, check, uncheck, select. For wait: CSS selector or milliseconds string (e.g. '1000').",
},
"value": {
"type": "string",
"description": (
"For fill/type: the text to enter. "
"For press: key name (e.g. 'Enter', 'Tab', 'Control+a'). "
"For select: the option value to select."
),
"description": "Text for fill/type, key for press (e.g. 'Enter'), option for select.",
},
"direction": {
"type": "string",
"enum": ["up", "down", "left", "right"],
"default": "down",
"description": "For scroll: direction to scroll.",
"description": "Scroll direction (default: down).",
},
},
"required": ["action"],
@@ -759,12 +741,10 @@ class BrowserScreenshotTool(BaseTool):
@property
def description(self) -> str:
return (
"Take a screenshot of the current browser page and save it to the workspace. "
"IMPORTANT: After calling this tool, immediately call read_workspace_file "
"with the returned file_id to display the image inline to the user — "
"the screenshot is not visible until you do this. "
"With annotate=true (default), @ref labels are overlaid on interactive "
"elements, making it easy to see which @ref ID maps to which element on screen."
"Screenshot the current browser page and save to workspace. "
"annotate=true overlays @ref labels on elements. "
"IMPORTANT: After calling, you MUST immediately call read_workspace_file with the "
"returned file_id to display the image inline."
)
@property
@@ -775,12 +755,12 @@ class BrowserScreenshotTool(BaseTool):
"annotate": {
"type": "boolean",
"default": True,
"description": "Overlay @ref labels on interactive elements (default: true).",
"description": "Overlay @ref labels (default: true).",
},
"filename": {
"type": "string",
"default": "screenshot.png",
"description": "Filename to save in the workspace.",
"description": "Workspace filename (default: screenshot.png).",
},
},
}

View File

@@ -108,22 +108,12 @@ class AgentOutputTool(BaseTool):
@property
def description(self) -> str:
return """Retrieve execution outputs from agents in the user's library.
Identify the agent using one of:
- agent_name: Fuzzy search in user's library
- library_agent_id: Exact library agent ID
- store_slug: Marketplace format 'username/agent-name'
Select which run to retrieve using:
- execution_id: Specific execution ID
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
Wait for completion (optional):
- wait_if_running: Max seconds to wait if execution is still running (0-300).
If the execution is running/queued, waits up to this many seconds for completion.
Returns current status on timeout. If already finished, returns immediately.
"""
return (
"Retrieve execution outputs from a library agent. "
"Identify by agent_name, library_agent_id, or store_slug. "
"Filter by execution_id or run_time. "
"Optionally wait for running executions."
)
@property
def parameters(self) -> dict[str, Any]:
@@ -132,32 +122,29 @@ class AgentOutputTool(BaseTool):
"properties": {
"agent_name": {
"type": "string",
"description": "Agent name to search for in user's library (fuzzy match)",
"description": "Agent name (fuzzy match).",
},
"library_agent_id": {
"type": "string",
"description": "Exact library agent ID",
"description": "Library agent ID.",
},
"store_slug": {
"type": "string",
"description": "Marketplace identifier: 'username/agent-slug'",
"description": "Marketplace 'username/agent-name'.",
},
"execution_id": {
"type": "string",
"description": "Specific execution ID to retrieve",
"description": "Specific execution ID.",
},
"run_time": {
"type": "string",
"description": (
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
),
"description": "Time filter: 'latest', 'today', 'yesterday', 'last week', 'last 7 days', 'last month', 'last 30 days', 'YYYY-MM-DD', or ISO datetime.",
},
"wait_if_running": {
"type": "integer",
"description": (
"Max seconds to wait if execution is still running (0-300). "
"If running, waits for completion. Returns current state on timeout."
),
"description": "Max seconds to wait if still running (0-300). Returns current state on timeout.",
"minimum": 0,
"maximum": 300,
},
},
"required": [],

View File

@@ -42,15 +42,9 @@ class BashExecTool(BaseTool):
@property
def description(self) -> str:
return (
"Execute a Bash command or script. "
"Full Bash scripting is supported (loops, conditionals, pipes, "
"functions, etc.). "
"The working directory is shared with the SDK Read/Write/Edit/Glob/Grep "
"tools — files created by either are immediately visible to both. "
"Execution is killed after the timeout (default 30s, max 120s). "
"Returns stdout and stderr. "
"Useful for file manipulation, data processing, running scripts, "
"and installing packages."
"Execute a Bash command or script. Shares filesystem with SDK file tools. "
"Useful for scripts, data processing, and package installation. "
"Killed after timeout (default 30s, max 120s)."
)
@property
@@ -60,13 +54,11 @@ class BashExecTool(BaseTool):
"properties": {
"command": {
"type": "string",
"description": "Bash command or script to execute.",
"description": "Bash command or script.",
},
"timeout": {
"type": "integer",
"description": (
"Max execution time in seconds (default 30, max 120)."
),
"description": "Max seconds (default 30, max 120).",
"default": 30,
},
},

View File

@@ -30,12 +30,7 @@ class ContinueRunBlockTool(BaseTool):
@property
def description(self) -> str:
return (
"Continue executing a block after human review approval. "
"Use this after a run_block call returned review_required. "
"Pass the review_id from the review_required response. "
"The block will execute with the original pre-approved input data."
)
return "Resume block execution after a run_block call returned review_required. Pass the review_id."
@property
def parameters(self) -> dict[str, Any]:
@@ -44,10 +39,7 @@ class ContinueRunBlockTool(BaseTool):
"properties": {
"review_id": {
"type": "string",
"description": (
"The review_id from a previous review_required response. "
"This resumes execution with the pre-approved input data."
),
"description": "review_id from the review_required response.",
},
},
"required": ["review_id"],

View File

@@ -23,12 +23,8 @@ class CreateAgentTool(BaseTool):
@property
def description(self) -> str:
return (
"Create a new agent workflow. Pass `agent_json` with the complete "
"agent graph JSON you generated using block schemas from find_block. "
"The tool validates, auto-fixes, and saves.\n\n"
"IMPORTANT: Before calling this tool, search for relevant existing agents "
"using find_library_agent that could be used as building blocks. "
"Pass their IDs in the library_agent_ids parameter."
"Create a new agent from JSON (nodes + links). Validates, auto-fixes, and saves. "
"Before calling, search for existing agents with find_library_agent."
)
@property
@@ -42,34 +38,21 @@ class CreateAgentTool(BaseTool):
"properties": {
"agent_json": {
"type": "object",
"description": (
"The agent JSON to validate and save. "
"Must contain 'nodes' and 'links' arrays, and optionally "
"'name' and 'description'."
),
"description": "Agent graph with 'nodes' and 'links' arrays.",
},
"library_agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of library agent IDs to use as building blocks."
),
"description": "Library agent IDs as building blocks.",
},
"save": {
"type": "boolean",
"description": (
"Whether to save the agent. Default is true. "
"Set to false for preview only."
),
"description": "Save the agent (default: true). False for preview.",
"default": True,
},
"folder_id": {
"type": "string",
"description": (
"Optional folder ID to save the agent into. "
"If not provided, the agent is saved at root level. "
"Use list_folders to find available folders."
),
"description": "Folder ID to save into (default: root).",
},
},
"required": ["agent_json"],

View File

@@ -23,9 +23,7 @@ class CustomizeAgentTool(BaseTool):
@property
def description(self) -> str:
return (
"Customize a marketplace or template agent. Pass `agent_json` "
"with the complete customized agent JSON. The tool validates, "
"auto-fixes, and saves."
"Customize a marketplace/template agent. Validates, auto-fixes, and saves."
)
@property
@@ -39,32 +37,21 @@ class CustomizeAgentTool(BaseTool):
"properties": {
"agent_json": {
"type": "object",
"description": (
"Complete customized agent JSON to validate and save. "
"Optionally include 'name' and 'description'."
),
"description": "Customized agent JSON with nodes and links.",
},
"library_agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of library agent IDs to use as building blocks."
),
"description": "Library agent IDs as building blocks.",
},
"save": {
"type": "boolean",
"description": (
"Whether to save the customized agent. Default is true."
),
"description": "Save the agent (default: true). False for preview.",
"default": True,
},
"folder_id": {
"type": "string",
"description": (
"Optional folder ID to save the agent into. "
"If not provided, the agent is saved at root level. "
"Use list_folders to find available folders."
),
"description": "Folder ID to save into (default: root).",
},
},
"required": ["agent_json"],

View File

@@ -23,12 +23,8 @@ class EditAgentTool(BaseTool):
@property
def description(self) -> str:
return (
"Edit an existing agent. Pass `agent_json` with the complete "
"updated agent JSON you generated. The tool validates, auto-fixes, "
"and saves.\n\n"
"IMPORTANT: Before calling this tool, if the changes involve adding new "
"functionality, search for relevant existing agents using find_library_agent "
"that could be used as building blocks."
"Edit an existing agent. Validates, auto-fixes, and saves. "
"Before calling, search for existing agents with find_library_agent."
)
@property
@@ -42,33 +38,20 @@ class EditAgentTool(BaseTool):
"properties": {
"agent_id": {
"type": "string",
"description": (
"The ID of the agent to edit. "
"Can be a graph ID or library agent ID."
),
"description": "Graph ID or library agent ID to edit.",
},
"agent_json": {
"type": "object",
"description": (
"Complete updated agent JSON to validate and save. "
"Must contain 'nodes' and 'links'. "
"Include 'name' and/or 'description' if they need "
"to be updated."
),
"description": "Updated agent JSON with nodes and links.",
},
"library_agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of library agent IDs to use as building blocks for the changes."
),
"description": "Library agent IDs as building blocks.",
},
"save": {
"type": "boolean",
"description": (
"Whether to save the changes. "
"Default is true. Set to false for preview only."
),
"description": "Save changes (default: true). False for preview.",
"default": True,
},
},

View File

@@ -134,11 +134,7 @@ class SearchFeatureRequestsTool(BaseTool):
@property
def description(self) -> str:
return (
"Search existing feature requests to check if a similar request "
"already exists before creating a new one. Returns matching feature "
"requests with their ID, title, and description."
)
return "Search existing feature requests. Check before creating a new one."
@property
def parameters(self) -> dict[str, Any]:
@@ -234,14 +230,9 @@ class CreateFeatureRequestTool(BaseTool):
@property
def description(self) -> str:
return (
"Create a new feature request or add a customer need to an existing one. "
"Always search first with search_feature_requests to avoid duplicates. "
"If a matching request exists, pass its ID as existing_issue_id to add "
"the user's need to it instead of creating a duplicate. "
"IMPORTANT: Never include personally identifiable information (PII) in "
"the title or description — no names, emails, phone numbers, company "
"names, or other identifying details. Write titles and descriptions in "
"generic, feature-focused language."
"Create a feature request or add need to existing one. "
"Search first to avoid duplicates. Pass existing_issue_id to add to existing. "
"Never include PII (names, emails, phone numbers, company names) in title/description."
)
@property
@@ -251,28 +242,15 @@ class CreateFeatureRequestTool(BaseTool):
"properties": {
"title": {
"type": "string",
"description": (
"Title for the feature request. Must be generic and "
"feature-focused — do not include any user names, emails, "
"company names, or other PII."
),
"description": "Feature request title. No names, emails, or company info.",
},
"description": {
"type": "string",
"description": (
"Detailed description of what the user wants and why. "
"Must not contain any personally identifiable information "
"(PII) — describe the feature need generically without "
"referencing specific users, companies, or contact details."
),
"description": "What the user wants and why. No names, emails, or company info.",
},
"existing_issue_id": {
"type": "string",
"description": (
"If adding a need to an existing feature request, "
"provide its Linear issue ID (from search results). "
"Omit to create a new feature request."
),
"description": "Linear issue ID to add need to (from search results).",
},
},
"required": ["title", "description"],

View File

@@ -18,10 +18,7 @@ class FindAgentTool(BaseTool):
@property
def description(self) -> str:
return (
"Discover agents from the marketplace based on capabilities and "
"user needs, or look up a specific agent by its creator/slug ID."
)
return "Search marketplace agents by capability, or look up by slug ('username/agent-name')."
@property
def parameters(self) -> dict[str, Any]:
@@ -30,7 +27,7 @@ class FindAgentTool(BaseTool):
"properties": {
"query": {
"type": "string",
"description": "Search query describing what the user wants to accomplish, or a creator/slug ID (e.g. 'username/agent-name') for direct lookup. Use single keywords for best results.",
"description": "Search keywords, or 'username/agent-name' for direct slug lookup.",
},
},
"required": ["query"],

View File

@@ -54,13 +54,9 @@ class FindBlockTool(BaseTool):
@property
def description(self) -> str:
return (
"Search for available blocks by name or description, or look up a "
"specific block by its ID. "
"Blocks are reusable components that perform specific tasks like "
"sending emails, making API calls, processing text, etc. "
"IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. "
"The response includes each block's id, name, and description. "
"Call run_block with the block's id **with no inputs** to see detailed inputs/outputs and execute it."
"Search blocks by name or description. Returns block IDs for run_block. "
"Always call this FIRST to get block IDs before using run_block. "
"Then call run_block with the block's id and empty input_data to see its detailed schema."
)
@property
@@ -70,19 +66,11 @@ class FindBlockTool(BaseTool):
"properties": {
"query": {
"type": "string",
"description": (
"Search query to find blocks by name or description, "
"or a block ID (UUID) for direct lookup. "
"Use keywords like 'email', 'http', 'text', 'ai', etc."
),
"description": "Search keywords (e.g. 'email', 'http', 'ai').",
},
"include_schemas": {
"type": "boolean",
"description": (
"If true, include full input_schema and output_schema "
"for each block. Use when generating agent JSON that "
"needs block schemas. Default is false."
),
"description": "Include full input/output schemas (for agent JSON generation).",
"default": False,
},
},

View File

@@ -19,13 +19,8 @@ class FindLibraryAgentTool(BaseTool):
@property
def description(self) -> str:
return (
"Search for or list agents in the user's library. Use this to find "
"agents the user has already added to their library, including agents "
"they created or added from the marketplace. "
"When creating agents with sub-agent composition, use this to get "
"the agent's graph_id, graph_version, input_schema, and output_schema "
"needed for AgentExecutorBlock nodes. "
"Omit the query to list all agents."
"Search user's library agents. Returns graph_id, schemas for sub-agent composition. "
"Omit query to list all."
)
@property
@@ -35,10 +30,7 @@ class FindLibraryAgentTool(BaseTool):
"properties": {
"query": {
"type": "string",
"description": (
"Search query to find agents by name or description. "
"Omit to list all agents in the library."
),
"description": "Search by name/description. Omit to list all.",
},
},
"required": [],

View File

@@ -22,20 +22,10 @@ class FixAgentGraphTool(BaseTool):
@property
def description(self) -> str:
return (
"Auto-fix common issues in an agent JSON graph. Applies fixes for:\n"
"- Missing or invalid UUIDs on nodes and links\n"
"- StoreValueBlock prerequisites for ConditionBlock\n"
"- Double curly brace escaping in prompt templates\n"
"- AddToList/AddToDictionary prerequisite blocks\n"
"- CodeExecutionBlock output field naming\n"
"- Missing credentials configuration\n"
"- Node X coordinate spacing (800+ units apart)\n"
"- AI model default parameters\n"
"- Link static properties based on input schema\n"
"- Type mismatches (inserts conversion blocks)\n\n"
"Returns the fixed agent JSON plus a list of fixes applied. "
"After fixing, the agent is re-validated. If still invalid, "
"the remaining errors are included in the response."
"Auto-fix common agent JSON issues: missing/invalid UUIDs, StoreValueBlock prerequisites, "
"double curly brace escaping, AddToList/AddToDictionary prerequisites, credentials, "
"node spacing, AI model defaults, link static properties, and type mismatches. "
"Returns fixed JSON and list of fixes applied."
)
@property

View File

@@ -42,12 +42,7 @@ class GetAgentBuildingGuideTool(BaseTool):
@property
def description(self) -> str:
return (
"Returns the complete guide for building agent JSON graphs, including "
"block IDs, link structure, AgentInputBlock, AgentOutputBlock, "
"AgentExecutorBlock (for sub-agent composition), and MCPToolBlock usage. "
"Call this before generating agent JSON to ensure correct structure."
)
return "Get the agent JSON building guide (nodes, links, AgentExecutorBlock, MCPToolBlock usage). Call before generating agent JSON."
@property
def parameters(self) -> dict[str, Any]:

View File

@@ -25,8 +25,7 @@ class GetDocPageTool(BaseTool):
@property
def description(self) -> str:
return (
"Get the full content of a documentation page by its path. "
"Use this after search_docs to read the complete content of a relevant page."
"Read full documentation page content by path (from search_docs results)."
)
@property
@@ -36,10 +35,7 @@ class GetDocPageTool(BaseTool):
"properties": {
"path": {
"type": "string",
"description": (
"The path to the documentation file, as returned by search_docs. "
"Example: 'platform/block-sdk-guide.md'"
),
"description": "Doc file path (e.g. 'platform/block-sdk-guide.md').",
},
},
"required": ["path"],

View File

@@ -38,11 +38,7 @@ class GetMCPGuideTool(BaseTool):
@property
def description(self) -> str:
return (
"Returns the MCP tool guide: known hosted server URLs (Notion, Linear, "
"Stripe, Intercom, Cloudflare, Atlassian) and authentication workflow. "
"Call before using run_mcp_tool if you need a server URL or auth info."
)
return "Get MCP server URLs and auth guide. Call before run_mcp_tool if you need a server URL or auth info."
@property
def parameters(self) -> dict[str, Any]:

View File

@@ -88,10 +88,7 @@ class CreateFolderTool(BaseTool):
@property
def description(self) -> str:
return (
"Create a new folder in the user's library to organize agents. "
"Optionally nest it inside an existing folder using parent_id."
)
return "Create a library folder. Use parent_id to nest inside another folder."
@property
def requires_auth(self) -> bool:
@@ -104,22 +101,19 @@ class CreateFolderTool(BaseTool):
"properties": {
"name": {
"type": "string",
"description": "Name for the new folder (max 100 chars).",
"description": "Folder name (max 100 chars).",
},
"parent_id": {
"type": "string",
"description": (
"ID of the parent folder to nest inside. "
"Omit to create at root level."
),
"description": "Parent folder ID (omit for root).",
},
"icon": {
"type": "string",
"description": "Optional icon identifier for the folder.",
"description": "Icon identifier.",
},
"color": {
"type": "string",
"description": "Optional hex color code (#RRGGBB).",
"description": "Hex color (#RRGGBB).",
},
},
"required": ["name"],
@@ -175,13 +169,9 @@ class ListFoldersTool(BaseTool):
@property
def description(self) -> str:
return (
"List the user's library folders. "
"Omit parent_id to get the full folder tree. "
"Provide parent_id to list only direct children of that folder. "
"Set include_agents=true to also return the agents inside each folder "
"and root-level agents not in any folder. Always set include_agents=true "
"when the user asks about agents, wants to see what's in their folders, "
"or mentions agents alongside folders."
"List library folders. Omit parent_id for full tree. "
"Set include_agents=true when user asks about agents, wants to see "
"what's in their folders, or mentions agents alongside folders."
)
@property
@@ -195,17 +185,11 @@ class ListFoldersTool(BaseTool):
"properties": {
"parent_id": {
"type": "string",
"description": (
"List children of this folder. "
"Omit to get the full folder tree."
),
"description": "List children of this folder (omit for full tree).",
},
"include_agents": {
"type": "boolean",
"description": (
"Whether to include the list of agents inside each folder. "
"Defaults to false."
),
"description": "Include agents in each folder (default: false).",
},
},
"required": [],
@@ -357,10 +341,7 @@ class MoveFolderTool(BaseTool):
@property
def description(self) -> str:
return (
"Move a folder to a different parent folder. "
"Set target_parent_id to null to move to root level."
)
return "Move a folder. Set target_parent_id to null for root."
@property
def requires_auth(self) -> bool:
@@ -373,14 +354,11 @@ class MoveFolderTool(BaseTool):
"properties": {
"folder_id": {
"type": "string",
"description": "ID of the folder to move.",
"description": "Folder ID.",
},
"target_parent_id": {
"type": ["string", "null"],
"description": (
"ID of the new parent folder. "
"Use null to move to root level."
),
"description": "New parent folder ID (null for root).",
},
},
"required": ["folder_id"],
@@ -433,10 +411,7 @@ class DeleteFolderTool(BaseTool):
@property
def description(self) -> str:
return (
"Delete a folder from the user's library. "
"Agents inside the folder are moved to root level (not deleted)."
)
return "Delete a folder. Agents inside move to root (not deleted)."
@property
def requires_auth(self) -> bool:
@@ -499,10 +474,7 @@ class MoveAgentsToFolderTool(BaseTool):
@property
def description(self) -> str:
return (
"Move one or more agents to a folder. "
"Set folder_id to null to move agents to root level."
)
return "Move agents to a folder. Set folder_id to null for root."
@property
def requires_auth(self) -> bool:
@@ -516,13 +488,11 @@ class MoveAgentsToFolderTool(BaseTool):
"agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": "List of library agent IDs to move.",
"description": "Library agent IDs to move.",
},
"folder_id": {
"type": ["string", "null"],
"description": (
"Target folder ID. Use null to move to root level."
),
"description": "Target folder ID (null for root).",
},
},
"required": ["agent_ids"],

View File

@@ -104,19 +104,11 @@ class RunAgentTool(BaseTool):
@property
def description(self) -> str:
return """Run or schedule an agent from the marketplace or user's library.
The tool automatically handles the setup flow:
- Returns missing inputs if required fields are not provided
- Returns missing credentials if user needs to configure them
- Executes immediately if all requirements are met
- Schedules execution if cron expression is provided
Identify the agent using either:
- username_agent_slug: Marketplace format 'username/agent-name'
- library_agent_id: ID of an agent in the user's library
For scheduled execution, provide: schedule_name, cron, and optionally timezone."""
return (
"Run or schedule an agent. Automatically checks inputs and credentials. "
"Identify by username_agent_slug ('user/agent') or library_agent_id. "
"For scheduling, provide schedule_name + cron."
)
@property
def parameters(self) -> dict[str, Any]:
@@ -125,40 +117,38 @@ class RunAgentTool(BaseTool):
"properties": {
"username_agent_slug": {
"type": "string",
"description": "Agent identifier in format 'username/agent-name'",
"description": "Marketplace format 'username/agent-name'.",
},
"library_agent_id": {
"type": "string",
"description": "Library agent ID from user's library",
"description": "Library agent ID.",
},
"inputs": {
"type": "object",
"description": "Input values for the agent",
"description": "Input values for the agent.",
"additionalProperties": True,
},
"use_defaults": {
"type": "boolean",
"description": "Set to true to run with default values (user must confirm)",
"description": "Run with default values (confirm with user first).",
},
"schedule_name": {
"type": "string",
"description": "Name for scheduled execution (triggers scheduling mode)",
"description": "Name for scheduled execution. Providing this triggers scheduling mode (also requires cron).",
},
"cron": {
"type": "string",
"description": "Cron expression (5 fields: min hour day month weekday)",
"description": "Cron expression (min hour day month weekday).",
},
"timezone": {
"type": "string",
"description": "IANA timezone for schedule (default: UTC)",
"description": "IANA timezone (default: UTC).",
},
"wait_for_result": {
"type": "integer",
"description": (
"Max seconds to wait for execution to complete (0-300). "
"If >0, blocks until the execution finishes or times out. "
"Returns execution outputs when complete."
),
"description": "Max seconds to wait for completion (0-300).",
"minimum": 0,
"maximum": 300,
},
},
"required": [],

View File

@@ -45,13 +45,10 @@ class RunBlockTool(BaseTool):
@property
def description(self) -> str:
return (
"Execute a specific block with the provided input data. "
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
"do NOT guess or make up block IDs. "
"On first attempt (without input_data), returns detailed schema showing "
"required inputs and outputs. Then call again with proper input_data to execute. "
"If a block requires human review, use continue_run_block with the "
"review_id after the user approves."
"Execute a block. IMPORTANT: Always get block_id from find_block first "
"— do NOT guess or fabricate IDs. "
"Call with empty input_data to see schema, then with data to execute. "
"If review_required, use continue_run_block."
)
@property
@@ -61,28 +58,14 @@ class RunBlockTool(BaseTool):
"properties": {
"block_id": {
"type": "string",
"description": (
"The block's 'id' field from find_block results. "
"NEVER guess this - always get it from find_block first."
),
},
"block_name": {
"type": "string",
"description": (
"The block's human-readable name from find_block results. "
"Used for display purposes in the UI."
),
"description": "Block ID from find_block results.",
},
"input_data": {
"type": "object",
"description": (
"Input values for the block. "
"First call with empty {} to see the block's schema, "
"then call again with proper values to execute."
),
"description": "Input values. Use {} first to see schema.",
},
},
"required": ["block_id", "block_name", "input_data"],
"required": ["block_id", "input_data"],
}
@property

View File

@@ -57,10 +57,9 @@ class RunMCPToolTool(BaseTool):
@property
def description(self) -> str:
return (
"Connect to an MCP (Model Context Protocol) server to discover and execute its tools. "
"Two-step: (1) call with server_url to list available tools, "
"(2) call again with server_url + tool_name + tool_arguments to execute. "
"Call get_mcp_guide for known server URLs and auth details."
"Discover and execute MCP server tools. "
"Call with server_url only to list tools, then with tool_name + tool_arguments to execute. "
"Call get_mcp_guide first for server URLs and auth."
)
@property
@@ -70,24 +69,15 @@ class RunMCPToolTool(BaseTool):
"properties": {
"server_url": {
"type": "string",
"description": (
"URL of the MCP server (Streamable HTTP endpoint), "
"e.g. https://mcp.example.com/mcp"
),
"description": "MCP server URL (Streamable HTTP endpoint).",
},
"tool_name": {
"type": "string",
"description": (
"Name of the MCP tool to execute. "
"Omit on first call to discover available tools."
),
"description": "Tool to execute. Omit to discover available tools.",
},
"tool_arguments": {
"type": "object",
"description": (
"Arguments to pass to the selected tool. "
"Must match the tool's input schema returned during discovery."
),
"description": "Arguments matching the tool's input schema.",
},
},
"required": ["server_url"],

View File

@@ -38,11 +38,7 @@ class SearchDocsTool(BaseTool):
@property
def description(self) -> str:
return (
"Search the AutoGPT platform documentation for information about "
"how to use the platform, build agents, configure blocks, and more. "
"Returns relevant documentation sections. Use get_doc_page to read full content."
)
return "Search platform documentation by keyword. Use get_doc_page to read full results."
@property
def parameters(self) -> dict[str, Any]:
@@ -51,10 +47,7 @@ class SearchDocsTool(BaseTool):
"properties": {
"query": {
"type": "string",
"description": (
"Search query to find relevant documentation. "
"Use natural language to describe what you're looking for."
),
"description": "Documentation search query.",
},
},
"required": ["query"],

View File

@@ -0,0 +1,119 @@
"""Schema regression tests for all registered CoPilot tools.
Validates that every tool in TOOL_REGISTRY produces a well-formed schema:
- description is non-empty
- all `required` fields exist in `properties`
- every property has a `type` and `description`
- total schema character budget does not regress past threshold
"""
import json
from typing import Any, cast
import pytest
from backend.copilot.tools import TOOL_REGISTRY
# Character budget (~4 chars/token heuristic, targeting ~8000 tokens)
_CHAR_BUDGET = 32_000
@pytest.fixture(scope="module")
def all_tool_schemas() -> list[tuple[str, Any]]:
"""Return (tool_name, openai_schema) pairs for every registered tool."""
return [(name, tool.as_openai_tool()) for name, tool in TOOL_REGISTRY.items()]
def _get_parametrize_data() -> list[tuple[str, object]]:
"""Build parametrize data at collection time."""
return [(name, tool.as_openai_tool()) for name, tool in TOOL_REGISTRY.items()]
@pytest.mark.parametrize(
"tool_name,schema",
_get_parametrize_data(),
ids=[name for name, _ in _get_parametrize_data()],
)
class TestToolSchema:
"""Validate schema invariants for every registered tool."""
def test_description_non_empty(self, tool_name: str, schema: dict) -> None:
desc = schema["function"].get("description", "")
assert desc, f"Tool '{tool_name}' has an empty description"
def test_required_fields_exist_in_properties(
self, tool_name: str, schema: dict
) -> None:
params = schema["function"].get("parameters", {})
properties = params.get("properties", {})
required = params.get("required", [])
for field in required:
assert field in properties, (
f"Tool '{tool_name}': required field '{field}' "
f"not found in properties {list(properties.keys())}"
)
def test_every_property_has_type_and_description(
self, tool_name: str, schema: dict
) -> None:
params = schema["function"].get("parameters", {})
properties = params.get("properties", {})
for prop_name, prop_def in properties.items():
assert (
"type" in prop_def
), f"Tool '{tool_name}', property '{prop_name}' is missing 'type'"
assert (
"description" in prop_def
), f"Tool '{tool_name}', property '{prop_name}' is missing 'description'"
def test_browser_act_action_enum_complete() -> None:
"""Assert browser_act action enum still contains all 14 supported actions.
This prevents future PRs from accidentally dropping actions during description
trimming. The enum is the authoritative list — this locks it at 14 values.
"""
tool = TOOL_REGISTRY["browser_act"]
schema = tool.as_openai_tool()
fn_def = schema["function"]
params = cast(dict[str, Any], fn_def.get("parameters", {}))
actions = params["properties"]["action"]["enum"]
expected = {
"click",
"dblclick",
"fill",
"type",
"scroll",
"hover",
"press",
"check",
"uncheck",
"select",
"wait",
"back",
"forward",
"reload",
}
assert set(actions) == expected, (
f"browser_act action enum changed. Got {set(actions)}, expected {expected}. "
"If you added/removed an action, update this test intentionally."
)
def test_total_schema_char_budget() -> None:
"""Assert total tool schema size stays under the character budget.
This locks in the 34% token reduction from #12398 and prevents future
description bloat from eroding the gains. Uses character count with a
~4 chars/token heuristic (budget of 32000 chars ≈ 8000 tokens).
Character count is tokenizer-agnostic — no dependency on GPT or Claude
tokenizers — while still providing a stable regression gate.
"""
schemas = [tool.as_openai_tool() for tool in TOOL_REGISTRY.values()]
serialized = json.dumps(schemas)
total_chars = len(serialized)
assert total_chars < _CHAR_BUDGET, (
f"Tool schemas use {total_chars} chars (~{total_chars // 4} tokens), "
f"exceeding budget of {_CHAR_BUDGET} chars (~{_CHAR_BUDGET // 4} tokens). "
f"Description bloat detected — trim descriptions or raise the budget intentionally."
)

View File

@@ -22,17 +22,9 @@ class ValidateAgentGraphTool(BaseTool):
@property
def description(self) -> str:
return (
"Validate an agent JSON graph for correctness. Checks:\n"
"- All block_ids reference real blocks\n"
"- All links reference valid source/sink nodes and fields\n"
"- Required input fields are wired or have defaults\n"
"- Data types are compatible across links\n"
"- Nested sink links use correct notation\n"
"- Prompt templates use proper curly brace escaping\n"
"- AgentExecutorBlock configurations are valid\n\n"
"Call this after generating agent JSON to verify correctness. "
"If validation fails, either fix issues manually based on the error "
"descriptions, or call fix_agent_graph to auto-fix common problems."
"Validate agent JSON for correctness: block_ids, links, required fields, "
"type compatibility, nested sink notation, prompt brace escaping, "
"and AgentExecutorBlock configs. On failure, use fix_agent_graph to auto-fix."
)
@property
@@ -46,11 +38,7 @@ class ValidateAgentGraphTool(BaseTool):
"properties": {
"agent_json": {
"type": "object",
"description": (
"The agent JSON to validate. Must contain 'nodes' and 'links' arrays. "
"Each node needs: id (UUID), block_id, input_default, metadata. "
"Each link needs: id (UUID), source_id, source_name, sink_id, sink_name."
),
"description": "Agent JSON with 'nodes' and 'links' arrays.",
},
},
"required": ["agent_json"],

View File

@@ -59,13 +59,7 @@ class WebFetchTool(BaseTool):
@property
def description(self) -> str:
return (
"Fetch the content of a public web page by URL. "
"Returns readable text extracted from HTML by default. "
"Useful for reading documentation, articles, and API responses. "
"Only supports HTTP/HTTPS GET requests to public URLs "
"(private/internal network addresses are blocked)."
)
return "Fetch a public web page. Public URLs only — internal addresses blocked. Returns readable text from HTML by default."
@property
def parameters(self) -> dict[str, Any]:
@@ -74,14 +68,11 @@ class WebFetchTool(BaseTool):
"properties": {
"url": {
"type": "string",
"description": "The public HTTP/HTTPS URL to fetch.",
"description": "Public HTTP/HTTPS URL.",
},
"extract_text": {
"type": "boolean",
"description": (
"If true (default), extract readable text from HTML. "
"If false, return raw content."
),
"description": "Extract text from HTML (default: true).",
"default": True,
},
},

View File

@@ -27,6 +27,8 @@ from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
_MAX_FILE_SIZE_MB = Config().max_file_size_mb
# Sentinel file_id used when a tool-result file is read directly from the local
# host filesystem (rather than from workspace storage).
_LOCAL_TOOL_RESULT_FILE_ID = "local"
@@ -415,13 +417,7 @@ class ListWorkspaceFilesTool(BaseTool):
@property
def description(self) -> str:
return (
"List files in the user's persistent workspace (cloud storage). "
"These files survive across sessions. "
"For ephemeral session files, use the SDK Read/Glob tools instead. "
"Returns file names, paths, sizes, and metadata. "
"Optionally filter by path prefix."
)
return "List persistent workspace files. For ephemeral session files, use SDK Glob/Read instead. Optionally filter by path prefix."
@property
def parameters(self) -> dict[str, Any]:
@@ -430,24 +426,17 @@ class ListWorkspaceFilesTool(BaseTool):
"properties": {
"path_prefix": {
"type": "string",
"description": (
"Optional path prefix to filter files "
"(e.g., '/documents/' to list only files in documents folder). "
"By default, only files from the current session are listed."
),
"description": "Filter by path prefix (e.g. '/documents/').",
},
"limit": {
"type": "integer",
"description": "Maximum number of files to return (default 50, max 100)",
"description": "Max files to return (default 50, max 100).",
"minimum": 1,
"maximum": 100,
},
"include_all_sessions": {
"type": "boolean",
"description": (
"If true, list files from all sessions. "
"Default is false (only current session's files)."
),
"description": "Include files from all sessions (default: false).",
},
},
"required": [],
@@ -530,18 +519,11 @@ class ReadWorkspaceFileTool(BaseTool):
@property
def description(self) -> str:
return (
"Read a file from the user's persistent workspace (cloud storage). "
"These files survive across sessions. "
"For ephemeral session files, use the SDK Read tool instead. "
"Specify either file_id or path to identify the file. "
"For small text files, returns content directly. "
"For large or binary files, returns metadata and a download URL. "
"Use 'save_to_path' to copy the file to the working directory "
"(sandbox or ephemeral) for processing with bash_exec or file tools. "
"Use 'offset' and 'length' for paginated reads of large files "
"(e.g., persisted tool outputs). "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."
"Read a file from persistent workspace. Specify file_id or path. "
"Small text/image files return inline; large/binary return metadata+URL. "
"Use save_to_path to copy to working dir for processing. "
"Use offset/length for paginated reads. "
"Paths scoped to current session; use /sessions/<id>/... for cross-session access."
)
@property
@@ -551,48 +533,30 @@ class ReadWorkspaceFileTool(BaseTool):
"properties": {
"file_id": {
"type": "string",
"description": "The file's unique ID (from list_workspace_files)",
"description": "File ID from list_workspace_files.",
},
"path": {
"type": "string",
"description": (
"The virtual file path (e.g., '/documents/report.pdf'). "
"Scoped to current session by default."
),
"description": "Virtual file path (e.g. '/documents/report.pdf').",
},
"save_to_path": {
"type": "string",
"description": (
"If provided, save the file to this path in the working "
"directory (cloud sandbox when E2B is active, or "
"ephemeral dir otherwise) so it can be processed with "
"bash_exec or file tools. "
"The file content is still returned in the response."
),
"description": "Copy file to this working directory path for processing.",
},
"force_download_url": {
"type": "boolean",
"description": (
"If true, always return metadata+URL instead of inline content. "
"Default is false (auto-selects based on file size/type)."
),
"description": "Always return metadata+URL instead of inline content.",
},
"offset": {
"type": "integer",
"description": (
"Character offset to start reading from (0-based). "
"Use with 'length' for paginated reads of large files."
),
"description": "Character offset for paginated reads (0-based).",
},
"length": {
"type": "integer",
"description": (
"Maximum number of characters to return. "
"Defaults to full file. Use with 'offset' for paginated reads."
),
"description": "Max characters to return for paginated reads.",
},
},
"required": [], # At least one must be provided
"required": [], # At least one of file_id or path must be provided
}
@property
@@ -755,15 +719,10 @@ class WriteWorkspaceFileTool(BaseTool):
@property
def description(self) -> str:
return (
"Write or create a file in the user's persistent workspace (cloud storage). "
"These files survive across sessions. "
"For ephemeral session files, use the SDK Write tool instead. "
"Provide content as plain text via 'content', OR base64-encoded via "
"'content_base64', OR copy a file from the ephemeral working directory "
"via 'source_path'. Exactly one of these three is required. "
f"Maximum file size is {Config().max_file_size_mb}MB. "
"Files are saved to the current session's folder by default. "
"Use /sessions/<session_id>/... for cross-session access."
"Write a file to persistent workspace (survives across sessions). "
"Provide exactly one of: content (text), content_base64 (binary), "
f"or source_path (copy from working dir). Max {_MAX_FILE_SIZE_MB}MB. "
"Paths scoped to current session; use /sessions/<id>/... for cross-session access."
)
@property
@@ -773,51 +732,31 @@ class WriteWorkspaceFileTool(BaseTool):
"properties": {
"filename": {
"type": "string",
"description": "Name for the file (e.g., 'report.pdf')",
"description": "Filename (e.g. 'report.pdf').",
},
"content": {
"type": "string",
"description": (
"Plain text content to write. Use this for text files "
"(code, configs, documents, etc.). "
"Mutually exclusive with content_base64 and source_path."
),
"description": "Plain text content. Mutually exclusive with content_base64/source_path.",
},
"content_base64": {
"type": "string",
"description": (
"Base64-encoded file content. Use this for binary files "
"(images, PDFs, etc.). "
"Mutually exclusive with content and source_path."
),
"description": "Base64-encoded binary content. Mutually exclusive with content/source_path.",
},
"source_path": {
"type": "string",
"description": (
"Path to a file in the ephemeral working directory to "
"copy to workspace (e.g., '/tmp/copilot-.../output.csv'). "
"Use this to persist files created by bash_exec or SDK Write. "
"Mutually exclusive with content and content_base64."
),
"description": "Working directory path to copy to workspace. Mutually exclusive with content/content_base64.",
},
"path": {
"type": "string",
"description": (
"Optional virtual path where to save the file "
"(e.g., '/documents/report.pdf'). "
"Defaults to '/{filename}'. Scoped to current session."
),
"description": "Virtual path (e.g. '/documents/report.pdf'). Defaults to '/{filename}'.",
},
"mime_type": {
"type": "string",
"description": (
"Optional MIME type of the file. "
"Auto-detected from filename if not provided."
),
"description": "MIME type. Auto-detected from filename if omitted.",
},
"overwrite": {
"type": "boolean",
"description": "Whether to overwrite if file exists at path (default: false)",
"description": "Overwrite if file exists (default: false).",
},
},
"required": ["filename"],
@@ -859,10 +798,10 @@ class WriteWorkspaceFileTool(BaseTool):
return resolved
content: bytes = resolved
max_size = Config().max_file_size_mb * 1024 * 1024
max_size = _MAX_FILE_SIZE_MB * 1024 * 1024
if len(content) > max_size:
return ErrorResponse(
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
message=f"File too large. Maximum size is {_MAX_FILE_SIZE_MB}MB",
session_id=session_id,
)
@@ -944,12 +883,7 @@ class DeleteWorkspaceFileTool(BaseTool):
@property
def description(self) -> str:
return (
"Delete a file from the user's persistent workspace (cloud storage). "
"Specify either file_id or path to identify the file. "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
return "Delete a file from persistent workspace. Specify file_id or path. Paths scoped to current session; use /sessions/<id>/... for cross-session access."
@property
def parameters(self) -> dict[str, Any]:
@@ -958,17 +892,14 @@ class DeleteWorkspaceFileTool(BaseTool):
"properties": {
"file_id": {
"type": "string",
"description": "The file's unique ID (from list_workspace_files)",
"description": "File ID from list_workspace_files.",
},
"path": {
"type": "string",
"description": (
"The virtual file path (e.g., '/documents/report.pdf'). "
"Scoped to current session by default."
),
"description": "Virtual file path.",
},
},
"required": [], # At least one must be provided
"required": [], # At least one of file_id or path must be provided
}
@property

View File

@@ -76,7 +76,6 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.GPT4O_MINI: 1,
LlmModel.GPT4O: 3,
LlmModel.GPT4_TURBO: 10,
LlmModel.GPT3_5_TURBO: 1,
LlmModel.CLAUDE_4_1_OPUS: 21,
LlmModel.CLAUDE_4_OPUS: 21,
LlmModel.CLAUDE_4_SONNET: 5,

View File

@@ -877,12 +877,12 @@ async def get_execution_outputs_by_node_exec_id(
where={"referencedByOutputExecId": node_exec_id}
)
result = {}
result: CompletedBlockOutput = defaultdict(list)
for output in outputs:
if output.data is not None:
result[output.name] = type_utils.convert(output.data, JsonValue)
result[output.name].append(type_utils.convert(output.data, JsonValue))
return result
return dict(result)
async def update_graph_execution_start_time(

View File

@@ -0,0 +1,102 @@
"""Test that get_execution_outputs_by_node_exec_id returns CompletedBlockOutput.
CompletedBlockOutput is dict[str, list[Any]] — values must be lists.
The RPC service layer validates return types via TypeAdapter, so if
the function returns plain values instead of lists, it causes:
1 validation error for dict[str,list[any]] response
Input should be a valid list [type=list_type, input_value='', input_type=str]
This breaks SmartDecisionMakerBlock agent mode tool execution.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import TypeAdapter
from backend.data.block import CompletedBlockOutput
@pytest.mark.asyncio
async def test_outputs_are_lists():
"""Each value in the returned dict must be a list, matching CompletedBlockOutput."""
from backend.data.execution import get_execution_outputs_by_node_exec_id
mock_output = MagicMock()
mock_output.name = "response"
mock_output.data = "some text output"
with patch(
"backend.data.execution.AgentNodeExecutionInputOutput.prisma"
) as mock_prisma:
mock_prisma.return_value.find_many = AsyncMock(return_value=[mock_output])
result = await get_execution_outputs_by_node_exec_id("test-exec-id")
# The result must conform to CompletedBlockOutput = dict[str, list[Any]]
assert "response" in result
assert isinstance(
result["response"], list
), f"Expected list, got {type(result['response']).__name__}: {result['response']!r}"
# Must also pass TypeAdapter validation (this is what the RPC layer does)
adapter = TypeAdapter(CompletedBlockOutput)
validated = adapter.validate_python(result) # This is the line that fails in prod
assert validated == {"response": ["some text output"]}
@pytest.mark.asyncio
async def test_multiple_outputs_same_name_are_collected():
"""Multiple outputs with the same name should all appear in the list."""
from backend.data.execution import get_execution_outputs_by_node_exec_id
mock_out1 = MagicMock()
mock_out1.name = "result"
mock_out1.data = "first"
mock_out2 = MagicMock()
mock_out2.name = "result"
mock_out2.data = "second"
with patch(
"backend.data.execution.AgentNodeExecutionInputOutput.prisma"
) as mock_prisma:
mock_prisma.return_value.find_many = AsyncMock(
return_value=[mock_out1, mock_out2]
)
result = await get_execution_outputs_by_node_exec_id("test-exec-id")
assert isinstance(result["result"], list)
assert len(result["result"]) == 2
@pytest.mark.asyncio
async def test_empty_outputs_returns_empty_dict():
"""No outputs → empty dict."""
from backend.data.execution import get_execution_outputs_by_node_exec_id
with patch(
"backend.data.execution.AgentNodeExecutionInputOutput.prisma"
) as mock_prisma:
mock_prisma.return_value.find_many = AsyncMock(return_value=[])
result = await get_execution_outputs_by_node_exec_id("test-exec-id")
assert result == {}
@pytest.mark.asyncio
async def test_none_data_skipped():
"""Outputs with data=None should be skipped."""
from backend.data.execution import get_execution_outputs_by_node_exec_id
mock_output = MagicMock()
mock_output.name = "response"
mock_output.data = None
with patch(
"backend.data.execution.AgentNodeExecutionInputOutput.prisma"
) as mock_prisma:
mock_prisma.return_value.find_many = AsyncMock(return_value=[mock_output])
result = await get_execution_outputs_by_node_exec_id("test-exec-id")
assert result == {}

View File

@@ -38,7 +38,7 @@ from backend.util.request import parse_url
from .block import BlockInput
from .db import BaseDbModel
from .db import prisma as db
from .db import query_raw_with_schema, transaction
from .db import execute_raw_with_schema, query_raw_with_schema, transaction
from .dynamic_fields import is_tool_pin, sanitize_pin_name
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE, MAX_GRAPH_VERSIONS_FETCH
from .model import CredentialsFieldInfo, CredentialsMetaInput, is_credentials_field_name
@@ -1669,16 +1669,15 @@ async def migrate_llm_models(migrate_to: LlmModel):
# Update each block
for id, path in llm_model_fields.items():
query = f"""
UPDATE platform."AgentNode"
query = """
UPDATE {schema_prefix}"AgentNode"
SET "constantInput" = jsonb_set("constantInput", $1, to_jsonb($2), true)
WHERE "agentBlockId" = $3
AND "constantInput" ? ($4)::text
AND "constantInput"->>($4)::text NOT IN {escaped_enum_values}
"""
AND "constantInput"->>($4)::text NOT IN """ + escaped_enum_values
await db.execute_raw(
query, # type: ignore - is supposed to be LiteralString
await execute_raw_with_schema(
query,
[path],
migrate_to.value,
id,

View File

@@ -1,774 +0,0 @@
import asyncio
import csv
import io
import logging
import os
import re
import socket
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Literal, Optional
from uuid import uuid4
import prisma.enums
import prisma.models
import prisma.types
from prisma.errors import UniqueViolationError
from pydantic import BaseModel, EmailStr, TypeAdapter, ValidationError
from backend.data.db import transaction
from backend.data.model import User
from backend.data.redis_client import get_redis_async
from backend.data.tally import get_business_understanding_input_from_tally, mask_email
from backend.data.understanding import (
BusinessUnderstandingInput,
merge_business_understanding_data,
)
from backend.data.user import get_user_by_email, get_user_by_id
from backend.executor.cluster_lock import AsyncClusterLock
from backend.util.exceptions import (
NotAuthorizedError,
NotFoundError,
PreconditionFailed,
)
from backend.util.json import SafeJson
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
_settings = Settings()
_WORKER_ID = f"{socket.gethostname()}:{os.getpid()}"
_tally_seed_tasks: dict[str, asyncio.Task] = {}
_TALLY_STALE_SECONDS = 300
_MAX_TALLY_ERROR_LENGTH = 200
_email_adapter = TypeAdapter(EmailStr)
MAX_BULK_INVITE_FILE_BYTES = 1024 * 1024
MAX_BULK_INVITE_ROWS = 500
class InvitedUserRecord(BaseModel):
id: str
email: str
status: prisma.enums.InvitedUserStatus
auth_user_id: Optional[str] = None
name: Optional[str] = None
tally_understanding: Optional[dict[str, Any]] = None
tally_status: prisma.enums.TallyComputationStatus
tally_computed_at: Optional[datetime] = None
tally_error: Optional[str] = None
created_at: datetime
updated_at: datetime
@classmethod
def from_db(cls, invited_user: "prisma.models.InvitedUser") -> "InvitedUserRecord":
payload = (
invited_user.tallyUnderstanding
if isinstance(invited_user.tallyUnderstanding, dict)
else None
)
return cls(
id=invited_user.id,
email=invited_user.email,
status=invited_user.status,
auth_user_id=invited_user.authUserId,
name=invited_user.name,
tally_understanding=payload,
tally_status=invited_user.tallyStatus,
tally_computed_at=invited_user.tallyComputedAt,
tally_error=invited_user.tallyError,
created_at=invited_user.createdAt,
updated_at=invited_user.updatedAt,
)
class BulkInvitedUserRowResult(BaseModel):
row_number: int
email: Optional[str] = None
name: Optional[str] = None
status: Literal["CREATED", "SKIPPED", "ERROR"]
message: str
invited_user: Optional[InvitedUserRecord] = None
class BulkInvitedUsersResult(BaseModel):
created_count: int
skipped_count: int
error_count: int
results: list[BulkInvitedUserRowResult]
@dataclass(frozen=True)
class _ParsedInviteRow:
row_number: int
email: str
name: Optional[str]
def normalize_email(email: str) -> str:
return email.strip().lower()
def is_internal_email(email: str) -> bool:
"""Return True for @agpt.co addresses, which always bypass the invite gate."""
return normalize_email(email).endswith("@agpt.co")
def _normalize_name(name: Optional[str]) -> Optional[str]:
if name is None:
return None
normalized = name.strip()
return normalized or None
def _default_profile_name(email: str, preferred_name: Optional[str]) -> str:
if preferred_name:
return preferred_name
local_part = email.split("@", 1)[0].strip()
return local_part or "user"
def _sanitize_username_base(email: str) -> str:
local_part = email.split("@", 1)[0].lower()
sanitized = re.sub(r"[^a-z0-9-]", "", local_part)
sanitized = sanitized.strip("-")
return sanitized[:40] or "user"
async def _generate_unique_profile_username(email: str, tx) -> str:
base = _sanitize_username_base(email)
for _ in range(2):
candidate = f"{base}-{uuid4().hex[:6]}"
existing = await prisma.models.Profile.prisma(tx).find_unique(
where={"username": candidate}
)
if existing is None:
return candidate
raise RuntimeError(f"Unable to generate unique username for {email}")
async def _ensure_default_profile(
user_id: str,
email: str,
preferred_name: Optional[str],
tx,
) -> None:
existing_profile = await prisma.models.Profile.prisma(tx).find_unique(
where={"userId": user_id}
)
if existing_profile is not None:
return
username = await _generate_unique_profile_username(email, tx)
await prisma.models.Profile.prisma(tx).create(
data=prisma.types.ProfileCreateInput(
userId=user_id,
name=_default_profile_name(email, preferred_name),
username=username,
description="I'm new here",
links=[],
avatarUrl="",
)
)
async def _ensure_default_onboarding(user_id: str, tx) -> None:
await prisma.models.UserOnboarding.prisma(tx).upsert(
where={"userId": user_id},
data={
"create": prisma.types.UserOnboardingCreateInput(userId=user_id),
"update": {},
},
)
async def _apply_tally_understanding(
user_id: str,
invited_user: "prisma.models.InvitedUser",
tx,
) -> None:
if not isinstance(invited_user.tallyUnderstanding, dict):
return
try:
input_data = BusinessUnderstandingInput.model_validate(
invited_user.tallyUnderstanding
)
except Exception:
logger.warning(
"Malformed tallyUnderstanding for invited user %s; skipping",
invited_user.id,
exc_info=True,
)
return
payload = merge_business_understanding_data({}, input_data)
await prisma.models.CoPilotUnderstanding.prisma(tx).upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "data": SafeJson(payload)},
"update": {"data": SafeJson(payload)},
},
)
async def check_invite_eligibility(email: str) -> bool:
"""Check if an email is allowed to sign up based on the invite list.
Args:
email: The email to check (will be normalized internally).
Returns True if the email has an active (INVITED) invite record.
Does NOT check enable_invite_gate — the caller is responsible for that.
"""
email = normalize_email(email)
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"email": email}
)
return (
invited_user is not None
and invited_user.status == prisma.enums.InvitedUserStatus.INVITED
)
async def list_invited_users(
page: int = 1,
page_size: int = 50,
) -> tuple[list[InvitedUserRecord], int]:
total = await prisma.models.InvitedUser.prisma().count()
invited_users = await prisma.models.InvitedUser.prisma().find_many(
order={"createdAt": "desc"},
skip=(page - 1) * page_size,
take=page_size,
)
return [InvitedUserRecord.from_db(iu) for iu in invited_users], total
async def create_invited_user(
email: str, name: Optional[str] = None
) -> InvitedUserRecord:
normalized_email = normalize_email(email)
normalized_name = _normalize_name(name)
existing_user = await prisma.models.User.prisma().find_unique(
where={"email": normalized_email}
)
if existing_user is not None:
raise PreconditionFailed("An active user with this email already exists")
existing_invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"email": normalized_email}
)
if existing_invited_user is not None:
raise PreconditionFailed("An invited user with this email already exists")
try:
invited_user = await prisma.models.InvitedUser.prisma().create(
data={
"email": normalized_email,
"name": normalized_name,
"status": prisma.enums.InvitedUserStatus.INVITED,
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
}
)
except UniqueViolationError:
raise PreconditionFailed("An invited user with this email already exists")
schedule_invited_user_tally_precompute(invited_user.id)
return InvitedUserRecord.from_db(invited_user)
async def revoke_invited_user(invited_user_id: str) -> InvitedUserRecord:
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"id": invited_user_id}
)
if invited_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
if invited_user.status == prisma.enums.InvitedUserStatus.CLAIMED:
raise PreconditionFailed("Claimed invited users cannot be revoked")
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
return InvitedUserRecord.from_db(invited_user)
revoked_user = await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={"status": prisma.enums.InvitedUserStatus.REVOKED},
)
if revoked_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
return InvitedUserRecord.from_db(revoked_user)
async def retry_invited_user_tally(invited_user_id: str) -> InvitedUserRecord:
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"id": invited_user_id}
)
if invited_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
raise PreconditionFailed("Revoked invited users cannot retry Tally seeding")
refreshed_user = await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyUnderstanding": None,
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
"tallyComputedAt": None,
"tallyError": None,
},
)
if refreshed_user is None:
raise NotFoundError(f"Invited user {invited_user_id} not found")
schedule_invited_user_tally_precompute(invited_user_id)
return InvitedUserRecord.from_db(refreshed_user)
def _decode_bulk_invite_file(content: bytes) -> str:
if len(content) > MAX_BULK_INVITE_FILE_BYTES:
raise ValueError("Invite file exceeds the maximum size of 1 MB")
try:
return content.decode("utf-8-sig")
except UnicodeDecodeError as exc:
raise ValueError("Invite file must be UTF-8 encoded") from exc
def _parse_bulk_invite_csv(text: str) -> list[_ParsedInviteRow]:
indexed_rows: list[tuple[int, list[str]]] = []
for row_number, row in enumerate(csv.reader(io.StringIO(text)), start=1):
normalized_row = [cell.strip() for cell in row]
if any(normalized_row):
indexed_rows.append((row_number, normalized_row))
if not indexed_rows:
return []
header = [cell.lower() for cell in indexed_rows[0][1]]
has_header = "email" in header
email_index = header.index("email") if has_header else 0
name_index: Optional[int] = (
header.index("name")
if has_header and "name" in header
else (1 if not has_header else None)
)
data_rows = indexed_rows[1:] if has_header else indexed_rows
parsed_rows: list[_ParsedInviteRow] = []
for row_number, row in data_rows:
if len(parsed_rows) >= MAX_BULK_INVITE_ROWS:
break
email = row[email_index].strip() if len(row) > email_index else ""
name = (
row[name_index].strip()
if name_index is not None and len(row) > name_index
else ""
)
parsed_rows.append(
_ParsedInviteRow(
row_number=row_number,
email=email,
name=name or None,
)
)
return parsed_rows
def _parse_bulk_invite_text(text: str) -> list[_ParsedInviteRow]:
parsed_rows: list[_ParsedInviteRow] = []
for row_number, raw_line in enumerate(text.splitlines(), start=1):
if len(parsed_rows) >= MAX_BULK_INVITE_ROWS:
break
line = raw_line.strip()
if not line or line.startswith("#"):
continue
parsed_rows.append(
_ParsedInviteRow(
row_number=row_number,
email=line,
name=None,
)
)
return parsed_rows
def _parse_bulk_invite_file(
filename: Optional[str],
content: bytes,
) -> list[_ParsedInviteRow]:
text = _decode_bulk_invite_file(content)
file_name = filename.lower() if filename else ""
parsed_rows = (
_parse_bulk_invite_csv(text)
if file_name.endswith(".csv")
else _parse_bulk_invite_text(text)
)
if not parsed_rows:
raise ValueError("Invite file did not contain any emails")
return parsed_rows
async def bulk_create_invited_users_from_file(
filename: Optional[str],
content: bytes,
) -> BulkInvitedUsersResult:
parsed_rows = _parse_bulk_invite_file(filename, content)
created_count = 0
skipped_count = 0
error_count = 0
results: list[BulkInvitedUserRowResult] = []
seen_emails: set[str] = set()
for row in parsed_rows:
row_name = _normalize_name(row.name)
try:
validated_email = _email_adapter.validate_python(row.email)
except ValidationError:
error_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=row.email or None,
name=row_name,
status="ERROR",
message="Invalid email address",
)
)
continue
normalized_email = normalize_email(str(validated_email))
if normalized_email in seen_emails:
skipped_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="SKIPPED",
message="Duplicate email in upload file",
)
)
continue
seen_emails.add(normalized_email)
try:
invited_user = await create_invited_user(normalized_email, row_name)
except PreconditionFailed as exc:
skipped_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="SKIPPED",
message=str(exc),
)
)
except Exception:
masked = mask_email(normalized_email)
logger.exception(
"Failed to create bulk invite for row %s (%s)",
row.row_number,
masked,
)
error_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="ERROR",
message="Unexpected error creating invite",
)
)
else:
created_count += 1
results.append(
BulkInvitedUserRowResult(
row_number=row.row_number,
email=normalized_email,
name=row_name,
status="CREATED",
message="Invite created",
invited_user=invited_user,
)
)
return BulkInvitedUsersResult(
created_count=created_count,
skipped_count=skipped_count,
error_count=error_count,
results=results,
)
async def _compute_invited_user_tally_seed(invited_user_id: str) -> None:
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"id": invited_user_id}
)
if invited_user is None:
return
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
return
try:
r = await get_redis_async()
except Exception:
r = None
lock: AsyncClusterLock | None = None
if r is not None:
lock = AsyncClusterLock(
redis=r,
key=f"tally_seed:{invited_user_id}",
owner_id=_WORKER_ID,
timeout=_TALLY_STALE_SECONDS,
)
current_owner = await lock.try_acquire()
if current_owner is None:
logger.warn("Redis unvailable for tally lock - skipping tally enrichement")
return
elif current_owner != _WORKER_ID:
logger.debug(
"Tally seed for %s already locked by %s, skipping",
invited_user_id,
current_owner,
)
return
if (
invited_user.tallyStatus == prisma.enums.TallyComputationStatus.RUNNING
and invited_user.updatedAt is not None
):
age = (datetime.now(timezone.utc) - invited_user.updatedAt).total_seconds()
if age < _TALLY_STALE_SECONDS:
logger.debug(
"Tally task for %s still RUNNING (age=%ds), skipping",
invited_user_id,
int(age),
)
return
logger.info(
"Tally task for %s is stale (age=%ds), re-running",
invited_user_id,
int(age),
)
await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyStatus": prisma.enums.TallyComputationStatus.RUNNING,
"tallyError": None,
},
)
try:
input_data = await get_business_understanding_input_from_tally(
invited_user.email,
require_api_key=True,
)
payload = (
SafeJson(input_data.model_dump(exclude_none=True))
if input_data is not None
else None
)
await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyUnderstanding": payload,
"tallyStatus": prisma.enums.TallyComputationStatus.READY,
"tallyComputedAt": datetime.now(timezone.utc),
"tallyError": None,
},
)
except Exception as exc:
logger.exception(
"Failed to compute Tally understanding for invited user %s",
invited_user_id,
)
sanitized_error = re.sub(
r"https?://\S+", "<url>", f"{type(exc).__name__}: {exc}"
)[:_MAX_TALLY_ERROR_LENGTH]
await prisma.models.InvitedUser.prisma().update(
where={"id": invited_user_id},
data={
"tallyStatus": prisma.enums.TallyComputationStatus.FAILED,
"tallyError": sanitized_error,
},
)
def schedule_invited_user_tally_precompute(invited_user_id: str) -> None:
existing = _tally_seed_tasks.get(invited_user_id)
if existing is not None and not existing.done():
logger.debug("Tally task already running for %s, skipping", invited_user_id)
return
task = asyncio.create_task(_compute_invited_user_tally_seed(invited_user_id))
_tally_seed_tasks[invited_user_id] = task
def _on_done(t: asyncio.Task, _id: str = invited_user_id) -> None:
if _tally_seed_tasks.get(_id) is t:
del _tally_seed_tasks[_id]
task.add_done_callback(_on_done)
async def _open_signup_create_user(
auth_user_id: str,
normalized_email: str,
metadata_name: Optional[str],
) -> User:
"""Create a user without requiring an invite (open signup mode)."""
preferred_name = _normalize_name(metadata_name)
try:
async with transaction() as tx:
user = await prisma.models.User.prisma(tx).create(
data=prisma.types.UserCreateInput(
id=auth_user_id,
email=normalized_email,
name=preferred_name,
)
)
await _ensure_default_profile(
auth_user_id, normalized_email, preferred_name, tx
)
await _ensure_default_onboarding(auth_user_id, tx)
except UniqueViolationError:
existing = await prisma.models.User.prisma().find_unique(
where={"id": auth_user_id}
)
if existing is not None:
return User.from_db(existing)
raise
return User.from_db(user)
# TODO: We need to change this functions logic before going live
async def get_or_activate_user(user_data: dict) -> User:
auth_user_id = user_data.get("sub")
if not auth_user_id:
raise NotAuthorizedError("User ID not found in token")
auth_email = user_data.get("email")
if not auth_email:
raise NotAuthorizedError("Email not found in token")
normalized_email = normalize_email(auth_email)
user_metadata = user_data.get("user_metadata")
metadata_name = (
user_metadata.get("name") if isinstance(user_metadata, dict) else None
)
existing_user = None
try:
existing_user = await get_user_by_id(auth_user_id)
except ValueError:
existing_user = None
except Exception:
logger.exception("Error on get user by id during tally enrichment process")
raise
if existing_user is not None:
return existing_user
if not _settings.config.enable_invite_gate or is_internal_email(normalized_email):
return await _open_signup_create_user(
auth_user_id, normalized_email, metadata_name
)
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
where={"email": normalized_email}
)
if invited_user is None:
raise NotAuthorizedError("Your email is not allowed to access the platform")
if invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
raise NotAuthorizedError("Your invitation is no longer active")
try:
async with transaction() as tx:
current_user = await prisma.models.User.prisma(tx).find_unique(
where={"id": auth_user_id}
)
if current_user is not None:
return User.from_db(current_user)
current_invited_user = await prisma.models.InvitedUser.prisma(
tx
).find_unique(where={"email": normalized_email})
if current_invited_user is None:
raise NotAuthorizedError(
"Your email is not allowed to access the platform"
)
if current_invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
raise NotAuthorizedError("Your invitation is no longer active")
if current_invited_user.authUserId not in (None, auth_user_id):
raise NotAuthorizedError("Your invitation has already been claimed")
preferred_name = current_invited_user.name or _normalize_name(metadata_name)
await prisma.models.User.prisma(tx).create(
data=prisma.types.UserCreateInput(
id=auth_user_id,
email=normalized_email,
name=preferred_name,
)
)
await prisma.models.InvitedUser.prisma(tx).update(
where={"id": current_invited_user.id},
data={
"status": prisma.enums.InvitedUserStatus.CLAIMED,
"authUserId": auth_user_id,
},
)
await _ensure_default_profile(
auth_user_id,
normalized_email,
preferred_name,
tx,
)
await _ensure_default_onboarding(auth_user_id, tx)
await _apply_tally_understanding(auth_user_id, current_invited_user, tx)
except UniqueViolationError:
logger.info("Concurrent activation for user %s; re-fetching", auth_user_id)
already_created = await prisma.models.User.prisma().find_unique(
where={"id": auth_user_id}
)
if already_created is not None:
return User.from_db(already_created)
raise RuntimeError(
f"UniqueViolationError during activation but user {auth_user_id} not found"
)
get_user_by_id.cache_delete(auth_user_id)
get_user_by_email.cache_delete(normalized_email)
activated_user = await prisma.models.User.prisma().find_unique(
where={"id": auth_user_id}
)
if activated_user is None:
raise RuntimeError(
f"Activated user {auth_user_id} was not found after creation"
)
return User.from_db(activated_user)

View File

@@ -1,409 +0,0 @@
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from types import SimpleNamespace
from typing import cast
from unittest.mock import AsyncMock, Mock
import prisma.enums
import prisma.models
import pytest
import pytest_mock
from backend.util.exceptions import NotAuthorizedError, PreconditionFailed
from .invited_user import (
InvitedUserRecord,
bulk_create_invited_users_from_file,
check_invite_eligibility,
create_invited_user,
get_or_activate_user,
retry_invited_user_tally,
)
def _invited_user_db_record(
*,
status: prisma.enums.InvitedUserStatus = prisma.enums.InvitedUserStatus.INVITED,
tally_understanding: dict | None = None,
):
now = datetime.now(timezone.utc)
return SimpleNamespace(
id="invite-1",
email="invited@example.com",
status=status,
authUserId=None,
name="Invited User",
tallyUnderstanding=tally_understanding,
tallyStatus=prisma.enums.TallyComputationStatus.PENDING,
tallyComputedAt=None,
tallyError=None,
createdAt=now,
updatedAt=now,
)
def _invited_user_record(
*,
status: prisma.enums.InvitedUserStatus = prisma.enums.InvitedUserStatus.INVITED,
tally_understanding: dict | None = None,
):
return InvitedUserRecord.from_db(
cast(
prisma.models.InvitedUser,
_invited_user_db_record(
status=status,
tally_understanding=tally_understanding,
),
)
)
def _user_db_record():
now = datetime.now(timezone.utc)
return SimpleNamespace(
id="auth-user-1",
email="invited@example.com",
emailVerified=True,
name="Invited User",
createdAt=now,
updatedAt=now,
metadata={},
integrations="",
stripeCustomerId=None,
topUpConfig=None,
maxEmailsPerDay=3,
notifyOnAgentRun=True,
notifyOnZeroBalance=True,
notifyOnLowBalance=True,
notifyOnBlockExecutionFailed=True,
notifyOnContinuousAgentError=True,
notifyOnDailySummary=True,
notifyOnWeeklySummary=True,
notifyOnMonthlySummary=True,
notifyOnAgentApproved=True,
notifyOnAgentRejected=True,
timezone="not-set",
)
@pytest.mark.asyncio
async def test_create_invited_user_rejects_existing_active_user(
mocker: pytest_mock.MockerFixture,
) -> None:
user_repo = Mock()
user_repo.find_unique = AsyncMock(return_value=_user_db_record())
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock()
mocker.patch(
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
with pytest.raises(PreconditionFailed):
await create_invited_user("Invited@example.com")
@pytest.mark.asyncio
async def test_create_invited_user_schedules_tally_seed(
mocker: pytest_mock.MockerFixture,
) -> None:
user_repo = Mock()
user_repo.find_unique = AsyncMock(return_value=None)
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock(return_value=None)
invited_user_repo.create = AsyncMock(return_value=_invited_user_db_record())
schedule = mocker.patch(
"backend.data.invited_user.schedule_invited_user_tally_precompute"
)
mocker.patch(
"backend.data.invited_user.prisma.models.User.prisma", return_value=user_repo
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
invited_user = await create_invited_user("Invited@example.com", "Invited User")
assert invited_user.email == "invited@example.com"
invited_user_repo.create.assert_awaited_once()
schedule.assert_called_once_with("invite-1")
@pytest.mark.asyncio
async def test_retry_invited_user_tally_resets_state_and_schedules(
mocker: pytest_mock.MockerFixture,
) -> None:
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock(return_value=_invited_user_db_record())
invited_user_repo.update = AsyncMock(return_value=_invited_user_db_record())
schedule = mocker.patch(
"backend.data.invited_user.schedule_invited_user_tally_precompute"
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
invited_user = await retry_invited_user_tally("invite-1")
assert invited_user.id == "invite-1"
invited_user_repo.update.assert_awaited_once()
schedule.assert_called_once_with("invite-1")
@pytest.mark.asyncio
async def test_get_or_activate_user_requires_invite(
mocker: pytest_mock.MockerFixture,
) -> None:
invited_user_repo = Mock()
invited_user_repo.find_unique = AsyncMock(return_value=None)
mock_get_user_by_id = AsyncMock(side_effect=ValueError("User not found"))
mock_get_user_by_id.cache_delete = Mock()
mocker.patch(
"backend.data.invited_user.get_user_by_id",
mock_get_user_by_id,
)
mocker.patch(
"backend.data.invited_user._settings.config.enable_invite_gate",
True,
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=invited_user_repo,
)
with pytest.raises(NotAuthorizedError):
await get_or_activate_user(
{"sub": "auth-user-1", "email": "invited@example.com"}
)
@pytest.mark.asyncio
async def test_get_or_activate_user_creates_user_from_invite(
mocker: pytest_mock.MockerFixture,
) -> None:
tx = object()
invited_user = _invited_user_db_record(
tally_understanding={"user_name": "Invited User", "industry": "Automation"}
)
created_user = _user_db_record()
outside_user_repo = Mock()
# Only called once at post-transaction verification (line 741);
# get_user_by_id (line 657) uses prisma.user.find_unique, not this mock.
outside_user_repo.find_unique = AsyncMock(return_value=created_user)
inside_user_repo = Mock()
inside_user_repo.find_unique = AsyncMock(return_value=None)
inside_user_repo.create = AsyncMock(return_value=created_user)
outside_invited_repo = Mock()
outside_invited_repo.find_unique = AsyncMock(return_value=invited_user)
inside_invited_repo = Mock()
inside_invited_repo.find_unique = AsyncMock(return_value=invited_user)
inside_invited_repo.update = AsyncMock(return_value=invited_user)
def user_prisma(client=None):
return inside_user_repo if client is tx else outside_user_repo
def invited_user_prisma(client=None):
return inside_invited_repo if client is tx else outside_invited_repo
@asynccontextmanager
async def fake_transaction():
yield tx
# Mock get_user_by_id since it uses prisma.user.find_unique (global client),
# not prisma.models.User.prisma().find_unique which we mock above.
mock_get_user_by_id = AsyncMock(side_effect=ValueError("User not found"))
mock_get_user_by_id.cache_delete = Mock()
mocker.patch(
"backend.data.invited_user.get_user_by_id",
mock_get_user_by_id,
)
mock_get_user_by_email = AsyncMock()
mock_get_user_by_email.cache_delete = Mock()
mocker.patch(
"backend.data.invited_user.get_user_by_email",
mock_get_user_by_email,
)
ensure_profile = mocker.patch(
"backend.data.invited_user._ensure_default_profile",
AsyncMock(),
)
ensure_onboarding = mocker.patch(
"backend.data.invited_user._ensure_default_onboarding",
AsyncMock(),
)
apply_tally = mocker.patch(
"backend.data.invited_user._apply_tally_understanding",
AsyncMock(),
)
mocker.patch(
"backend.data.invited_user._settings.config.enable_invite_gate",
True,
)
mocker.patch("backend.data.invited_user.transaction", fake_transaction)
mocker.patch(
"backend.data.invited_user.prisma.models.User.prisma", side_effect=user_prisma
)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
side_effect=invited_user_prisma,
)
user = await get_or_activate_user(
{
"sub": "auth-user-1",
"email": "Invited@example.com",
"user_metadata": {"name": "Invited User"},
}
)
assert user.id == "auth-user-1"
inside_user_repo.create.assert_awaited_once()
inside_invited_repo.update.assert_awaited_once()
ensure_profile.assert_awaited_once()
ensure_onboarding.assert_awaited_once_with("auth-user-1", tx)
apply_tally.assert_awaited_once_with("auth-user-1", invited_user, tx)
@pytest.mark.asyncio
async def test_bulk_create_invited_users_from_text_file(
mocker: pytest_mock.MockerFixture,
) -> None:
create_invited = mocker.patch(
"backend.data.invited_user.create_invited_user",
AsyncMock(
side_effect=[
_invited_user_record(),
_invited_user_record(),
]
),
)
result = await bulk_create_invited_users_from_file(
"invites.txt",
b"Invited@example.com\nsecond@example.com\n",
)
assert result.created_count == 2
assert result.skipped_count == 0
assert result.error_count == 0
assert [row.status for row in result.results] == ["CREATED", "CREATED"]
assert create_invited.await_count == 2
@pytest.mark.asyncio
async def test_bulk_create_invited_users_handles_csv_duplicates_and_invalid_rows(
mocker: pytest_mock.MockerFixture,
) -> None:
create_invited = mocker.patch(
"backend.data.invited_user.create_invited_user",
AsyncMock(
side_effect=[
_invited_user_record(),
PreconditionFailed("An invited user with this email already exists"),
]
),
)
result = await bulk_create_invited_users_from_file(
"invites.csv",
(
"email,name\n"
"valid@example.com,Valid User\n"
"not-an-email,Bad Row\n"
"valid@example.com,Duplicate In File\n"
"existing@example.com,Existing User\n"
).encode("utf-8"),
)
assert result.created_count == 1
assert result.skipped_count == 2
assert result.error_count == 1
assert [row.status for row in result.results] == [
"CREATED",
"ERROR",
"SKIPPED",
"SKIPPED",
]
assert create_invited.await_count == 2
# ---------------------------------------------------------------------------
# check_invite_eligibility tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_check_invite_eligibility_returns_true_for_invited(
mocker: pytest_mock.MockerFixture,
) -> None:
invited = _invited_user_db_record(status=prisma.enums.InvitedUserStatus.INVITED)
repo = Mock()
repo.find_unique = AsyncMock(return_value=invited)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=repo,
)
result = await check_invite_eligibility("invited@example.com")
assert result is True
repo.find_unique.assert_awaited_once()
@pytest.mark.asyncio
async def test_check_invite_eligibility_returns_false_for_no_record(
mocker: pytest_mock.MockerFixture,
) -> None:
repo = Mock()
repo.find_unique = AsyncMock(return_value=None)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=repo,
)
result = await check_invite_eligibility("unknown@example.com")
assert result is False
@pytest.mark.asyncio
async def test_check_invite_eligibility_returns_false_for_claimed(
mocker: pytest_mock.MockerFixture,
) -> None:
claimed = _invited_user_db_record(status=prisma.enums.InvitedUserStatus.CLAIMED)
repo = Mock()
repo.find_unique = AsyncMock(return_value=claimed)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=repo,
)
result = await check_invite_eligibility("claimed@example.com")
assert result is False
@pytest.mark.asyncio
async def test_check_invite_eligibility_returns_false_for_revoked(
mocker: pytest_mock.MockerFixture,
) -> None:
revoked = _invited_user_db_record(status=prisma.enums.InvitedUserStatus.REVOKED)
repo = Mock()
repo.find_unique = AsyncMock(return_value=revoked)
mocker.patch(
"backend.data.invited_user.prisma.models.InvitedUser.prisma",
return_value=repo,
)
result = await check_invite_eligibility("revoked@example.com")
assert result is False

View File

@@ -0,0 +1,40 @@
"""LLM Registry - Dynamic model management system."""
from backend.blocks.llm import ModelMetadata
from .notifications import (
publish_registry_refresh_notification,
subscribe_to_registry_refresh,
)
from .registry import (
RegistryModel,
RegistryModelCost,
RegistryModelCreator,
clear_registry_cache,
get_all_model_slugs_for_validation,
get_all_models,
get_default_model_slug,
get_enabled_models,
get_model,
get_schema_options,
refresh_llm_registry,
)
__all__ = [
# Models
"ModelMetadata",
"RegistryModel",
"RegistryModelCost",
"RegistryModelCreator",
# Cache management
"clear_registry_cache",
"publish_registry_refresh_notification",
"subscribe_to_registry_refresh",
# Read functions
"refresh_llm_registry",
"get_model",
"get_all_models",
"get_enabled_models",
"get_schema_options",
"get_default_model_slug",
"get_all_model_slugs_for_validation",
]

View File

@@ -0,0 +1,84 @@
"""Pub/sub notifications for LLM registry cross-process synchronisation."""
from __future__ import annotations
import asyncio
import logging
from typing import Awaitable, Callable
logger = logging.getLogger(__name__)
REGISTRY_REFRESH_CHANNEL = "llm_registry:refresh"
async def publish_registry_refresh_notification() -> None:
"""Publish a refresh signal so all other workers reload their in-process cache."""
from backend.data.redis_client import get_redis_async
try:
redis = await get_redis_async()
await redis.publish(REGISTRY_REFRESH_CHANNEL, "refresh")
logger.debug("Published LLM registry refresh notification")
except Exception as e:
logger.warning("Failed to publish registry refresh notification: %s", e)
async def subscribe_to_registry_refresh(
on_refresh: Callable[[], Awaitable[None]],
) -> None:
"""Listen for registry refresh signals and call on_refresh each time one arrives.
Designed to run as a long-lived background asyncio.Task. Automatically
reconnects if the Redis connection drops.
Args:
on_refresh: Async callable invoked on each refresh signal.
Typically ``llm_registry.refresh_llm_registry``.
"""
from backend.data.redis_client import HOST, PASSWORD, PORT
from redis.asyncio import Redis as AsyncRedis
while True:
try:
# Dedicated connection — pub/sub must not share a connection used
# for regular commands.
redis_sub = AsyncRedis(
host=HOST, port=PORT, password=PASSWORD, decode_responses=True
)
pubsub = redis_sub.pubsub()
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
logger.info("Subscribed to LLM registry refresh channel")
while True:
try:
message = await pubsub.get_message(
ignore_subscribe_messages=True, timeout=1.0
)
if (
message
and message["type"] == "message"
and message["channel"] == REGISTRY_REFRESH_CHANNEL
):
logger.debug("LLM registry refresh signal received")
try:
await on_refresh()
except Exception as e:
logger.error(
"Error in registry on_refresh callback: %s", e
)
except asyncio.CancelledError:
raise
except Exception as e:
logger.warning(
"Error processing registry refresh message: %s", e
)
await asyncio.sleep(1)
except asyncio.CancelledError:
logger.info("LLM registry subscription task cancelled")
break
except Exception as e:
logger.warning(
"LLM registry subscription error: %s. Retrying in 5s...", e
)
await asyncio.sleep(5)

View File

@@ -0,0 +1,254 @@
"""Core LLM registry implementation for managing models dynamically."""
from __future__ import annotations
import asyncio
import logging
from typing import Any
import prisma.models
from pydantic import BaseModel, ConfigDict
from backend.blocks.llm import ModelMetadata
from backend.util.cache import cached
logger = logging.getLogger(__name__)
class RegistryModelCost(BaseModel):
"""Cost configuration for an LLM model."""
model_config = ConfigDict(frozen=True)
unit: str # "RUN" or "TOKENS"
credit_cost: int
credential_provider: str
credential_id: str | None = None
credential_type: str | None = None
currency: str | None = None
metadata: dict[str, Any] = {}
class RegistryModelCreator(BaseModel):
"""Creator information for an LLM model."""
model_config = ConfigDict(frozen=True)
id: str
name: str
display_name: str
description: str | None = None
website_url: str | None = None
logo_url: str | None = None
class RegistryModel(BaseModel):
"""Represents a model in the LLM registry."""
model_config = ConfigDict(frozen=True)
slug: str
display_name: str
description: str | None = None
metadata: ModelMetadata
capabilities: dict[str, Any] = {}
extra_metadata: dict[str, Any] = {}
provider_display_name: str
is_enabled: bool
is_recommended: bool = False
costs: tuple[RegistryModelCost, ...] = ()
creator: RegistryModelCreator | None = None
# Typed capability fields from DB schema
supports_tools: bool = False
supports_json_output: bool = False
supports_reasoning: bool = False
supports_parallel_tool_calls: bool = False
# L1 in-process cache — Redis is the shared L2 via @cached(shared_cache=True)
_dynamic_models: dict[str, RegistryModel] = {}
_schema_options: list[dict[str, str]] = []
_lock = asyncio.Lock()
def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel: # type: ignore[name-defined]
"""Transform a raw Prisma LlmModel record into a RegistryModel instance."""
costs = tuple(
RegistryModelCost(
unit=str(cost.unit),
credit_cost=cost.creditCost,
credential_provider=cost.credentialProvider,
credential_id=cost.credentialId,
credential_type=cost.credentialType,
currency=cost.currency,
metadata=dict(cost.metadata or {}),
)
for cost in (record.Costs or [])
)
creator = None
if record.Creator:
creator = RegistryModelCreator(
id=record.Creator.id,
name=record.Creator.name,
display_name=record.Creator.displayName,
description=record.Creator.description,
website_url=record.Creator.websiteUrl,
logo_url=record.Creator.logoUrl,
)
capabilities = dict(record.capabilities or {})
if not record.Provider:
logger.warning(
"LlmModel %s has no Provider despite NOT NULL FK - "
"falling back to providerId %s",
record.slug,
record.providerId,
)
provider_name = record.Provider.name if record.Provider else record.providerId
provider_display = (
record.Provider.displayName if record.Provider else record.providerId
)
creator_name = record.Creator.displayName if record.Creator else "Unknown"
if record.priceTier not in (1, 2, 3):
logger.warning(
"LlmModel %s has out-of-range priceTier=%s, defaulting to 1",
record.slug,
record.priceTier,
)
price_tier = record.priceTier if record.priceTier in (1, 2, 3) else 1
metadata = ModelMetadata(
provider=provider_name,
context_window=record.contextWindow,
max_output_tokens=(
record.maxOutputTokens
if record.maxOutputTokens is not None
else record.contextWindow
),
display_name=record.displayName,
provider_name=provider_display,
creator_name=creator_name,
price_tier=price_tier,
)
return RegistryModel(
slug=record.slug,
display_name=record.displayName,
description=record.description,
metadata=metadata,
capabilities=capabilities,
extra_metadata=dict(record.metadata or {}),
provider_display_name=provider_display,
is_enabled=record.isEnabled,
is_recommended=record.isRecommended,
costs=costs,
creator=creator,
supports_tools=record.supportsTools,
supports_json_output=record.supportsJsonOutput,
supports_reasoning=record.supportsReasoning,
supports_parallel_tool_calls=record.supportsParallelToolCalls,
)
@cached(maxsize=1, ttl_seconds=300, shared_cache=True, refresh_ttl_on_get=True)
async def _fetch_registry_from_db() -> list[RegistryModel]:
"""Fetch all LLM models from the database.
Results are cached in Redis (shared_cache=True) so subsequent calls within
the TTL window skip the DB entirely — both within this process and across
all other workers that share the same Redis instance.
"""
records = await prisma.models.LlmModel.prisma().find_many( # type: ignore[attr-defined]
include={"Provider": True, "Costs": True, "Creator": True}
)
logger.info("Fetched %d LLM models from database", len(records))
return [_record_to_registry_model(r) for r in records]
def clear_registry_cache() -> None:
"""Invalidate the shared Redis cache for the registry DB fetch.
Call this before refresh_llm_registry() after any admin DB mutation so the
next fetch hits the database rather than serving the now-stale cached data.
"""
_fetch_registry_from_db.cache_clear()
async def refresh_llm_registry() -> None:
"""Refresh the in-process L1 cache from Redis/DB.
On the first call (or after clear_registry_cache()), fetches fresh data
from the database and stores it in Redis. Subsequent calls by other
workers hit the Redis cache instead of the DB.
"""
async with _lock:
try:
models = await _fetch_registry_from_db()
new_models = {m.slug: m for m in models}
global _dynamic_models, _schema_options
_dynamic_models = new_models
_schema_options = _build_schema_options()
logger.info(
"LLM registry refreshed: %d models, %d schema options",
len(_dynamic_models),
len(_schema_options),
)
except Exception as e:
logger.error("Failed to refresh LLM registry: %s", e, exc_info=True)
raise
def _build_schema_options() -> list[dict[str, str]]:
"""Build schema options for model selection dropdown. Only includes enabled models."""
return [
{
"label": model.display_name,
"value": model.slug,
"group": model.metadata.provider,
"description": model.description or "",
}
for model in sorted(
_dynamic_models.values(), key=lambda m: m.display_name.lower()
)
if model.is_enabled
]
def get_model(slug: str) -> RegistryModel | None:
"""Get a model by slug from the registry."""
return _dynamic_models.get(slug)
def get_all_models() -> list[RegistryModel]:
"""Get all models from the registry (including disabled)."""
return list(_dynamic_models.values())
def get_enabled_models() -> list[RegistryModel]:
"""Get only enabled models from the registry."""
return [model for model in _dynamic_models.values() if model.is_enabled]
def get_schema_options() -> list[dict[str, str]]:
"""Get schema options for model selection dropdown (enabled models only)."""
return list(_schema_options)
def get_default_model_slug() -> str | None:
"""Get the default model slug (first recommended, or first enabled)."""
models = sorted(_dynamic_models.values(), key=lambda m: m.display_name)
recommended = next(
(m.slug for m in models if m.is_recommended and m.is_enabled), None
)
return recommended or next((m.slug for m in models if m.is_enabled), None)
def get_all_model_slugs_for_validation() -> list[str]:
"""Get all model slugs for validation (enabled models only)."""
return [model.slug for model in _dynamic_models.values() if model.is_enabled]

View File

@@ -0,0 +1,358 @@
"""Unit tests for the LLM registry module."""
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, Mock, patch
import pytest
import pydantic
from backend.data.llm_registry.registry import (
RegistryModel,
RegistryModelCost,
RegistryModelCreator,
_build_schema_options,
_record_to_registry_model,
get_default_model_slug,
get_schema_options,
refresh_llm_registry,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_mock_record(**overrides):
"""Build a realistic mock Prisma LlmModel record."""
provider = Mock()
provider.name = "openai"
provider.displayName = "OpenAI"
record = Mock()
record.slug = "openai/gpt-4o"
record.displayName = "GPT-4o"
record.description = "Latest GPT model"
record.providerId = "provider-uuid"
record.Provider = provider
record.creatorId = "creator-uuid"
record.Creator = None
record.contextWindow = 128000
record.maxOutputTokens = 16384
record.priceTier = 2
record.isEnabled = True
record.isRecommended = False
record.supportsTools = True
record.supportsJsonOutput = True
record.supportsReasoning = False
record.supportsParallelToolCalls = True
record.capabilities = {}
record.metadata = {}
record.Costs = []
for key, value in overrides.items():
setattr(record, key, value)
return record
def _make_registry_model(**kwargs) -> RegistryModel:
"""Build a minimal RegistryModel for testing registry-level functions."""
from backend.blocks.llm import ModelMetadata
defaults = dict(
slug="openai/gpt-4o",
display_name="GPT-4o",
description=None,
metadata=ModelMetadata(
provider="openai",
context_window=128000,
max_output_tokens=16384,
display_name="GPT-4o",
provider_name="OpenAI",
creator_name="Unknown",
price_tier=2,
),
capabilities={},
extra_metadata={},
provider_display_name="OpenAI",
is_enabled=True,
is_recommended=False,
)
defaults.update(kwargs)
return RegistryModel(**defaults)
# ---------------------------------------------------------------------------
# _record_to_registry_model tests
# ---------------------------------------------------------------------------
def test_record_to_registry_model():
"""Happy-path: well-formed record produces a correct RegistryModel."""
record = _make_mock_record()
model = _record_to_registry_model(record)
assert model.slug == "openai/gpt-4o"
assert model.display_name == "GPT-4o"
assert model.description == "Latest GPT model"
assert model.provider_display_name == "OpenAI"
assert model.is_enabled is True
assert model.is_recommended is False
assert model.supports_tools is True
assert model.supports_json_output is True
assert model.supports_reasoning is False
assert model.supports_parallel_tool_calls is True
assert model.metadata.provider == "openai"
assert model.metadata.context_window == 128000
assert model.metadata.max_output_tokens == 16384
assert model.metadata.price_tier == 2
assert model.creator is None
assert model.costs == ()
def test_record_to_registry_model_missing_provider(caplog):
"""Record with no Provider relation falls back to providerId and logs a warning."""
record = _make_mock_record(Provider=None, providerId="provider-uuid")
with caplog.at_level("WARNING"):
model = _record_to_registry_model(record)
assert "no Provider" in caplog.text
assert model.metadata.provider == "provider-uuid"
assert model.provider_display_name == "provider-uuid"
def test_record_to_registry_model_missing_creator():
"""When Creator is None, creator_name defaults to 'Unknown' and creator field is None."""
record = _make_mock_record(Creator=None)
model = _record_to_registry_model(record)
assert model.creator is None
assert model.metadata.creator_name == "Unknown"
def test_record_to_registry_model_with_creator():
"""When Creator is present, it is parsed into RegistryModelCreator."""
creator_mock = Mock()
creator_mock.id = "creator-uuid"
creator_mock.name = "openai"
creator_mock.displayName = "OpenAI"
creator_mock.description = "AI company"
creator_mock.websiteUrl = "https://openai.com"
creator_mock.logoUrl = "https://openai.com/logo.png"
record = _make_mock_record(Creator=creator_mock)
model = _record_to_registry_model(record)
assert model.creator is not None
assert isinstance(model.creator, RegistryModelCreator)
assert model.creator.id == "creator-uuid"
assert model.creator.display_name == "OpenAI"
assert model.metadata.creator_name == "OpenAI"
def test_record_to_registry_model_null_max_output_tokens():
"""maxOutputTokens=None falls back to contextWindow."""
record = _make_mock_record(maxOutputTokens=None, contextWindow=64000)
model = _record_to_registry_model(record)
assert model.metadata.max_output_tokens == 64000
def test_record_to_registry_model_invalid_price_tier(caplog):
"""Out-of-range priceTier is coerced to 1 and a warning is logged."""
record = _make_mock_record(priceTier=99)
with caplog.at_level("WARNING"):
model = _record_to_registry_model(record)
assert "out-of-range priceTier" in caplog.text
assert model.metadata.price_tier == 1
def test_record_to_registry_model_with_costs():
"""Costs are parsed into RegistryModelCost tuples."""
cost_mock = Mock()
cost_mock.unit = "TOKENS"
cost_mock.creditCost = 10
cost_mock.credentialProvider = "openai"
cost_mock.credentialId = None
cost_mock.credentialType = None
cost_mock.currency = "USD"
cost_mock.metadata = {}
record = _make_mock_record(Costs=[cost_mock])
model = _record_to_registry_model(record)
assert len(model.costs) == 1
cost = model.costs[0]
assert isinstance(cost, RegistryModelCost)
assert cost.unit == "TOKENS"
assert cost.credit_cost == 10
assert cost.credential_provider == "openai"
# ---------------------------------------------------------------------------
# get_default_model_slug tests
# ---------------------------------------------------------------------------
def test_get_default_model_slug_recommended():
"""Recommended model is preferred over non-recommended enabled models."""
import backend.data.llm_registry.registry as reg
reg._dynamic_models = {
"openai/gpt-4o": _make_registry_model(
slug="openai/gpt-4o", display_name="GPT-4o", is_recommended=False
),
"openai/gpt-4o-recommended": _make_registry_model(
slug="openai/gpt-4o-recommended",
display_name="GPT-4o Recommended",
is_recommended=True,
),
}
result = get_default_model_slug()
assert result == "openai/gpt-4o-recommended"
def test_get_default_model_slug_fallback():
"""With no recommended model, falls back to first enabled (alphabetical)."""
import backend.data.llm_registry.registry as reg
reg._dynamic_models = {
"openai/gpt-4o": _make_registry_model(
slug="openai/gpt-4o", display_name="GPT-4o", is_recommended=False
),
"openai/gpt-3.5": _make_registry_model(
slug="openai/gpt-3.5", display_name="GPT-3.5", is_recommended=False
),
}
result = get_default_model_slug()
# Sorted alphabetically: GPT-3.5 < GPT-4o
assert result == "openai/gpt-3.5"
def test_get_default_model_slug_empty():
"""Empty registry returns None."""
import backend.data.llm_registry.registry as reg
reg._dynamic_models = {}
result = get_default_model_slug()
assert result is None
# ---------------------------------------------------------------------------
# _build_schema_options / get_schema_options tests
# ---------------------------------------------------------------------------
def test_build_schema_options():
"""Only enabled models appear, sorted case-insensitively."""
import backend.data.llm_registry.registry as reg
reg._dynamic_models = {
"openai/gpt-4o": _make_registry_model(
slug="openai/gpt-4o", display_name="GPT-4o", is_enabled=True
),
"openai/disabled": _make_registry_model(
slug="openai/disabled", display_name="Disabled Model", is_enabled=False
),
"openai/gpt-3.5": _make_registry_model(
slug="openai/gpt-3.5", display_name="gpt-3.5", is_enabled=True
),
}
options = _build_schema_options()
slugs = [o["value"] for o in options]
# disabled model should be excluded
assert "openai/disabled" not in slugs
# only enabled models
assert "openai/gpt-4o" in slugs
assert "openai/gpt-3.5" in slugs
# case-insensitive sort: "gpt-3.5" < "GPT-4o" (both lowercase: "gpt-3.5" < "gpt-4o")
assert slugs.index("openai/gpt-3.5") < slugs.index("openai/gpt-4o")
# Verify structure
for option in options:
assert "label" in option
assert "value" in option
assert "group" in option
assert "description" in option
def test_get_schema_options_returns_copy():
"""Mutating the returned list does not affect the internal cache."""
import backend.data.llm_registry.registry as reg
reg._dynamic_models = {
"openai/gpt-4o": _make_registry_model(slug="openai/gpt-4o", display_name="GPT-4o"),
}
reg._schema_options = _build_schema_options()
options = get_schema_options()
original_length = len(options)
options.append({"label": "Injected", "value": "evil/model", "group": "evil", "description": ""})
# Internal state should be unchanged
assert len(get_schema_options()) == original_length
# ---------------------------------------------------------------------------
# Pydantic frozen model tests
# ---------------------------------------------------------------------------
def test_registry_model_frozen():
"""Pydantic frozen=True should reject attribute assignment."""
model = _make_registry_model()
with pytest.raises((pydantic.ValidationError, TypeError)):
model.slug = "changed/slug" # type: ignore[misc]
def test_registry_model_cost_frozen():
"""RegistryModelCost is also frozen."""
cost = RegistryModelCost(
unit="TOKENS",
credit_cost=5,
credential_provider="openai",
)
with pytest.raises((pydantic.ValidationError, TypeError)):
cost.unit = "RUN" # type: ignore[misc]
# ---------------------------------------------------------------------------
# refresh_llm_registry tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_refresh_llm_registry():
"""Mock prisma find_many, verify cache is populated after refresh."""
import backend.data.llm_registry.registry as reg
record = _make_mock_record()
mock_find_many = AsyncMock(return_value=[record])
with patch("prisma.models.LlmModel.prisma") as mock_prisma_cls:
mock_prisma_instance = Mock()
mock_prisma_instance.find_many = mock_find_many
mock_prisma_cls.return_value = mock_prisma_instance
# Clear state first
reg._dynamic_models = {}
reg._schema_options = []
await refresh_llm_registry()
assert "openai/gpt-4o" in reg._dynamic_models
model = reg._dynamic_models["openai/gpt-4o"]
assert isinstance(model, RegistryModel)
assert model.slug == "openai/gpt-4o"
# Schema options should be populated too
assert len(reg._schema_options) == 1
assert reg._schema_options[0]["value"] == "openai/gpt-4o"

View File

@@ -41,7 +41,7 @@ _MAX_PAGES = 100
_LLM_TIMEOUT = 30
def mask_email(email: str) -> str:
def _mask_email(email: str) -> str:
"""Mask an email for safe logging: 'alice@example.com' -> 'a***e@example.com'."""
try:
local, domain = email.rsplit("@", 1)
@@ -196,7 +196,8 @@ async def _refresh_cache(form_id: str) -> tuple[dict, list]:
Returns (email_index, questions).
"""
client = _make_tally_client(_settings.secrets.tally_api_key)
settings = Settings()
client = _make_tally_client(settings.secrets.tally_api_key)
redis = await get_redis_async()
last_fetch_key = _LAST_FETCH_KEY.format(form_id=form_id)
@@ -331,9 +332,6 @@ Fields:
- current_software (list of strings): software/tools currently used
- existing_automation (list of strings): existing automations
- additional_notes (string): any additional context
- suggested_prompts (list of 5 strings): short action prompts (each under 20 words) that would help \
this person get started with automating their work. Should be specific to their industry, role, and \
pain points; actionable and conversational in tone; focused on automation opportunities.
Form data:
"""
@@ -341,21 +339,21 @@ Form data:
_EXTRACTION_SUFFIX = "\n\nReturn ONLY valid JSON."
async def extract_business_understanding_from_tally(
async def extract_business_understanding(
formatted_text: str,
) -> BusinessUnderstandingInput:
"""
Use an LLM to extract structured business understanding from form text.
"""Use an LLM to extract structured business understanding from form text.
Raises on timeout or unparseable response so the caller can handle it.
"""
api_key = _settings.secrets.open_router_api_key
settings = Settings()
api_key = settings.secrets.open_router_api_key
client = AsyncOpenAI(api_key=api_key, base_url=OPENROUTER_BASE_URL)
try:
response = await asyncio.wait_for(
client.chat.completions.create(
model=_settings.config.tally_extraction_llm_model,
model="openai/gpt-4o-mini",
messages=[
{
"role": "user",
@@ -380,57 +378,9 @@ async def extract_business_understanding_from_tally(
# Filter out null values before constructing
cleaned = {k: v for k, v in data.items() if v is not None}
# Validate suggested_prompts: filter >20 words, keep top 3
raw_prompts = cleaned.get("suggested_prompts", [])
if isinstance(raw_prompts, list):
valid = [
p.strip()
for p in raw_prompts
if isinstance(p, str) and len(p.strip().split()) <= 20
]
# This will keep up to 3 suggestions
short_prompts = valid[:3] if valid else None
if short_prompts:
cleaned["suggested_prompts"] = short_prompts
else:
# We dont want to add a None value suggested_prompts field
cleaned.pop("suggested_prompts", None)
else:
# suggested_prompts must be a list - removing it as its not here
cleaned.pop("suggested_prompts", None)
return BusinessUnderstandingInput(**cleaned)
async def get_business_understanding_input_from_tally(
email: str,
*,
require_api_key: bool = False,
) -> Optional[BusinessUnderstandingInput]:
if not _settings.secrets.tally_api_key:
if require_api_key:
raise RuntimeError("Tally API key is not configured")
logger.debug("Tally: no API key configured, skipping")
return None
masked = mask_email(email)
result = await find_submission_by_email(TALLY_FORM_ID, email)
if result is None:
logger.debug(f"Tally: no submission found for {masked}")
return None
submission, questions = result
logger.info(f"Tally: found submission for {masked}, extracting understanding")
formatted = format_submission_for_llm(submission, questions)
if not formatted.strip():
logger.warning("Tally: formatted submission was empty, skipping")
return None
return await extract_business_understanding_from_tally(formatted)
async def populate_understanding_from_tally(user_id: str, email: str) -> None:
"""Main orchestrator: check Tally for a matching submission and populate understanding.
@@ -445,9 +395,32 @@ async def populate_understanding_from_tally(user_id: str, email: str) -> None:
)
return
understanding_input = await get_business_understanding_input_from_tally(email)
if understanding_input is None:
# Check required config is present
settings = Settings()
if not settings.secrets.tally_api_key or not settings.secrets.tally_form_id:
logger.debug("Tally: Tally config incomplete, skipping")
return
if not settings.secrets.open_router_api_key:
logger.debug("Tally: no OpenRouter API key configured, skipping")
return
# Look up submission by email
masked = _mask_email(email)
result = await find_submission_by_email(settings.secrets.tally_form_id, email)
if result is None:
logger.debug(f"Tally: no submission found for {masked}")
return
submission, questions = result
logger.info(f"Tally: found submission for {masked}, extracting understanding")
# Format and extract
formatted = format_submission_for_llm(submission, questions)
if not formatted.strip():
logger.warning("Tally: formatted submission was empty, skipping")
return
understanding_input = await extract_business_understanding(formatted)
# Upsert into database
await upsert_business_understanding(user_id, understanding_input)

View File

@@ -12,11 +12,11 @@ from backend.data.tally import (
_build_email_index,
_format_answer,
_make_tally_client,
_mask_email,
_refresh_cache,
extract_business_understanding_from_tally,
extract_business_understanding,
find_submission_by_email,
format_submission_for_llm,
mask_email,
populate_understanding_from_tally,
)
@@ -248,7 +248,7 @@ async def test_populate_understanding_skips_no_api_key():
new_callable=AsyncMock,
return_value=None,
),
patch("backend.data.tally._settings", mock_settings),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.find_submission_by_email",
new_callable=AsyncMock,
@@ -284,7 +284,6 @@ async def test_populate_understanding_full_flow():
],
}
mock_input = MagicMock()
mock_input.suggested_prompts = ["Prompt 1", "Prompt 2", "Prompt 3"]
with (
patch(
@@ -292,14 +291,14 @@ async def test_populate_understanding_full_flow():
new_callable=AsyncMock,
return_value=None,
),
patch("backend.data.tally._settings", mock_settings),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.find_submission_by_email",
new_callable=AsyncMock,
return_value=(submission, SAMPLE_QUESTIONS),
),
patch(
"backend.data.tally.extract_business_understanding_from_tally",
"backend.data.tally.extract_business_understanding",
new_callable=AsyncMock,
return_value=mock_input,
) as mock_extract,
@@ -332,14 +331,14 @@ async def test_populate_understanding_handles_llm_timeout():
new_callable=AsyncMock,
return_value=None,
),
patch("backend.data.tally._settings", mock_settings),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.find_submission_by_email",
new_callable=AsyncMock,
return_value=(submission, SAMPLE_QUESTIONS),
),
patch(
"backend.data.tally.extract_business_understanding_from_tally",
"backend.data.tally.extract_business_understanding",
new_callable=AsyncMock,
side_effect=asyncio.TimeoutError(),
),
@@ -357,13 +356,13 @@ async def test_populate_understanding_handles_llm_timeout():
def test_mask_email():
assert mask_email("alice@example.com") == "a***e@example.com"
assert mask_email("ab@example.com") == "a***@example.com"
assert mask_email("a@example.com") == "a***@example.com"
assert _mask_email("alice@example.com") == "a***e@example.com"
assert _mask_email("ab@example.com") == "a***@example.com"
assert _mask_email("a@example.com") == "a***@example.com"
def test_mask_email_invalid():
assert mask_email("no-at-sign") == "***"
assert _mask_email("no-at-sign") == "***"
# ── Prompt construction (curly-brace safety) ─────────────────────────────────
@@ -394,11 +393,11 @@ def test_extraction_prompt_no_format_placeholders():
assert single_braces == [], f"Found format placeholders: {single_braces}"
# ── extract_business_understanding_from_tally ────────────────────────────────────────────
# ── extract_business_understanding ────────────────────────────────────────────
@pytest.mark.asyncio
async def test_extract_business_understanding_from_tally_success():
async def test_extract_business_understanding_success():
"""Happy path: LLM returns valid JSON that maps to BusinessUnderstandingInput."""
mock_choice = MagicMock()
mock_choice.message.content = json.dumps(
@@ -407,13 +406,6 @@ async def test_extract_business_understanding_from_tally_success():
"business_name": "Acme Corp",
"industry": "Technology",
"pain_points": ["manual reporting"],
"suggested_prompts": [
"Automate weekly reports",
"Set up invoice processing",
"Create a customer onboarding flow",
"Track project deadlines automatically",
"Send follow-up emails after meetings",
],
}
)
mock_response = MagicMock()
@@ -423,56 +415,16 @@ async def test_extract_business_understanding_from_tally_success():
mock_client.chat.completions.create.return_value = mock_response
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
result = await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
result = await extract_business_understanding("Q: Name?\nA: Alice")
assert result.user_name == "Alice"
assert result.business_name == "Acme Corp"
assert result.industry == "Technology"
assert result.pain_points == ["manual reporting"]
# suggested_prompts validated and sliced to top 3
assert result.suggested_prompts == [
"Automate weekly reports",
"Set up invoice processing",
"Create a customer onboarding flow",
]
@pytest.mark.asyncio
async def test_extract_business_understanding_from_tally_filters_long_prompts():
"""Prompts exceeding 20 words are excluded and only top 3 are kept."""
long_prompt = " ".join(["word"] * 21)
mock_choice = MagicMock()
mock_choice.message.content = json.dumps(
{
"user_name": "Alice",
"suggested_prompts": [
long_prompt,
"Short prompt one",
long_prompt,
"Short prompt two",
"Short prompt three",
"Short prompt four",
],
}
)
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_client = AsyncMock()
mock_client.chat.completions.create.return_value = mock_response
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
result = await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
assert result.suggested_prompts == [
"Short prompt one",
"Short prompt two",
"Short prompt three",
]
@pytest.mark.asyncio
async def test_extract_business_understanding_from_tally_filters_nulls():
async def test_extract_business_understanding_filters_nulls():
"""Null values from LLM should be excluded from the result."""
mock_choice = MagicMock()
mock_choice.message.content = json.dumps(
@@ -485,7 +437,7 @@ async def test_extract_business_understanding_from_tally_filters_nulls():
mock_client.chat.completions.create.return_value = mock_response
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
result = await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
result = await extract_business_understanding("Q: Name?\nA: Alice")
assert result.user_name == "Alice"
assert result.business_name is None
@@ -493,7 +445,7 @@ async def test_extract_business_understanding_from_tally_filters_nulls():
@pytest.mark.asyncio
async def test_extract_business_understanding_from_tally_invalid_json():
async def test_extract_business_understanding_invalid_json():
"""Invalid JSON from LLM should raise JSONDecodeError."""
mock_choice = MagicMock()
mock_choice.message.content = "not valid json {"
@@ -507,11 +459,11 @@ async def test_extract_business_understanding_from_tally_invalid_json():
patch("backend.data.tally.AsyncOpenAI", return_value=mock_client),
pytest.raises(json.JSONDecodeError),
):
await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
await extract_business_understanding("Q: Name?\nA: Alice")
@pytest.mark.asyncio
async def test_extract_business_understanding_from_tally_timeout():
async def test_extract_business_understanding_timeout():
"""LLM timeout should propagate as asyncio.TimeoutError."""
mock_client = AsyncMock()
mock_client.chat.completions.create.side_effect = asyncio.TimeoutError()
@@ -521,7 +473,7 @@ async def test_extract_business_understanding_from_tally_timeout():
patch("backend.data.tally._LLM_TIMEOUT", 0.001),
pytest.raises(asyncio.TimeoutError),
):
await extract_business_understanding_from_tally("Q: Name?\nA: Alice")
await extract_business_understanding("Q: Name?\nA: Alice")
# ── _refresh_cache ───────────────────────────────────────────────────────────
@@ -540,7 +492,7 @@ async def test_refresh_cache_full_fetch():
submissions = SAMPLE_SUBMISSIONS
with (
patch("backend.data.tally._settings", mock_settings),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.get_redis_async",
new_callable=AsyncMock,
@@ -588,7 +540,7 @@ async def test_refresh_cache_incremental_fetch():
new_submissions = [SAMPLE_SUBMISSIONS[0]] # Just Alice
with (
patch("backend.data.tally._settings", mock_settings),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.get_redis_async",
new_callable=AsyncMock,

View File

@@ -86,11 +86,6 @@ class BusinessUnderstandingInput(pydantic.BaseModel):
None, description="Any additional context"
)
# Suggested prompts (UI-only, not included in system prompt)
suggested_prompts: Optional[list[str]] = pydantic.Field(
None, description="LLM-generated suggested prompts based on business context"
)
class BusinessUnderstanding(pydantic.BaseModel):
"""Full business understanding model returned from database."""
@@ -127,9 +122,6 @@ class BusinessUnderstanding(pydantic.BaseModel):
# Additional context
additional_notes: Optional[str] = None
# Suggested prompts (UI-only, not included in system prompt)
suggested_prompts: list[str] = pydantic.Field(default_factory=list)
@classmethod
def from_db(cls, db_record: CoPilotUnderstanding) -> "BusinessUnderstanding":
"""Convert database record to Pydantic model."""
@@ -157,7 +149,6 @@ class BusinessUnderstanding(pydantic.BaseModel):
current_software=_json_to_list(business.get("current_software")),
existing_automation=_json_to_list(business.get("existing_automation")),
additional_notes=business.get("additional_notes"),
suggested_prompts=_json_to_list(data.get("suggested_prompts")),
)
@@ -175,62 +166,6 @@ def _merge_lists(existing: list | None, new: list | None) -> list | None:
return merged
def merge_business_understanding_data(
existing_data: dict[str, Any],
input_data: BusinessUnderstandingInput,
) -> dict[str, Any]:
merged_data = dict(existing_data)
merged_business: dict[str, Any] = {}
if isinstance(merged_data.get("business"), dict):
merged_business = dict(merged_data["business"])
business_string_fields = [
"job_title",
"business_name",
"industry",
"business_size",
"user_role",
"additional_notes",
]
business_list_fields = [
"key_workflows",
"daily_activities",
"pain_points",
"bottlenecks",
"manual_tasks",
"automation_goals",
"current_software",
"existing_automation",
]
if input_data.user_name is not None:
merged_data["name"] = input_data.user_name
for field in business_string_fields:
value = getattr(input_data, field)
if value is not None:
merged_business[field] = value
for field in business_list_fields:
value = getattr(input_data, field)
if value is not None:
existing_list = _json_to_list(merged_business.get(field))
merged_list = _merge_lists(existing_list, value)
merged_business[field] = merged_list
merged_business["version"] = 1
merged_data["business"] = merged_business
# suggested_prompts lives at the top level (not under `business`) because
# it's a UI-only artifact consumed by the frontend, not business understanding
# data. The `business` sub-dict feeds the system prompt.
if input_data.suggested_prompts is not None:
merged_data["suggested_prompts"] = input_data.suggested_prompts
return merged_data
async def _get_from_cache(user_id: str) -> Optional[BusinessUnderstanding]:
"""Get business understanding from Redis cache."""
try:
@@ -310,18 +245,63 @@ async def upsert_business_understanding(
where={"userId": user_id}
)
# Get existing data structure or start fresh
existing_data: dict[str, Any] = {}
if existing and isinstance(existing.data, dict):
existing_data = dict(existing.data)
merged_data = merge_business_understanding_data(existing_data, input_data)
existing_business: dict[str, Any] = {}
if isinstance(existing_data.get("business"), dict):
existing_business = dict(existing_data["business"])
# Business fields (stored inside business object)
business_string_fields = [
"job_title",
"business_name",
"industry",
"business_size",
"user_role",
"additional_notes",
]
business_list_fields = [
"key_workflows",
"daily_activities",
"pain_points",
"bottlenecks",
"manual_tasks",
"automation_goals",
"current_software",
"existing_automation",
]
# Handle top-level name field
if input_data.user_name is not None:
existing_data["name"] = input_data.user_name
# Business string fields - overwrite if provided
for field in business_string_fields:
value = getattr(input_data, field)
if value is not None:
existing_business[field] = value
# Business list fields - merge with existing
for field in business_list_fields:
value = getattr(input_data, field)
if value is not None:
existing_list = _json_to_list(existing_business.get(field))
merged = _merge_lists(existing_list, value)
existing_business[field] = merged
# Set version and nest business data
existing_business["version"] = 1
existing_data["business"] = existing_business
# Upsert with the merged data
record = await CoPilotUnderstanding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "data": SafeJson(merged_data)},
"update": {"data": SafeJson(merged_data)},
"create": {"userId": user_id, "data": SafeJson(existing_data)},
"update": {"data": SafeJson(existing_data)},
},
)

View File

@@ -1,102 +0,0 @@
"""Tests for business understanding merge and format logic."""
from datetime import datetime, timezone
from typing import Any
from backend.data.understanding import (
BusinessUnderstanding,
BusinessUnderstandingInput,
format_understanding_for_prompt,
merge_business_understanding_data,
)
def _make_input(**kwargs: Any) -> BusinessUnderstandingInput:
"""Create a BusinessUnderstandingInput with only the specified fields."""
return BusinessUnderstandingInput.model_validate(kwargs)
# ─── merge_business_understanding_data: suggested_prompts ─────────────
def test_merge_suggested_prompts_overwrites_existing():
"""New suggested_prompts should fully replace existing ones (not append)."""
existing = {
"name": "Alice",
"business": {"industry": "Tech", "version": 1},
"suggested_prompts": ["Old prompt 1", "Old prompt 2"],
}
input_data = _make_input(
suggested_prompts=["New prompt A", "New prompt B", "New prompt C"],
)
result = merge_business_understanding_data(existing, input_data)
assert result["suggested_prompts"] == [
"New prompt A",
"New prompt B",
"New prompt C",
]
def test_merge_suggested_prompts_none_preserves_existing():
"""When input has suggested_prompts=None, existing prompts are preserved."""
existing = {
"name": "Alice",
"business": {"industry": "Tech", "version": 1},
"suggested_prompts": ["Keep me"],
}
input_data = _make_input(industry="Finance")
result = merge_business_understanding_data(existing, input_data)
assert result["suggested_prompts"] == ["Keep me"]
assert result["business"]["industry"] == "Finance"
def test_merge_suggested_prompts_added_to_empty_data():
"""Suggested prompts are set at top level even when starting from empty data."""
existing: dict[str, Any] = {}
input_data = _make_input(suggested_prompts=["Prompt 1"])
result = merge_business_understanding_data(existing, input_data)
assert result["suggested_prompts"] == ["Prompt 1"]
def test_merge_suggested_prompts_empty_list_overwrites():
"""An explicit empty list should overwrite existing prompts."""
existing: dict[str, Any] = {
"suggested_prompts": ["Old prompt"],
"business": {"version": 1},
}
input_data = _make_input(suggested_prompts=[])
result = merge_business_understanding_data(existing, input_data)
assert result["suggested_prompts"] == []
# ─── format_understanding_for_prompt: excludes suggested_prompts ──────
def test_format_understanding_excludes_suggested_prompts():
"""suggested_prompts is UI-only and must NOT appear in the system prompt."""
understanding = BusinessUnderstanding(
id="test-id",
user_id="user-1",
created_at=datetime.now(tz=timezone.utc),
updated_at=datetime.now(tz=timezone.utc),
user_name="Alice",
industry="Technology",
suggested_prompts=["Automate reports", "Set up alerts", "Track KPIs"],
)
formatted = format_understanding_for_prompt(understanding)
assert "Alice" in formatted
assert "Technology" in formatted
assert "suggested_prompts" not in formatted
assert "Automate reports" not in formatted
assert "Set up alerts" not in formatted
assert "Track KPIs" not in formatted

View File

@@ -224,7 +224,7 @@ async def execute_node(
# Sanity check: validate the execution input.
input_data, error = validate_exec(node, data.inputs, resolve_input=False)
if input_data is None:
log_metadata.error(f"Skip execution, input validation error: {error}")
log_metadata.warning(f"Skip execution, input validation error: {error}")
yield "error", error
return

View File

@@ -0,0 +1,6 @@
"""LLM registry API (public + admin)."""
from .admin_routes import router as admin_router
from .routes import router
__all__ = ["router", "admin_router"]

View File

@@ -0,0 +1,115 @@
"""Request/response models for LLM registry admin API."""
from typing import Any
from pydantic import BaseModel, Field
class CreateLlmProviderRequest(BaseModel):
"""Request model for creating an LLM provider."""
name: str = Field(
..., description="Provider identifier (e.g., 'openai', 'anthropic')"
)
display_name: str = Field(..., description="Human-readable provider name")
description: str | None = Field(None, description="Provider description")
default_credential_provider: str | None = Field(
None, description="Default credential system identifier"
)
default_credential_id: str | None = Field(None, description="Default credential ID")
default_credential_type: str | None = Field(
None, description="Default credential type"
)
metadata: dict[str, Any] = Field(
default_factory=dict, description="Additional metadata"
)
class UpdateLlmProviderRequest(BaseModel):
"""Request model for updating an LLM provider."""
display_name: str | None = Field(None, description="Human-readable provider name")
description: str | None = Field(None, description="Provider description")
default_credential_provider: str | None = Field(
None, description="Default credential system identifier"
)
default_credential_id: str | None = Field(None, description="Default credential ID")
default_credential_type: str | None = Field(
None, description="Default credential type"
)
metadata: dict[str, Any] | None = Field(None, description="Additional metadata")
class CreateLlmModelRequest(BaseModel):
"""Request model for creating an LLM model."""
slug: str = Field(..., description="Model slug (e.g., 'gpt-4', 'claude-3-opus')")
display_name: str = Field(..., description="Human-readable model name")
description: str | None = Field(None, description="Model description")
provider_id: str = Field(..., description="Provider ID (UUID)")
creator_id: str | None = Field(None, description="Creator ID (UUID)")
context_window: int = Field(
..., description="Maximum context window in tokens", gt=0
)
max_output_tokens: int | None = Field(
None, description="Maximum output tokens (None if unlimited)", gt=0
)
price_tier: int = Field(
..., description="Price tier (1=cheapest, 2=medium, 3=expensive)", ge=1, le=3
)
is_enabled: bool = Field(default=True, description="Whether the model is enabled")
is_recommended: bool = Field(
default=False, description="Whether the model is recommended"
)
supports_tools: bool = Field(default=False, description="Supports function calling")
supports_json_output: bool = Field(
default=False, description="Supports JSON output mode"
)
supports_reasoning: bool = Field(
default=False, description="Supports reasoning mode"
)
supports_parallel_tool_calls: bool = Field(
default=False, description="Supports parallel tool calls"
)
capabilities: dict[str, Any] = Field(
default_factory=dict, description="Additional capabilities"
)
metadata: dict[str, Any] = Field(
default_factory=dict, description="Additional metadata"
)
costs: list[dict[str, Any]] = Field(
default_factory=list, description="Cost entries for the model"
)
class UpdateLlmModelRequest(BaseModel):
"""Request model for updating an LLM model."""
display_name: str | None = Field(None, description="Human-readable model name")
description: str | None = Field(None, description="Model description")
creator_id: str | None = Field(None, description="Creator ID (UUID)")
context_window: int | None = Field(
None, description="Maximum context window in tokens", gt=0
)
max_output_tokens: int | None = Field(
None, description="Maximum output tokens (None if unlimited)", gt=0
)
price_tier: int | None = Field(
None, description="Price tier (1=cheapest, 2=medium, 3=expensive)", ge=1, le=3
)
is_enabled: bool | None = Field(None, description="Whether the model is enabled")
is_recommended: bool | None = Field(
None, description="Whether the model is recommended"
)
supports_tools: bool | None = Field(None, description="Supports function calling")
supports_json_output: bool | None = Field(
None, description="Supports JSON output mode"
)
supports_reasoning: bool | None = Field(None, description="Supports reasoning mode")
supports_parallel_tool_calls: bool | None = Field(
None, description="Supports parallel tool calls"
)
capabilities: dict[str, Any] | None = Field(
None, description="Additional capabilities"
)
metadata: dict[str, Any] | None = Field(None, description="Additional metadata")

View File

@@ -0,0 +1,689 @@
"""Admin API for LLM registry management.
Provides endpoints for:
- Reading creators (GET)
- Creating, updating, and deleting models
- Creating, updating, and deleting providers
All endpoints require admin authentication. Mutations refresh the registry cache.
"""
import logging
from typing import Any
import prisma
import autogpt_libs.auth
from fastapi import APIRouter, HTTPException, Security, status
from backend.server.v2.llm import db_write
from backend.server.v2.llm.admin_model import (
CreateLlmModelRequest,
CreateLlmProviderRequest,
UpdateLlmModelRequest,
UpdateLlmProviderRequest,
)
logger = logging.getLogger(__name__)
router = APIRouter()
def _map_provider_response(provider: Any) -> dict[str, Any]:
"""Map Prisma provider model to response dict."""
return {
"id": provider.id,
"name": provider.name,
"display_name": provider.displayName,
"description": provider.description,
"default_credential_provider": provider.defaultCredentialProvider,
"default_credential_id": provider.defaultCredentialId,
"default_credential_type": provider.defaultCredentialType,
"metadata": dict(provider.metadata or {}),
"created_at": provider.createdAt.isoformat() if provider.createdAt else None,
"updated_at": provider.updatedAt.isoformat() if provider.updatedAt else None,
}
def _map_model_response(model: Any) -> dict[str, Any]:
"""Map Prisma model to response dict."""
return {
"id": model.id,
"slug": model.slug,
"display_name": model.displayName,
"description": model.description,
"provider_id": model.providerId,
"creator_id": model.creatorId,
"context_window": model.contextWindow,
"max_output_tokens": model.maxOutputTokens,
"price_tier": model.priceTier,
"is_enabled": model.isEnabled,
"is_recommended": model.isRecommended,
"supports_tools": model.supportsTools,
"supports_json_output": model.supportsJsonOutput,
"supports_reasoning": model.supportsReasoning,
"supports_parallel_tool_calls": model.supportsParallelToolCalls,
"capabilities": dict(model.capabilities or {}),
"metadata": dict(model.metadata or {}),
"created_at": model.createdAt.isoformat() if model.createdAt else None,
"updated_at": model.updatedAt.isoformat() if model.updatedAt else None,
}
def _map_creator_response(creator: Any) -> dict[str, Any]:
"""Map Prisma creator model to response dict."""
return {
"id": creator.id,
"name": creator.name,
"display_name": creator.displayName,
"description": creator.description,
"website_url": creator.websiteUrl,
"logo_url": creator.logoUrl,
"metadata": dict(creator.metadata or {}),
"created_at": creator.createdAt.isoformat() if creator.createdAt else None,
"updated_at": creator.updatedAt.isoformat() if creator.updatedAt else None,
}
@router.post(
"/llm/models",
status_code=status.HTTP_201_CREATED,
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
)
async def create_model(
request: CreateLlmModelRequest,
) -> dict[str, Any]:
"""Create a new LLM model.
Requires admin authentication.
"""
try:
import prisma.models as pm
# Resolve provider name to ID
provider = await pm.LlmProvider.prisma().find_unique(
where={"name": request.provider_id}
)
if not provider:
# Try as UUID fallback
provider = await pm.LlmProvider.prisma().find_unique(
where={"id": request.provider_id}
)
if not provider:
raise HTTPException(
status_code=404,
detail=f"Provider '{request.provider_id}' not found",
)
model = await db_write.create_model(
slug=request.slug,
display_name=request.display_name,
provider_id=provider.id,
context_window=request.context_window,
price_tier=request.price_tier,
description=request.description,
creator_id=request.creator_id,
max_output_tokens=request.max_output_tokens,
is_enabled=request.is_enabled,
is_recommended=request.is_recommended,
supports_tools=request.supports_tools,
supports_json_output=request.supports_json_output,
supports_reasoning=request.supports_reasoning,
supports_parallel_tool_calls=request.supports_parallel_tool_calls,
capabilities=request.capabilities,
metadata=request.metadata,
)
# Create costs if provided in the raw request body
if hasattr(request, 'costs') and request.costs:
for cost_input in request.costs:
await pm.LlmModelCost.prisma().create(
data={
"unit": cost_input.get("unit", "RUN"),
"creditCost": int(cost_input.get("credit_cost", 1)),
"credentialProvider": provider.name,
"metadata": prisma.Json(cost_input.get("metadata", {})),
"Model": {"connect": {"id": model.id}},
}
)
await db_write.refresh_runtime_caches()
logger.info(f"Created model '{request.slug}' (id: {model.id})")
# Re-fetch with costs included
model = await pm.LlmModel.prisma().find_unique(
where={"id": model.id},
include={"Costs": True, "Creator": True},
)
return _map_model_response(model)
except ValueError as e:
logger.warning(f"Model creation validation failed: {e}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.exception(f"Failed to create model: {e}")
raise HTTPException(status_code=500, detail="Failed to create model")
@router.patch(
"/llm/models/{slug:path}",
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
)
async def update_model(
slug: str,
request: UpdateLlmModelRequest,
) -> dict[str, Any]:
"""Update an existing LLM model.
Requires admin authentication.
"""
try:
# Find model by slug first to get ID
import prisma.models
existing = await prisma.models.LlmModel.prisma().find_unique(
where={"slug": slug}
)
if not existing:
raise HTTPException(
status_code=404, detail=f"Model with slug '{slug}' not found"
)
model = await db_write.update_model(
model_id=existing.id,
display_name=request.display_name,
description=request.description,
creator_id=request.creator_id,
context_window=request.context_window,
max_output_tokens=request.max_output_tokens,
price_tier=request.price_tier,
is_enabled=request.is_enabled,
is_recommended=request.is_recommended,
supports_tools=request.supports_tools,
supports_json_output=request.supports_json_output,
supports_reasoning=request.supports_reasoning,
supports_parallel_tool_calls=request.supports_parallel_tool_calls,
capabilities=request.capabilities,
metadata=request.metadata,
)
await db_write.refresh_runtime_caches()
logger.info(f"Updated model '{slug}' (id: {model.id})")
return _map_model_response(model)
except ValueError as e:
logger.warning(f"Model update validation failed: {e}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.exception(f"Failed to update model: {e}")
raise HTTPException(status_code=500, detail="Failed to update model")
@router.delete(
"/llm/models/{slug:path}",
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
)
async def delete_model(
slug: str,
replacement_model_slug: str | None = None,
) -> dict[str, Any]:
"""Delete an LLM model with optional migration.
If workflows are using this model and no replacement_model_slug is given,
returns 400 with the node count. Provide replacement_model_slug to migrate
affected nodes before deletion.
"""
try:
import prisma.models
existing = await prisma.models.LlmModel.prisma().find_unique(
where={"slug": slug}
)
if not existing:
raise HTTPException(
status_code=404, detail=f"Model with slug '{slug}' not found"
)
result = await db_write.delete_model(
model_id=existing.id,
replacement_model_slug=replacement_model_slug,
)
await db_write.refresh_runtime_caches()
logger.info(
f"Deleted model '{slug}' (migrated {result['nodes_migrated']} nodes)"
)
return result
except ValueError as e:
logger.warning(f"Model deletion validation failed: {e}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.exception(f"Failed to delete model: {e}")
raise HTTPException(status_code=500, detail="Failed to delete model")
@router.get(
"/llm/models/{slug:path}/usage",
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
)
async def get_model_usage(slug: str) -> dict[str, Any]:
"""Get usage count for a model — how many workflow nodes reference it."""
try:
return await db_write.get_model_usage(slug)
except Exception as e:
logger.exception(f"Failed to get model usage: {e}")
raise HTTPException(status_code=500, detail="Failed to get model usage")
@router.post(
"/llm/models/{slug:path}/toggle",
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
)
async def toggle_model(
slug: str,
request: dict[str, Any],
) -> dict[str, Any]:
"""Toggle a model's enabled status with optional migration when disabling.
Body params:
is_enabled: bool
migrate_to_slug: optional str
migration_reason: optional str
custom_credit_cost: optional int
"""
try:
import prisma.models
existing = await prisma.models.LlmModel.prisma().find_unique(
where={"slug": slug}
)
if not existing:
raise HTTPException(
status_code=404, detail=f"Model with slug '{slug}' not found"
)
result = await db_write.toggle_model_with_migration(
model_id=existing.id,
is_enabled=request.get("is_enabled", True),
migrate_to_slug=request.get("migrate_to_slug"),
migration_reason=request.get("migration_reason"),
custom_credit_cost=request.get("custom_credit_cost"),
)
await db_write.refresh_runtime_caches()
logger.info(
f"Toggled model '{slug}' enabled={request.get('is_enabled')} "
f"(migrated {result['nodes_migrated']} nodes)"
)
return result
except ValueError as e:
logger.warning(f"Model toggle failed: {e}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.exception(f"Failed to toggle model: {e}")
raise HTTPException(status_code=500, detail="Failed to toggle model")
@router.get(
"/llm/migrations",
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
)
async def list_migrations(
include_reverted: bool = False,
) -> dict[str, Any]:
"""List model migrations."""
try:
migrations = await db_write.list_migrations(
include_reverted=include_reverted
)
return {"migrations": migrations}
except Exception as e:
logger.exception(f"Failed to list migrations: {e}")
raise HTTPException(
status_code=500, detail="Failed to list migrations"
)
@router.post(
"/llm/migrations/{migration_id}/revert",
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
)
async def revert_migration(
migration_id: str,
re_enable_source_model: bool = True,
) -> dict[str, Any]:
"""Revert a model migration, restoring affected nodes."""
try:
result = await db_write.revert_migration(
migration_id=migration_id,
re_enable_source_model=re_enable_source_model,
)
await db_write.refresh_runtime_caches()
logger.info(
f"Reverted migration {migration_id}: "
f"{result['nodes_reverted']} nodes restored"
)
return result
except ValueError as e:
logger.warning(f"Migration revert failed: {e}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.exception(f"Failed to revert migration: {e}")
raise HTTPException(
status_code=500, detail="Failed to revert migration"
)
@router.post(
"/llm/providers",
status_code=status.HTTP_201_CREATED,
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
)
async def create_provider(
request: CreateLlmProviderRequest,
) -> dict[str, Any]:
"""Create a new LLM provider.
Requires admin authentication.
"""
try:
provider = await db_write.create_provider(
name=request.name,
display_name=request.display_name,
description=request.description,
default_credential_provider=request.default_credential_provider,
default_credential_id=request.default_credential_id,
default_credential_type=request.default_credential_type,
metadata=request.metadata,
)
await db_write.refresh_runtime_caches()
logger.info(f"Created provider '{request.name}' (id: {provider.id})")
return _map_provider_response(provider)
except ValueError as e:
logger.warning(f"Provider creation validation failed: {e}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.exception(f"Failed to create provider: {e}")
raise HTTPException(status_code=500, detail="Failed to create provider")
@router.patch(
"/llm/providers/{name}",
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
)
async def update_provider(
name: str,
request: UpdateLlmProviderRequest,
) -> dict[str, Any]:
"""Update an existing LLM provider.
Requires admin authentication.
"""
try:
# Find provider by name first to get ID
import prisma.models
existing = await prisma.models.LlmProvider.prisma().find_unique(
where={"name": name}
)
if not existing:
raise HTTPException(
status_code=404, detail=f"Provider with name '{name}' not found"
)
provider = await db_write.update_provider(
provider_id=existing.id,
display_name=request.display_name,
description=request.description,
default_credential_provider=request.default_credential_provider,
default_credential_id=request.default_credential_id,
default_credential_type=request.default_credential_type,
metadata=request.metadata,
)
await db_write.refresh_runtime_caches()
logger.info(f"Updated provider '{name}' (id: {provider.id})")
return _map_provider_response(provider)
except ValueError as e:
logger.warning(f"Provider update validation failed: {e}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.exception(f"Failed to update provider: {e}")
raise HTTPException(status_code=500, detail="Failed to update provider")
@router.delete(
"/llm/providers/{name}",
status_code=status.HTTP_204_NO_CONTENT,
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
)
async def delete_provider(
name: str,
) -> None:
"""Delete an LLM provider.
Requires admin authentication.
A provider can only be deleted if it has no associated models.
"""
try:
# Find provider by name first to get ID
import prisma.models
existing = await prisma.models.LlmProvider.prisma().find_unique(
where={"name": name}
)
if not existing:
raise HTTPException(
status_code=404, detail=f"Provider with name '{name}' not found"
)
await db_write.delete_provider(provider_id=existing.id)
await db_write.refresh_runtime_caches()
logger.info(f"Deleted provider '{name}' (id: {existing.id})")
except ValueError as e:
logger.warning(f"Provider deletion validation failed: {e}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.exception(f"Failed to delete provider: {e}")
raise HTTPException(status_code=500, detail="Failed to delete provider")
@router.get(
"/llm/admin/providers",
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
)
async def admin_list_providers() -> dict[str, Any]:
"""List all LLM providers from the database.
Unlike the public endpoint, this returns ALL providers including
those with no models. Requires admin authentication.
"""
try:
import prisma.models
providers = await prisma.models.LlmProvider.prisma().find_many(
order={"name": "asc"},
include={"Models": True},
)
return {
"providers": [
{**_map_provider_response(p), "model_count": len(p.Models) if p.Models else 0}
for p in providers
]
}
except Exception as e:
logger.exception(f"Failed to list providers: {e}")
raise HTTPException(status_code=500, detail="Failed to list providers")
@router.get(
"/llm/admin/models",
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
)
async def admin_list_models(
page: int = 1,
page_size: int = 100,
enabled_only: bool = False,
) -> dict[str, Any]:
"""List all LLM models from the database.
Unlike the public endpoint, this returns full model data including
costs and creator info. Requires admin authentication.
"""
try:
import prisma.models
where = {"isEnabled": True} if enabled_only else {}
models = await prisma.models.LlmModel.prisma().find_many(
where=where,
skip=(page - 1) * page_size,
take=page_size,
order={"displayName": "asc"},
include={"Costs": True, "Creator": True},
)
return {
"models": [
{
**_map_model_response(m),
"creator": _map_creator_response(m.Creator) if m.Creator else None,
"costs": [
{
"unit": c.unit,
"credit_cost": float(c.creditCost),
"credential_provider": c.credentialProvider,
"credential_type": c.credentialType,
"metadata": dict(c.metadata or {}),
}
for c in (m.Costs or [])
],
}
for m in models
]
}
except Exception as e:
logger.exception(f"Failed to list models: {e}")
raise HTTPException(status_code=500, detail="Failed to list models")
@router.get(
"/llm/creators",
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
)
async def list_creators() -> dict[str, Any]:
"""List all LLM model creators.
Requires admin authentication.
"""
try:
import prisma.models
creators = await prisma.models.LlmModelCreator.prisma().find_many(
order={"name": "asc"}
)
logger.info(f"Retrieved {len(creators)} creators")
return {"creators": [_map_creator_response(c) for c in creators]}
except Exception as e:
logger.exception(f"Failed to list creators: {e}")
raise HTTPException(status_code=500, detail="Failed to list creators")
@router.post(
"/llm/creators",
status_code=status.HTTP_201_CREATED,
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
)
async def create_creator(
request: dict[str, Any],
) -> dict[str, Any]:
"""Create a new LLM model creator."""
try:
import prisma.models
creator = await prisma.models.LlmModelCreator.prisma().create(
data={
"name": request["name"],
"displayName": request["display_name"],
"description": request.get("description"),
"websiteUrl": request.get("website_url"),
"logoUrl": request.get("logo_url"),
"metadata": prisma.Json(request.get("metadata", {})),
}
)
logger.info(f"Created creator '{creator.name}' (id: {creator.id})")
return _map_creator_response(creator)
except Exception as e:
logger.exception(f"Failed to create creator: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.patch(
"/llm/creators/{name}",
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
)
async def update_creator(
name: str,
request: dict[str, Any],
) -> dict[str, Any]:
"""Update an existing LLM model creator."""
try:
import prisma.models
existing = await prisma.models.LlmModelCreator.prisma().find_unique(
where={"name": name}
)
if not existing:
raise HTTPException(
status_code=404, detail=f"Creator '{name}' not found"
)
data: dict[str, Any] = {}
if "display_name" in request:
data["displayName"] = request["display_name"]
if "description" in request:
data["description"] = request["description"]
if "website_url" in request:
data["websiteUrl"] = request["website_url"]
if "logo_url" in request:
data["logoUrl"] = request["logo_url"]
creator = await prisma.models.LlmModelCreator.prisma().update(
where={"id": existing.id},
data=data,
)
logger.info(f"Updated creator '{name}' (id: {creator.id})")
return _map_creator_response(creator)
except HTTPException:
raise
except Exception as e:
logger.exception(f"Failed to update creator: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete(
"/llm/creators/{name}",
status_code=status.HTTP_204_NO_CONTENT,
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
)
async def delete_creator(
name: str,
) -> None:
"""Delete an LLM model creator."""
try:
import prisma.models
existing = await prisma.models.LlmModelCreator.prisma().find_unique(
where={"name": name},
include={"Models": True},
)
if not existing:
raise HTTPException(
status_code=404, detail=f"Creator '{name}' not found"
)
if existing.Models and len(existing.Models) > 0:
raise HTTPException(
status_code=400,
detail=f"Cannot delete creator '{name}' — it has {len(existing.Models)} associated models",
)
await prisma.models.LlmModelCreator.prisma().delete(
where={"id": existing.id}
)
logger.info(f"Deleted creator '{name}' (id: {existing.id})")
except HTTPException:
raise
except Exception as e:
logger.exception(f"Failed to delete creator: {e}")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -0,0 +1,588 @@
"""Database write operations for LLM registry admin API."""
import json
import logging
from datetime import datetime, timezone
from typing import Any
import prisma
import prisma.models
from backend.data import llm_registry
from backend.data.db import transaction
logger = logging.getLogger(__name__)
def _build_provider_data(
name: str,
display_name: str,
description: str | None = None,
default_credential_provider: str | None = None,
default_credential_id: str | None = None,
default_credential_type: str | None = None,
metadata: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Build provider data dict for Prisma operations."""
return {
"name": name,
"displayName": display_name,
"description": description,
"defaultCredentialProvider": default_credential_provider,
"defaultCredentialId": default_credential_id,
"defaultCredentialType": default_credential_type,
"metadata": prisma.Json(metadata or {}),
}
def _build_model_data(
slug: str,
display_name: str,
provider_id: str,
context_window: int,
price_tier: int,
description: str | None = None,
creator_id: str | None = None,
max_output_tokens: int | None = None,
is_enabled: bool = True,
is_recommended: bool = False,
supports_tools: bool = False,
supports_json_output: bool = False,
supports_reasoning: bool = False,
supports_parallel_tool_calls: bool = False,
capabilities: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Build model data dict for Prisma operations."""
data: dict[str, Any] = {
"slug": slug,
"displayName": display_name,
"description": description,
"Provider": {"connect": {"id": provider_id}},
"contextWindow": context_window,
"maxOutputTokens": max_output_tokens,
"priceTier": price_tier,
"isEnabled": is_enabled,
"isRecommended": is_recommended,
"supportsTools": supports_tools,
"supportsJsonOutput": supports_json_output,
"supportsReasoning": supports_reasoning,
"supportsParallelToolCalls": supports_parallel_tool_calls,
"capabilities": prisma.Json(capabilities or {}),
"metadata": prisma.Json(metadata or {}),
}
if creator_id:
data["Creator"] = {"connect": {"id": creator_id}}
return data
async def create_provider(
name: str,
display_name: str,
description: str | None = None,
default_credential_provider: str | None = None,
default_credential_id: str | None = None,
default_credential_type: str | None = None,
metadata: dict[str, Any] | None = None,
) -> prisma.models.LlmProvider:
"""Create a new LLM provider."""
data = _build_provider_data(
name=name,
display_name=display_name,
description=description,
default_credential_provider=default_credential_provider,
default_credential_id=default_credential_id,
default_credential_type=default_credential_type,
metadata=metadata,
)
provider = await prisma.models.LlmProvider.prisma().create(
data=data,
include={"Models": True},
)
if not provider:
raise ValueError("Failed to create provider")
return provider
async def update_provider(
provider_id: str,
display_name: str | None = None,
description: str | None = None,
default_credential_provider: str | None = None,
default_credential_id: str | None = None,
default_credential_type: str | None = None,
metadata: dict[str, Any] | None = None,
) -> prisma.models.LlmProvider:
"""Update an existing LLM provider."""
# Fetch existing provider to get current name
provider = await prisma.models.LlmProvider.prisma().find_unique(
where={"id": provider_id}
)
if not provider:
raise ValueError(f"Provider with id '{provider_id}' not found")
# Build update data (only include fields that are provided)
data: dict[str, Any] = {}
if display_name is not None:
data["displayName"] = display_name
if description is not None:
data["description"] = description
if default_credential_provider is not None:
data["defaultCredentialProvider"] = default_credential_provider
if default_credential_id is not None:
data["defaultCredentialId"] = default_credential_id
if default_credential_type is not None:
data["defaultCredentialType"] = default_credential_type
if metadata is not None:
data["metadata"] = prisma.Json(metadata)
updated = await prisma.models.LlmProvider.prisma().update(
where={"id": provider_id},
data=data,
include={"Models": True},
)
if not updated:
raise ValueError("Failed to update provider")
return updated
async def delete_provider(provider_id: str) -> bool:
"""Delete an LLM provider.
A provider can only be deleted if it has no associated models.
"""
# Check if provider exists
provider = await prisma.models.LlmProvider.prisma().find_unique(
where={"id": provider_id},
include={"Models": True},
)
if not provider:
raise ValueError(f"Provider with id '{provider_id}' not found")
# Check if provider has any models
model_count = len(provider.Models) if provider.Models else 0
if model_count > 0:
raise ValueError(
f"Cannot delete provider '{provider.displayName}' because it has "
f"{model_count} model(s). Delete all models first."
)
await prisma.models.LlmProvider.prisma().delete(where={"id": provider_id})
return True
async def create_model(
slug: str,
display_name: str,
provider_id: str,
context_window: int,
price_tier: int,
description: str | None = None,
creator_id: str | None = None,
max_output_tokens: int | None = None,
is_enabled: bool = True,
is_recommended: bool = False,
supports_tools: bool = False,
supports_json_output: bool = False,
supports_reasoning: bool = False,
supports_parallel_tool_calls: bool = False,
capabilities: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> prisma.models.LlmModel:
"""Create a new LLM model."""
data = _build_model_data(
slug=slug,
display_name=display_name,
provider_id=provider_id,
context_window=context_window,
price_tier=price_tier,
description=description,
creator_id=creator_id,
max_output_tokens=max_output_tokens,
is_enabled=is_enabled,
is_recommended=is_recommended,
supports_tools=supports_tools,
supports_json_output=supports_json_output,
supports_reasoning=supports_reasoning,
supports_parallel_tool_calls=supports_parallel_tool_calls,
capabilities=capabilities,
metadata=metadata,
)
model = await prisma.models.LlmModel.prisma().create(
data=data,
include={"Costs": True, "Creator": True, "Provider": True},
)
if not model:
raise ValueError("Failed to create model")
return model
async def update_model(
model_id: str,
display_name: str | None = None,
description: str | None = None,
creator_id: str | None = None,
context_window: int | None = None,
max_output_tokens: int | None = None,
price_tier: int | None = None,
is_enabled: bool | None = None,
is_recommended: bool | None = None,
supports_tools: bool | None = None,
supports_json_output: bool | None = None,
supports_reasoning: bool | None = None,
supports_parallel_tool_calls: bool | None = None,
capabilities: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> prisma.models.LlmModel:
"""Update an existing LLM model.
When is_recommended=True, clears the flag on all other models first so
only one model can be recommended at a time.
"""
# Build update data (only include fields that are provided)
data: dict[str, Any] = {}
if display_name is not None:
data["displayName"] = display_name
if description is not None:
data["description"] = description
if context_window is not None:
data["contextWindow"] = context_window
if max_output_tokens is not None:
data["maxOutputTokens"] = max_output_tokens
if price_tier is not None:
data["priceTier"] = price_tier
if is_enabled is not None:
data["isEnabled"] = is_enabled
if is_recommended is not None:
data["isRecommended"] = is_recommended
if supports_tools is not None:
data["supportsTools"] = supports_tools
if supports_json_output is not None:
data["supportsJsonOutput"] = supports_json_output
if supports_reasoning is not None:
data["supportsReasoning"] = supports_reasoning
if supports_parallel_tool_calls is not None:
data["supportsParallelToolCalls"] = supports_parallel_tool_calls
if capabilities is not None:
data["capabilities"] = prisma.Json(capabilities)
if metadata is not None:
data["metadata"] = prisma.Json(metadata)
if creator_id is not None:
data["creatorId"] = creator_id if creator_id else None
async with transaction() as tx:
# Enforce single recommended model: unset all others first.
if is_recommended is True:
await tx.llmmodel.update_many(
where={"id": {"not": model_id}},
data={"isRecommended": False},
)
model = await tx.llmmodel.update(
where={"id": model_id},
data=data,
include={"Costs": True, "Creator": True, "Provider": True},
)
if not model:
raise ValueError(f"Model with id '{model_id}' not found")
return model
async def get_model_usage(slug: str) -> dict[str, Any]:
"""Get usage count for a model — how many AgentNodes reference it."""
import prisma as prisma_module
count_result = await prisma_module.get_client().query_raw(
"""
SELECT COUNT(*) as count
FROM "AgentNode"
WHERE "constantInput"::jsonb->>'model' = $1
""",
slug,
)
node_count = int(count_result[0]["count"]) if count_result else 0
return {"model_slug": slug, "node_count": node_count}
async def toggle_model_with_migration(
model_id: str,
is_enabled: bool,
migrate_to_slug: str | None = None,
migration_reason: str | None = None,
custom_credit_cost: int | None = None,
) -> dict[str, Any]:
"""Toggle a model's enabled status, optionally migrating workflows when disabling."""
model = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id}, include={"Costs": True}
)
if not model:
raise ValueError(f"Model with id '{model_id}' not found")
nodes_migrated = 0
migration_id: str | None = None
if not is_enabled and migrate_to_slug:
async with transaction() as tx:
replacement = await tx.llmmodel.find_unique(
where={"slug": migrate_to_slug}
)
if not replacement:
raise ValueError(
f"Replacement model '{migrate_to_slug}' not found"
)
if not replacement.isEnabled:
raise ValueError(
f"Replacement model '{migrate_to_slug}' is disabled. "
f"Please enable it before using it as a replacement."
)
node_ids_result = await tx.query_raw(
"""
SELECT id
FROM "AgentNode"
WHERE "constantInput"::jsonb->>'model' = $1
FOR UPDATE
""",
model.slug,
)
migrated_node_ids = (
[row["id"] for row in node_ids_result] if node_ids_result else []
)
nodes_migrated = len(migrated_node_ids)
if nodes_migrated > 0:
node_ids_json = json.dumps(migrated_node_ids)
await tx.execute_raw(
"""
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
to_jsonb($1::text)
)
WHERE id::text IN (
SELECT jsonb_array_elements_text($2::jsonb)
)
""",
migrate_to_slug,
node_ids_json,
)
await tx.llmmodel.update(
where={"id": model_id},
data={"isEnabled": is_enabled},
)
if nodes_migrated > 0:
migration_record = await tx.llmmodelmigration.create(
data={
"sourceModelSlug": model.slug,
"targetModelSlug": migrate_to_slug,
"reason": migration_reason,
"migratedNodeIds": json.dumps(migrated_node_ids),
"nodeCount": nodes_migrated,
"customCreditCost": custom_credit_cost,
}
)
migration_id = migration_record.id
else:
await prisma.models.LlmModel.prisma().update(
where={"id": model_id},
data={"isEnabled": is_enabled},
)
return {
"nodes_migrated": nodes_migrated,
"migrated_to_slug": migrate_to_slug if nodes_migrated > 0 else None,
"migration_id": migration_id,
}
async def delete_model(
model_id: str, replacement_model_slug: str | None = None
) -> dict[str, Any]:
"""Delete an LLM model, optionally migrating affected AgentNodes first.
If workflows are using this model and no replacement is given, raises ValueError.
If replacement is given, atomically migrates all affected nodes then deletes.
"""
model = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id}, include={"Costs": True}
)
if not model:
raise ValueError(f"Model with id '{model_id}' not found")
deleted_slug = model.slug
deleted_display_name = model.displayName
async with transaction() as tx:
count_result = await tx.query_raw(
"""
SELECT COUNT(*) as count
FROM "AgentNode"
WHERE "constantInput"::jsonb->>'model' = $1
""",
deleted_slug,
)
nodes_to_migrate = int(count_result[0]["count"]) if count_result else 0
if nodes_to_migrate > 0:
if not replacement_model_slug:
raise ValueError(
f"Cannot delete model '{deleted_slug}': {nodes_to_migrate} workflow node(s) "
f"are using it. Please provide a replacement_model_slug to migrate them."
)
replacement = await tx.llmmodel.find_unique(
where={"slug": replacement_model_slug}
)
if not replacement:
raise ValueError(
f"Replacement model '{replacement_model_slug}' not found"
)
if not replacement.isEnabled:
raise ValueError(
f"Replacement model '{replacement_model_slug}' is disabled."
)
await tx.execute_raw(
"""
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
to_jsonb($1::text)
)
WHERE "constantInput"::jsonb->>'model' = $2
""",
replacement_model_slug,
deleted_slug,
)
await tx.llmmodel.delete(where={"id": model_id})
return {
"deleted_model_slug": deleted_slug,
"deleted_model_display_name": deleted_display_name,
"replacement_model_slug": replacement_model_slug,
"nodes_migrated": nodes_to_migrate,
}
async def list_migrations(
include_reverted: bool = False,
) -> list[dict[str, Any]]:
"""List model migrations."""
where: Any = None if include_reverted else {"isReverted": False}
records = await prisma.models.LlmModelMigration.prisma().find_many(
where=where,
order={"createdAt": "desc"},
)
return [
{
"id": r.id,
"source_model_slug": r.sourceModelSlug,
"target_model_slug": r.targetModelSlug,
"reason": r.reason,
"node_count": r.nodeCount,
"custom_credit_cost": r.customCreditCost,
"is_reverted": r.isReverted,
"reverted_at": r.revertedAt.isoformat() if r.revertedAt else None,
"created_at": r.createdAt.isoformat(),
}
for r in records
]
async def revert_migration(
migration_id: str,
re_enable_source_model: bool = True,
) -> dict[str, Any]:
"""Revert a model migration, restoring affected nodes to their original model."""
migration = await prisma.models.LlmModelMigration.prisma().find_unique(
where={"id": migration_id}
)
if not migration:
raise ValueError(f"Migration with id '{migration_id}' not found")
if migration.isReverted:
raise ValueError(
f"Migration '{migration_id}' has already been reverted"
)
source_model = await prisma.models.LlmModel.prisma().find_unique(
where={"slug": migration.sourceModelSlug}
)
if not source_model:
raise ValueError(
f"Source model '{migration.sourceModelSlug}' no longer exists."
)
migrated_node_ids: list[str] = (
migration.migratedNodeIds
if isinstance(migration.migratedNodeIds, list)
else json.loads(migration.migratedNodeIds) # type: ignore
)
if not migrated_node_ids:
raise ValueError("No nodes to revert in this migration")
source_model_re_enabled = False
async with transaction() as tx:
if not source_model.isEnabled and re_enable_source_model:
await tx.llmmodel.update(
where={"id": source_model.id},
data={"isEnabled": True},
)
source_model_re_enabled = True
node_ids_json = json.dumps(migrated_node_ids)
result = await tx.execute_raw(
"""
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
to_jsonb($1::text)
)
WHERE id::text IN (
SELECT jsonb_array_elements_text($2::jsonb)
)
AND "constantInput"::jsonb->>'model' = $3
""",
migration.sourceModelSlug,
node_ids_json,
migration.targetModelSlug,
)
nodes_reverted = result if isinstance(result, int) else 0
await tx.llmmodelmigration.update(
where={"id": migration_id},
data={
"isReverted": True,
"revertedAt": datetime.now(timezone.utc),
},
)
return {
"migration_id": migration_id,
"source_model_slug": migration.sourceModelSlug,
"target_model_slug": migration.targetModelSlug,
"nodes_reverted": nodes_reverted,
"nodes_already_changed": len(migrated_node_ids) - nodes_reverted,
"source_model_re_enabled": source_model_re_enabled,
}
async def refresh_runtime_caches() -> None:
"""Invalidate the shared Redis cache, refresh this process, notify other workers."""
from backend.data.llm_registry.notifications import (
publish_registry_refresh_notification,
)
# Invalidate Redis so the next fetch hits the DB.
llm_registry.clear_registry_cache()
# Refresh this process (also repopulates Redis via @cached(shared_cache=True)).
await llm_registry.refresh_llm_registry()
# Tell other workers to reload their in-process cache from the fresh Redis data.
await publish_registry_refresh_notification()

View File

@@ -0,0 +1,68 @@
"""Pydantic models for LLM registry public API."""
from __future__ import annotations
from typing import Any
import pydantic
class LlmModelCost(pydantic.BaseModel):
"""Cost configuration for an LLM model."""
unit: str # "RUN" or "TOKENS"
credit_cost: int = pydantic.Field(ge=0)
credential_provider: str
credential_id: str | None = None
credential_type: str | None = None
currency: str | None = None
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class LlmModelCreator(pydantic.BaseModel):
"""Represents the organization that created/trained the model."""
id: str
name: str
display_name: str
description: str | None = None
website_url: str | None = None
logo_url: str | None = None
class LlmModel(pydantic.BaseModel):
"""Public-facing LLM model information."""
slug: str
display_name: str
description: str | None = None
provider_name: str
creator: LlmModelCreator | None = None
context_window: int
max_output_tokens: int | None = None
price_tier: int # 1=cheapest, 2=medium, 3=expensive
is_enabled: bool = True
is_recommended: bool = False
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
costs: list[LlmModelCost] = pydantic.Field(default_factory=list)
class LlmProvider(pydantic.BaseModel):
"""Provider with its enabled models."""
name: str
display_name: str
models: list[LlmModel] = pydantic.Field(default_factory=list)
class LlmModelsResponse(pydantic.BaseModel):
"""Response for GET /llm/models."""
models: list[LlmModel]
total: int
class LlmProvidersResponse(pydantic.BaseModel):
"""Response for GET /llm/providers."""
providers: list[LlmProvider]

View File

@@ -0,0 +1,143 @@
"""Public read-only API for LLM registry."""
import autogpt_libs.auth
import fastapi
from backend.data.llm_registry import (
RegistryModelCreator,
get_all_models,
get_enabled_models,
)
from backend.server.v2.llm import model as llm_model
router = fastapi.APIRouter(
prefix="/llm",
tags=["llm"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
)
def _map_creator(
creator: RegistryModelCreator | None,
) -> llm_model.LlmModelCreator | None:
"""Convert registry creator to API model."""
if not creator:
return None
return llm_model.LlmModelCreator(
id=creator.id,
name=creator.name,
display_name=creator.display_name,
description=creator.description,
website_url=creator.website_url,
logo_url=creator.logo_url,
)
@router.get("/models", response_model=llm_model.LlmModelsResponse)
async def list_models(
enabled_only: bool = fastapi.Query(
default=True, description="Only return enabled models"
),
):
"""
List all LLM models available to users.
Returns models from the in-memory registry cache.
Use enabled_only=true to filter to only enabled models (default).
"""
# Get models from in-memory registry
registry_models = get_enabled_models() if enabled_only else get_all_models()
# Map to API response models
models = [
llm_model.LlmModel(
slug=model.slug,
display_name=model.display_name,
description=model.description,
provider_name=model.provider_display_name,
creator=_map_creator(model.creator),
context_window=model.metadata.context_window,
max_output_tokens=model.metadata.max_output_tokens,
price_tier=model.metadata.price_tier,
is_enabled=model.is_enabled,
is_recommended=model.is_recommended,
capabilities=model.capabilities,
costs=[
llm_model.LlmModelCost(
unit=cost.unit,
credit_cost=cost.credit_cost,
credential_provider=cost.credential_provider,
credential_id=cost.credential_id,
credential_type=cost.credential_type,
currency=cost.currency,
metadata=cost.metadata,
)
for cost in model.costs
],
)
for model in registry_models
]
return llm_model.LlmModelsResponse(models=models, total=len(models))
@router.get("/providers", response_model=llm_model.LlmProvidersResponse)
async def list_providers():
"""
List all LLM providers with their enabled models.
Groups enabled models by provider from the in-memory registry.
"""
# Get all enabled models and group by provider
registry_models = get_enabled_models()
# Group models by provider
provider_map: dict[str, list] = {}
for model in registry_models:
provider_key = model.metadata.provider
if provider_key not in provider_map:
provider_map[provider_key] = []
provider_map[provider_key].append(model)
# Build provider responses
providers = []
for provider_key, models in sorted(provider_map.items()):
# Use the first model's provider display name
display_name = models[0].provider_display_name if models else provider_key
providers.append(
llm_model.LlmProvider(
name=provider_key,
display_name=display_name,
models=[
llm_model.LlmModel(
slug=model.slug,
display_name=model.display_name,
description=model.description,
provider_name=model.provider_display_name,
creator=_map_creator(model.creator),
context_window=model.metadata.context_window,
max_output_tokens=model.metadata.max_output_tokens,
price_tier=model.metadata.price_tier,
is_enabled=model.is_enabled,
is_recommended=model.is_recommended,
capabilities=model.capabilities,
costs=[
llm_model.LlmModelCost(
unit=cost.unit,
credit_cost=cost.credit_cost,
credential_provider=cost.credential_provider,
credential_id=cost.credential_id,
credential_type=cost.credential_type,
currency=cost.currency,
metadata=cost.metadata,
)
for cost in model.costs
],
)
for model in sorted(models, key=lambda m: m.display_name)
],
)
)
return llm_model.LlmProvidersResponse(providers=providers)

View File

@@ -0,0 +1,14 @@
"""Override session-scoped fixtures from parent conftest.py so unit tests
in this directory can run without the full server stack."""
import pytest
@pytest.fixture(scope="session")
def server():
yield None
@pytest.fixture(scope="session", autouse=True)
def graph_cleanup():
yield

View File

@@ -0,0 +1,150 @@
"""Helpers for OpenAI Responses API.
This module provides utilities for using OpenAI's Responses API, which is the
default for all OpenAI models supported by the platform.
"""
from typing import Any
def convert_tools_to_responses_format(tools: list[dict] | None) -> list[dict]:
"""Convert Chat Completions tool format to Responses API format.
The Responses API uses internally-tagged polymorphism (flatter structure)
and functions are strict by default.
Chat Completions format:
{"type": "function", "function": {"name": "...", "parameters": {...}}}
Responses API format:
{"type": "function", "name": "...", "parameters": {...}}
Args:
tools: List of tools in Chat Completions format
Returns:
List of tools in Responses API format
"""
if not tools:
return []
converted = []
for tool in tools:
if tool.get("type") == "function":
func = tool.get("function", {})
name = func.get("name")
if not name:
raise ValueError(
f"Function tool is missing required 'name' field: {tool}"
)
entry: dict[str, Any] = {
"type": "function",
"name": name,
# Note: strict=True is default in Responses API
}
if func.get("description") is not None:
entry["description"] = func["description"]
if func.get("parameters") is not None:
entry["parameters"] = func["parameters"]
converted.append(entry)
else:
# Pass through non-function tools as-is
converted.append(tool)
return converted
def extract_responses_tool_calls(response: Any) -> list[dict] | None:
"""Extract tool calls from Responses API response.
The Responses API returns tool calls as separate items in the output array
with type="function_call".
Args:
response: The Responses API response object
Returns:
List of tool calls in a normalized format, or None if no tool calls
"""
tool_calls = []
for item in response.output:
if getattr(item, "type", None) == "function_call":
tool_calls.append(
{
"id": item.call_id,
"type": "function",
"function": {
"name": item.name,
"arguments": item.arguments,
},
}
)
return tool_calls if tool_calls else None
def extract_responses_usage(response: Any) -> tuple[int, int]:
"""Extract token usage from Responses API response.
The Responses API uses input_tokens/output_tokens (not prompt_tokens/completion_tokens).
Args:
response: The Responses API response object
Returns:
Tuple of (input_tokens, output_tokens)
"""
if not getattr(response, "usage", None):
return 0, 0
return (
getattr(response.usage, "input_tokens", 0),
getattr(response.usage, "output_tokens", 0),
)
def extract_responses_content(response: Any) -> str:
"""Extract text content from Responses API response.
Args:
response: The Responses API response object
Returns:
The text content from the response, or empty string if none
"""
# The SDK provides a helper property
if hasattr(response, "output_text"):
return response.output_text or ""
# Fallback: manually extract from output items
for item in response.output:
if getattr(item, "type", None) == "message":
for content in getattr(item, "content", []):
if getattr(content, "type", None) == "output_text":
return getattr(content, "text", "")
return ""
def extract_responses_reasoning(response: Any) -> str | None:
"""Extract reasoning content from Responses API response.
Reasoning models return their reasoning process in the response,
which can be useful for debugging or display.
Args:
response: The Responses API response object
Returns:
The reasoning text, or None if not present
"""
for item in response.output:
if getattr(item, "type", None) == "reasoning":
# Reasoning items may have summary or content
summary = getattr(item, "summary", [])
if summary:
# Join summary items if present
texts = []
for s in summary:
if hasattr(s, "text"):
texts.append(s.text)
if texts:
return "\n".join(texts)
return None

View File

@@ -0,0 +1,312 @@
"""Tests for OpenAI Responses API helpers."""
from unittest.mock import MagicMock
from backend.util.openai_responses import (
convert_tools_to_responses_format,
extract_responses_content,
extract_responses_reasoning,
extract_responses_tool_calls,
extract_responses_usage,
)
class TestConvertToolsToResponsesFormat:
"""Tests for the convert_tools_to_responses_format function."""
def test_empty_tools_returns_empty_list(self):
"""Empty or None tools should return empty list."""
assert convert_tools_to_responses_format(None) == []
assert convert_tools_to_responses_format([]) == []
def test_converts_function_tool_format(self):
"""Should convert Chat Completions function format to Responses format."""
chat_completions_tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather in a location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"},
},
"required": ["location"],
},
},
}
]
result = convert_tools_to_responses_format(chat_completions_tools)
assert len(result) == 1
assert result[0]["type"] == "function"
assert result[0]["name"] == "get_weather"
assert result[0]["description"] == "Get the weather in a location"
assert result[0]["parameters"] == {
"type": "object",
"properties": {
"location": {"type": "string"},
},
"required": ["location"],
}
# Should not have nested "function" key
assert "function" not in result[0]
def test_handles_multiple_tools(self):
"""Should handle multiple tools."""
chat_completions_tools = [
{
"type": "function",
"function": {
"name": "tool_1",
"description": "First tool",
"parameters": {"type": "object", "properties": {}},
},
},
{
"type": "function",
"function": {
"name": "tool_2",
"description": "Second tool",
"parameters": {"type": "object", "properties": {}},
},
},
]
result = convert_tools_to_responses_format(chat_completions_tools)
assert len(result) == 2
assert result[0]["name"] == "tool_1"
assert result[1]["name"] == "tool_2"
def test_passes_through_non_function_tools(self):
"""Non-function tools should be passed through as-is."""
tools = [{"type": "web_search", "config": {"enabled": True}}]
result = convert_tools_to_responses_format(tools)
assert result == tools
def test_omits_none_description_and_parameters(self):
"""Should omit description and parameters when they are None."""
tools = [
{
"type": "function",
"function": {
"name": "simple_tool",
},
}
]
result = convert_tools_to_responses_format(tools)
assert len(result) == 1
assert result[0]["type"] == "function"
assert result[0]["name"] == "simple_tool"
assert "description" not in result[0]
assert "parameters" not in result[0]
def test_raises_on_missing_name(self):
"""Should raise ValueError when function tool has no name."""
import pytest
tools = [{"type": "function", "function": {}}]
with pytest.raises(ValueError, match="missing required 'name' field"):
convert_tools_to_responses_format(tools)
class TestExtractResponsesToolCalls:
"""Tests for the extract_responses_tool_calls function."""
def test_extracts_function_call_items(self):
"""Should extract function_call items from response output."""
item = MagicMock()
item.type = "function_call"
item.call_id = "call_123"
item.name = "get_weather"
item.arguments = '{"location": "NYC"}'
response = MagicMock()
response.output = [item]
result = extract_responses_tool_calls(response)
assert result == [
{
"id": "call_123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"location": "NYC"}',
},
}
]
def test_returns_none_when_no_tool_calls(self):
"""Should return None when no function_call items exist."""
message_item = MagicMock()
message_item.type = "message"
response = MagicMock()
response.output = [message_item]
assert extract_responses_tool_calls(response) is None
def test_returns_none_for_empty_output(self):
"""Should return None when output is empty."""
response = MagicMock()
response.output = []
assert extract_responses_tool_calls(response) is None
def test_extracts_multiple_tool_calls(self):
"""Should extract multiple function_call items."""
item1 = MagicMock()
item1.type = "function_call"
item1.call_id = "call_1"
item1.name = "tool_a"
item1.arguments = "{}"
item2 = MagicMock()
item2.type = "function_call"
item2.call_id = "call_2"
item2.name = "tool_b"
item2.arguments = '{"x": 1}'
response = MagicMock()
response.output = [item1, item2]
result = extract_responses_tool_calls(response)
assert result is not None
assert len(result) == 2
assert result[0]["function"]["name"] == "tool_a"
assert result[1]["function"]["name"] == "tool_b"
class TestExtractResponsesUsage:
"""Tests for the extract_responses_usage function."""
def test_extracts_token_counts(self):
"""Should extract input_tokens and output_tokens."""
response = MagicMock()
response.usage.input_tokens = 42
response.usage.output_tokens = 17
result = extract_responses_usage(response)
assert result == (42, 17)
def test_returns_zeros_when_usage_is_none(self):
"""Should return (0, 0) when usage is None."""
response = MagicMock()
response.usage = None
result = extract_responses_usage(response)
assert result == (0, 0)
class TestExtractResponsesContent:
"""Tests for the extract_responses_content function."""
def test_extracts_from_output_text(self):
"""Should use output_text property when available."""
response = MagicMock()
response.output_text = "Hello world"
assert extract_responses_content(response) == "Hello world"
def test_returns_empty_string_when_output_text_is_none(self):
"""Should return empty string when output_text is None."""
response = MagicMock()
response.output_text = None
response.output = []
assert extract_responses_content(response) == ""
def test_fallback_to_output_items(self):
"""Should fall back to extracting from output items."""
text_content = MagicMock()
text_content.type = "output_text"
text_content.text = "Fallback content"
message_item = MagicMock()
message_item.type = "message"
message_item.content = [text_content]
response = MagicMock(spec=[]) # no output_text attribute
response.output = [message_item]
assert extract_responses_content(response) == "Fallback content"
def test_returns_empty_string_for_empty_output(self):
"""Should return empty string when no content found."""
response = MagicMock(spec=[]) # no output_text attribute
response.output = []
assert extract_responses_content(response) == ""
class TestExtractResponsesReasoning:
"""Tests for the extract_responses_reasoning function."""
def test_extracts_reasoning_summary(self):
"""Should extract reasoning text from summary items."""
summary_item = MagicMock()
summary_item.text = "Step 1: Think about it"
reasoning_item = MagicMock()
reasoning_item.type = "reasoning"
reasoning_item.summary = [summary_item]
response = MagicMock()
response.output = [reasoning_item]
assert extract_responses_reasoning(response) == "Step 1: Think about it"
def test_joins_multiple_summary_items(self):
"""Should join multiple summary text items with newlines."""
s1 = MagicMock()
s1.text = "First thought"
s2 = MagicMock()
s2.text = "Second thought"
reasoning_item = MagicMock()
reasoning_item.type = "reasoning"
reasoning_item.summary = [s1, s2]
response = MagicMock()
response.output = [reasoning_item]
assert extract_responses_reasoning(response) == "First thought\nSecond thought"
def test_returns_none_when_no_reasoning(self):
"""Should return None when no reasoning items exist."""
message_item = MagicMock()
message_item.type = "message"
response = MagicMock()
response.output = [message_item]
assert extract_responses_reasoning(response) is None
def test_returns_none_for_empty_output(self):
"""Should return None when output is empty."""
response = MagicMock()
response.output = []
assert extract_responses_reasoning(response) is None
def test_returns_none_when_summary_is_empty(self):
"""Should return None when reasoning item has empty summary."""
reasoning_item = MagicMock()
reasoning_item.type = "reasoning"
reasoning_item.summary = []
response = MagicMock()
response.output = [reasoning_item]
assert extract_responses_reasoning(response) is None

View File

@@ -36,16 +36,34 @@ def _msg_tokens(msg: dict, enc) -> int:
OpenAI counts ≈3 wrapper tokens per chat message, plus 1 if "name"
is present, plus the tokenised content length.
For tool calls, we need to count tokens in tool_calls and content fields.
Supports Chat Completions, Anthropic, and Responses API formats.
"""
WRAPPER = 3 + (1 if "name" in msg else 0)
# Responses API: function_call items have arguments + name
if msg.get("type") == "function_call":
return (
WRAPPER
+ _tok_len(msg.get("name", ""), enc)
+ _tok_len(msg.get("arguments", ""), enc)
+ _tok_len(msg.get("call_id", ""), enc)
)
# Responses API: function_call_output items have output
if msg.get("type") == "function_call_output":
return (
WRAPPER
+ _tok_len(msg.get("output", ""), enc)
+ _tok_len(msg.get("call_id", ""), enc)
)
# Count content tokens
content_tokens = _tok_len(msg.get("content") or "", enc)
# Count tool call tokens for both OpenAI and Anthropic formats
tool_call_tokens = 0
# OpenAI format: tool_calls array at message level
# OpenAI Chat Completions format: tool_calls array at message level
if "tool_calls" in msg and isinstance(msg["tool_calls"], list):
for tool_call in msg["tool_calls"]:
# Count the tool call structure tokens
@@ -85,6 +103,10 @@ def _msg_tokens(msg: dict, enc) -> int:
def _is_tool_message(msg: dict) -> bool:
"""Check if a message contains tool calls or results that should be protected."""
# Responses API: standalone function_call / function_call_output items
if msg.get("type") in ("function_call", "function_call_output"):
return True
content = msg.get("content")
# Check for Anthropic-style tool messages
@@ -94,7 +116,7 @@ def _is_tool_message(msg: dict) -> bool:
):
return True
# Check for OpenAI-style tool calls in the message
# Check for OpenAI Chat Completions-style tool calls in the message
if "tool_calls" in msg or msg.get("role") == "tool":
return True
@@ -113,11 +135,18 @@ def _is_objective_message(msg: dict) -> bool:
def _truncate_tool_message_content(msg: dict, enc, max_tokens: int) -> None:
"""
Carefully truncate tool message content while preserving tool structure.
Handles both Anthropic-style (list content) and OpenAI-style (string content) tool messages.
Handles Anthropic, Chat Completions, and Responses API tool messages.
"""
# Responses API: function_call_output has "output" field
if msg.get("type") == "function_call_output":
output = msg.get("output", "")
if isinstance(output, str) and _tok_len(output, enc) > max_tokens:
msg["output"] = _truncate_middle_tokens(output, enc, max_tokens)
return
content = msg.get("content")
# OpenAI-style tool message: role="tool" with string content
# OpenAI Chat Completions tool message: role="tool" with string content
if msg.get("role") == "tool" and isinstance(content, str):
if _tok_len(content, enc) > max_tokens:
msg["content"] = _truncate_middle_tokens(content, enc, max_tokens)
@@ -251,18 +280,26 @@ def _extract_tool_call_ids_from_message(msg: dict) -> set[str]:
"""
Extract tool_call IDs from an assistant message.
Supports both formats:
- OpenAI: {"role": "assistant", "tool_calls": [{"id": "..."}]}
Supports all formats:
- OpenAI Chat Completions: {"role": "assistant", "tool_calls": [{"id": "..."}]}
- Anthropic: {"role": "assistant", "content": [{"type": "tool_use", "id": "..."}]}
- OpenAI Responses API: {"type": "function_call", "call_id": "..."}
Returns:
Set of tool_call IDs found in the message.
"""
ids: set[str] = set()
# Responses API: standalone function_call item
if msg.get("type") == "function_call":
if call_id := msg.get("call_id"):
ids.add(call_id)
return ids
if msg.get("role") != "assistant":
return ids
# OpenAI format: tool_calls array
# OpenAI Chat Completions format: tool_calls array
if msg.get("tool_calls"):
for tc in msg["tool_calls"]:
tc_id = tc.get("id")
@@ -285,16 +322,23 @@ def _extract_tool_response_ids_from_message(msg: dict) -> set[str]:
"""
Extract tool_call IDs that this message is responding to.
Supports both formats:
- OpenAI: {"role": "tool", "tool_call_id": "..."}
Supports all formats:
- OpenAI Chat Completions: {"role": "tool", "tool_call_id": "..."}
- Anthropic: {"role": "user", "content": [{"type": "tool_result", "tool_use_id": "..."}]}
- OpenAI Responses API: {"type": "function_call_output", "call_id": "..."}
Returns:
Set of tool_call IDs this message responds to.
"""
ids: set[str] = set()
# OpenAI format: role=tool with tool_call_id
# Responses API: standalone function_call_output item
if msg.get("type") == "function_call_output":
if call_id := msg.get("call_id"):
ids.add(call_id)
return ids
# OpenAI Chat Completions format: role=tool with tool_call_id
if msg.get("role") == "tool":
tc_id = msg.get("tool_call_id")
if tc_id:
@@ -313,8 +357,11 @@ def _extract_tool_response_ids_from_message(msg: dict) -> set[str]:
def _is_tool_response_message(msg: dict) -> bool:
"""Check if message is a tool response (OpenAI or Anthropic format)."""
# OpenAI format
"""Check if message is a tool response (Chat Completions, Anthropic, or Responses API)."""
# Responses API format
if msg.get("type") == "function_call_output":
return True
# OpenAI Chat Completions format
if msg.get("role") == "tool":
return True
# Anthropic format
@@ -332,13 +379,20 @@ def _remove_orphan_tool_responses(
"""
Remove tool response messages/blocks that reference orphan tool_call IDs.
Supports both OpenAI and Anthropic formats.
Supports OpenAI Chat Completions, Anthropic, and Responses API formats.
For Anthropic messages with mixed valid/orphan tool_result blocks,
filters out only the orphan blocks instead of dropping the entire message.
"""
result = []
for msg in messages:
# OpenAI format: role=tool - drop entire message if orphan
# Responses API: function_call_output - drop if orphan
if msg.get("type") == "function_call_output":
if msg.get("call_id") in orphan_ids:
continue
result.append(msg)
continue
# OpenAI Chat Completions: role=tool - drop entire message if orphan
if msg.get("role") == "tool":
tc_id = msg.get("tool_call_id")
if tc_id and tc_id in orphan_ids:
@@ -524,6 +578,18 @@ async def _summarize_messages_llm(
"""Summarize messages using an LLM."""
conversation = []
for msg in messages:
# Responses API: function_call items
if msg.get("type") == "function_call":
name = msg.get("name", "unknown_tool")
args = msg.get("arguments", "")
conversation.append(f"TOOL CALL ({name}): {args}")
continue
# Responses API: function_call_output items
if msg.get("type") == "function_call_output":
output = msg.get("output", "")
conversation.append(f"TOOL OUTPUT: {output}")
continue
role = msg.get("role", "")
content = msg.get("content", "")
if content and role in ("user", "assistant", "tool"):

View File

@@ -0,0 +1,603 @@
"""Tests for prompt.py compatibility with the OpenAI Responses API.
The Responses API uses a different conversation format:
- Tool calls are standalone items with ``type: "function_call"`` and ``call_id``
- Tool results are items with ``type: "function_call_output"`` and ``call_id``
- These items do NOT have ``role`` at the top level
These tests validate that prompt utilities correctly handle Responses API items
alongside Chat Completions and Anthropic formats.
"""
import pytest
from tiktoken import encoding_for_model
from backend.util.prompt import (
_ensure_tool_pairs_intact,
_extract_tool_call_ids_from_message,
_extract_tool_response_ids_from_message,
_is_tool_message,
_is_tool_response_message,
_msg_tokens,
_remove_orphan_tool_responses,
_truncate_tool_message_content,
compress_context,
validate_and_remove_orphan_tool_responses,
)
# ── Fixtures ──────────────────────────────────────────────────────────────
@pytest.fixture
def enc():
return encoding_for_model("gpt-4o")
# ── Sample items ──────────────────────────────────────────────────────────
FUNCTION_CALL_ITEM = {
"type": "function_call",
"id": "fc_abc",
"call_id": "call_abc",
"name": "search_tool",
"arguments": '{"query": "python asyncio tutorial"}',
"status": "completed",
}
FUNCTION_CALL_OUTPUT_ITEM = {
"type": "function_call_output",
"call_id": "call_abc",
"output": '{"results": ["result1", "result2", "result3"]}',
}
# ═══════════════════════════════════════════════════════════════════════════
# _msg_tokens
# ═══════════════════════════════════════════════════════════════════════════
class TestMsgTokensResponsesApi:
"""_msg_tokens should count tokens in function_call / function_call_output
items, not just role-based messages."""
def test_chat_completions_tool_call_counted(self, enc):
"""Baseline: Chat Completions tool_calls are counted correctly."""
msg = {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_abc",
"type": "function",
"function": {
"name": "search_tool",
"arguments": '{"query": "python asyncio tutorial"}',
},
}
],
}
tokens = _msg_tokens(msg, enc)
assert tokens > 10 # Should count the tool call content
def test_chat_completions_tool_response_counted(self, enc):
"""Baseline: Chat Completions tool responses are counted correctly."""
msg = {
"role": "tool",
"tool_call_id": "call_abc",
"content": '{"results": ["result1", "result2"]}',
}
tokens = _msg_tokens(msg, enc)
assert tokens > 5
def test_function_call_minimal_fields(self, enc):
"""function_call with missing optional fields still counts."""
msg = {"type": "function_call"}
tokens = _msg_tokens(msg, enc)
assert tokens >= 3 # At least the wrapper
def test_function_call_output_minimal_fields(self, enc):
"""function_call_output with missing output field still counts."""
msg = {"type": "function_call_output"}
tokens = _msg_tokens(msg, enc)
assert tokens >= 3
def test_function_call_arguments_counted(self, enc):
"""function_call items have 'arguments' not 'content' — tokens must
include the arguments string and the function name."""
tokens = _msg_tokens(FUNCTION_CALL_ITEM, enc)
# Must count at least the arguments and name tokens
name_tokens = len(enc.encode(FUNCTION_CALL_ITEM["name"]))
args_tokens = len(enc.encode(FUNCTION_CALL_ITEM["arguments"]))
assert tokens >= name_tokens + args_tokens
def test_function_call_output_content_counted(self, enc):
"""function_call_output items have 'output' not 'content' — tokens must
include the output string."""
tokens = _msg_tokens(FUNCTION_CALL_OUTPUT_ITEM, enc)
output_tokens = len(enc.encode(FUNCTION_CALL_OUTPUT_ITEM["output"]))
assert tokens >= output_tokens
# ═══════════════════════════════════════════════════════════════════════════
# _is_tool_message
# ═══════════════════════════════════════════════════════════════════════════
class TestIsToolMessageResponsesApi:
"""_is_tool_message should recognise Responses API items as tool messages
so they are protected from deletion during compaction."""
def test_chat_completions_tool_call_detected(self):
"""Baseline: Chat Completions tool_calls are detected."""
msg = {
"role": "assistant",
"tool_calls": [{"id": "call_1", "type": "function"}],
}
assert _is_tool_message(msg) is True
def test_chat_completions_tool_response_detected(self):
"""Baseline: Chat Completions role=tool is detected."""
msg = {"role": "tool", "tool_call_id": "call_1", "content": "result"}
assert _is_tool_message(msg) is True
def test_anthropic_tool_use_detected(self):
"""Baseline: Anthropic tool_use is detected."""
msg = {
"role": "assistant",
"content": [
{"type": "tool_use", "id": "toolu_1", "name": "t", "input": {}}
],
}
assert _is_tool_message(msg) is True
def test_anthropic_tool_result_detected(self):
"""Baseline: Anthropic tool_result is detected."""
msg = {
"role": "user",
"content": [
{"type": "tool_result", "tool_use_id": "toolu_1", "content": "ok"}
],
}
assert _is_tool_message(msg) is True
def test_function_call_detected(self):
"""type=function_call should be recognised as a tool message."""
assert _is_tool_message(FUNCTION_CALL_ITEM) is True
def test_function_call_output_detected(self):
"""type=function_call_output should be recognised as a tool message."""
assert _is_tool_message(FUNCTION_CALL_OUTPUT_ITEM) is True
def test_regular_user_message_not_tool(self):
"""Plain user message → not a tool message."""
assert _is_tool_message({"role": "user", "content": "hello"}) is False
def test_regular_assistant_message_not_tool(self):
"""Plain assistant message without tool_calls → not a tool message."""
assert _is_tool_message({"role": "assistant", "content": "hi"}) is False
# ═══════════════════════════════════════════════════════════════════════════
# _extract_tool_call_ids_from_message
# ═══════════════════════════════════════════════════════════════════════════
class TestExtractToolCallIdsResponsesApi:
"""_extract_tool_call_ids_from_message should extract call_ids from
Responses API function_call items."""
def test_chat_completions_extracted(self):
"""Baseline: Chat Completions tool_calls IDs are extracted."""
msg = {
"role": "assistant",
"tool_calls": [
{"id": "call_1", "type": "function"},
{"id": "call_2", "type": "function"},
],
}
assert _extract_tool_call_ids_from_message(msg) == {"call_1", "call_2"}
def test_anthropic_extracted(self):
"""Baseline: Anthropic tool_use IDs are extracted."""
msg = {
"role": "assistant",
"content": [{"type": "tool_use", "id": "toolu_1"}],
}
assert _extract_tool_call_ids_from_message(msg) == {"toolu_1"}
def test_function_call_extracted(self):
"""type=function_call with call_id should be extracted."""
assert _extract_tool_call_ids_from_message(FUNCTION_CALL_ITEM) == {"call_abc"}
def test_function_call_missing_call_id(self):
"""function_call without call_id → empty set."""
msg = {"type": "function_call", "name": "tool"}
assert _extract_tool_call_ids_from_message(msg) == set()
def test_non_assistant_non_function_call(self):
"""Messages with neither role=assistant nor type=function_call → empty."""
msg = {"role": "user", "content": "hello"}
assert _extract_tool_call_ids_from_message(msg) == set()
# ═══════════════════════════════════════════════════════════════════════════
# _extract_tool_response_ids_from_message
# ═══════════════════════════════════════════════════════════════════════════
class TestExtractToolResponseIdsResponsesApi:
"""_extract_tool_response_ids_from_message should extract call_ids from
Responses API function_call_output items."""
def test_chat_completions_extracted(self):
"""Baseline: Chat Completions tool_call_id is extracted."""
msg = {"role": "tool", "tool_call_id": "call_1", "content": "result"}
assert _extract_tool_response_ids_from_message(msg) == {"call_1"}
def test_anthropic_extracted(self):
"""Baseline: Anthropic tool_use_id is extracted."""
msg = {
"role": "user",
"content": [
{"type": "tool_result", "tool_use_id": "toolu_1", "content": "ok"}
],
}
assert _extract_tool_response_ids_from_message(msg) == {"toolu_1"}
def test_function_call_output_extracted(self):
"""type=function_call_output with call_id should be extracted."""
assert _extract_tool_response_ids_from_message(FUNCTION_CALL_OUTPUT_ITEM) == {
"call_abc"
}
def test_function_call_output_missing_call_id(self):
"""function_call_output without call_id → empty set."""
msg = {"type": "function_call_output", "output": "result"}
assert _extract_tool_response_ids_from_message(msg) == set()
def test_non_tool_non_function_call_output(self):
"""Regular user message → empty set."""
msg = {"role": "user", "content": "hello"}
assert _extract_tool_response_ids_from_message(msg) == set()
# ═══════════════════════════════════════════════════════════════════════════
# _is_tool_response_message
# ═══════════════════════════════════════════════════════════════════════════
class TestIsToolResponseMessageResponsesApi:
def test_chat_completions_detected(self):
msg = {"role": "tool", "tool_call_id": "call_1", "content": "r"}
assert _is_tool_response_message(msg) is True
def test_anthropic_detected(self):
msg = {
"role": "user",
"content": [
{"type": "tool_result", "tool_use_id": "toolu_1", "content": "ok"}
],
}
assert _is_tool_response_message(msg) is True
def test_function_call_output_detected(self):
"""type=function_call_output should be recognised as a tool response."""
assert _is_tool_response_message(FUNCTION_CALL_OUTPUT_ITEM) is True
def test_function_call_is_not_response(self):
"""function_call is a tool REQUEST, not a response."""
assert _is_tool_response_message(FUNCTION_CALL_ITEM) is False
def test_regular_message_not_response(self):
"""Plain message → not a tool response."""
assert _is_tool_response_message({"role": "user", "content": "hi"}) is False
# ═══════════════════════════════════════════════════════════════════════════
# _truncate_tool_message_content
# ═══════════════════════════════════════════════════════════════════════════
class TestTruncateToolMessageContentResponsesApi:
def test_chat_completions_truncated(self, enc):
"""Baseline: role=tool content is truncated."""
msg = {"role": "tool", "tool_call_id": "call_1", "content": "x" * 10000}
_truncate_tool_message_content(msg, enc, max_tokens=50)
assert len(enc.encode(msg["content"])) <= 55 # ~50 with rounding
def test_function_call_output_truncated(self, enc):
"""function_call_output 'output' field should be truncated."""
msg = {
"type": "function_call_output",
"call_id": "call_1",
"output": "x" * 10000,
}
_truncate_tool_message_content(msg, enc, max_tokens=50)
assert len(enc.encode(msg["output"])) <= 55
def test_function_call_output_short_not_truncated(self, enc):
"""Short function_call_output output is left unchanged."""
msg = {
"type": "function_call_output",
"call_id": "call_1",
"output": "short",
}
_truncate_tool_message_content(msg, enc, max_tokens=1000)
assert msg["output"] == "short"
def test_function_call_not_truncated(self, enc):
"""function_call items (requests) should not be truncated."""
msg = dict(FUNCTION_CALL_ITEM) # copy
original_args = msg["arguments"]
_truncate_tool_message_content(msg, enc, max_tokens=5)
assert msg["arguments"] == original_args # unchanged
# ═══════════════════════════════════════════════════════════════════════════
# _remove_orphan_tool_responses
# ═══════════════════════════════════════════════════════════════════════════
class TestRemoveOrphanToolResponsesResponsesApi:
def test_chat_completions_orphan_removed(self):
"""Baseline: orphan role=tool messages are removed."""
messages = [
{"role": "tool", "tool_call_id": "call_orphan", "content": "result"},
{"role": "user", "content": "Hello"},
]
result = _remove_orphan_tool_responses(messages, {"call_orphan"})
assert len(result) == 1
assert result[0]["role"] == "user"
def test_function_call_output_orphan_removed(self):
"""Orphan function_call_output items should be removed."""
messages = [
{
"type": "function_call_output",
"call_id": "call_orphan",
"output": "result",
},
{"role": "user", "content": "Hello"},
]
result = _remove_orphan_tool_responses(messages, {"call_orphan"})
assert len(result) == 1
assert result[0]["role"] == "user"
def test_function_call_output_non_orphan_kept(self):
"""Non-orphan function_call_output items should be kept."""
messages = [
{
"type": "function_call_output",
"call_id": "call_valid",
"output": "result",
},
{"role": "user", "content": "Hello"},
]
result = _remove_orphan_tool_responses(messages, {"call_other"})
assert len(result) == 2
# ═══════════════════════════════════════════════════════════════════════════
# validate_and_remove_orphan_tool_responses
# ═══════════════════════════════════════════════════════════════════════════
class TestValidateOrphansResponsesApi:
def test_chat_completions_paired_kept(self):
"""Baseline: matched Chat Completions pairs are kept."""
messages = [
{
"role": "assistant",
"tool_calls": [{"id": "call_1", "type": "function"}],
},
{"role": "tool", "tool_call_id": "call_1", "content": "done"},
]
result = validate_and_remove_orphan_tool_responses(messages, log_warning=False)
assert len(result) == 2
def test_responses_api_paired_kept(self):
"""Matched Responses API pairs are kept because the validator
properly recognizes function_call and function_call_output items."""
messages = [
{"role": "user", "content": "Do something."},
FUNCTION_CALL_ITEM,
FUNCTION_CALL_OUTPUT_ITEM,
]
result = validate_and_remove_orphan_tool_responses(messages, log_warning=False)
assert len(result) == 3
def test_responses_api_orphan_output_removed(self):
"""Orphan function_call_output (no matching function_call) should be removed."""
messages = [
{"role": "user", "content": "Do something."},
# No function_call — output is orphaned
FUNCTION_CALL_OUTPUT_ITEM,
]
result = validate_and_remove_orphan_tool_responses(messages, log_warning=False)
assert len(result) == 1
assert result[0]["role"] == "user"
# ═══════════════════════════════════════════════════════════════════════════
# _ensure_tool_pairs_intact
# ═══════════════════════════════════════════════════════════════════════════
class TestEnsureToolPairsIntactResponsesApi:
def test_chat_completions_pair_preserved(self):
"""Baseline: sliced Chat Completions tool responses get their assistant prepended."""
all_msgs = [
{"role": "system", "content": "sys"},
{
"role": "assistant",
"tool_calls": [{"id": "call_1", "type": "function"}],
},
{"role": "tool", "tool_call_id": "call_1", "content": "result"},
{"role": "user", "content": "thanks"},
]
# Slice starts at index 2 (tool response) — orphan
recent = [all_msgs[2], all_msgs[3]]
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index=2)
# Should prepend the assistant message
assert len(result) == 3
assert "tool_calls" in result[0]
def test_responses_api_pair_preserved(self):
"""Sliced function_call_output should get its function_call prepended."""
all_msgs = [
{"role": "system", "content": "sys"},
{"role": "user", "content": "search for X"},
FUNCTION_CALL_ITEM,
FUNCTION_CALL_OUTPUT_ITEM,
{"role": "user", "content": "thanks"},
]
# Slice starts at index 3 (function_call_output) — orphan
recent = [all_msgs[3], all_msgs[4]]
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index=3)
# Should prepend the function_call item
assert len(result) == 3
assert result[0].get("type") == "function_call"
# ═══════════════════════════════════════════════════════════════════════════
# _summarize_messages_llm (minor)
# ═══════════════════════════════════════════════════════════════════════════
class TestSummarizeMessagesResponsesApi:
"""_summarize_messages_llm extracts content using msg.get("role") and
msg.get("content"). Responses API function_call items have neither
role in ("user", "assistant", "tool") nor "content" — they'd be silently
skipped in the summary. This is a minor data-loss issue."""
@pytest.mark.asyncio
async def test_function_call_included_in_summary_text(self):
"""function_call items should contribute to the summary text."""
from backend.util.prompt import _summarize_messages_llm
messages = [
{"role": "user", "content": "Search for X"},
FUNCTION_CALL_ITEM,
FUNCTION_CALL_OUTPUT_ITEM,
{"role": "user", "content": "Thanks"},
]
# We only need to check the conversation text building, not the LLM call.
# The function builds conversation_text before calling the client.
# We mock the client to capture what it receives.
from unittest.mock import AsyncMock, MagicMock
mock_client = MagicMock()
mock_resp = MagicMock()
mock_resp.choices = [MagicMock()]
mock_resp.choices[0].message.content = "Summary"
mock_client.with_options.return_value.chat.completions.create = AsyncMock(
return_value=mock_resp
)
await _summarize_messages_llm(messages, mock_client, "gpt-4o")
# Check the prompt sent to the LLM contains tool info
call_args = (
mock_client.with_options.return_value.chat.completions.create.call_args
)
user_msg = call_args.kwargs["messages"][1]["content"]
# The tool name or arguments should appear in the summary text
assert "search_tool" in user_msg or "python asyncio" in user_msg
# ═══════════════════════════════════════════════════════════════════════════
# compress_context end-to-end
# ═══════════════════════════════════════════════════════════════════════════
class TestCompressContextResponsesApi:
@pytest.mark.asyncio
async def test_chat_completions_tool_pairs_preserved(self):
"""Baseline: Chat Completions tool pairs survive compaction."""
messages: list[dict] = [
{"role": "system", "content": "You are helpful."},
]
# Add enough messages to trigger compaction
for i in range(20):
messages.append({"role": "user", "content": f"Question {i} " * 200})
messages.append({"role": "assistant", "content": f"Answer {i} " * 200})
# Add a tool pair at the end
messages.append(
{
"role": "assistant",
"tool_calls": [
{"id": "call_final", "type": "function", "function": {"name": "f"}}
],
}
)
messages.append(
{"role": "tool", "tool_call_id": "call_final", "content": "result"}
)
messages.append({"role": "assistant", "content": "Done!"})
result = await compress_context(messages, target_tokens=2000, client=None)
# If tool response exists, its call must exist too
call_ids = set()
resp_ids = set()
for msg in result.messages:
if "tool_calls" in msg:
for tc in msg["tool_calls"]:
call_ids.add(tc["id"])
if msg.get("role") == "tool":
resp_ids.add(msg.get("tool_call_id"))
assert resp_ids <= call_ids
@pytest.mark.asyncio
async def test_responses_api_tool_pairs_preserved(self):
"""Responses API function_call / function_call_output pairs must
survive compaction intact. Currently they can be silently deleted
because _is_tool_message doesn't recognise them."""
messages = [
{"role": "system", "content": "You are helpful."},
]
# Add enough messages to trigger compaction
for i in range(20):
messages.append({"role": "user", "content": f"Question {i} " * 200})
messages.append({"role": "assistant", "content": f"Answer {i} " * 200})
# Add a Responses API tool pair at the end
messages.append(
{
"type": "function_call",
"id": "fc_final",
"call_id": "call_final",
"name": "search_tool",
"arguments": '{"q": "test"}',
"status": "completed",
}
)
messages.append(
{
"type": "function_call_output",
"call_id": "call_final",
"output": '{"results": ["a", "b"]}',
}
)
messages.append({"role": "user", "content": "Thanks!"})
result = await compress_context(messages, target_tokens=2000, client=None)
# The function_call and function_call_output must both survive
fc_items = [m for m in result.messages if m.get("type") == "function_call"]
fco_items = [
m for m in result.messages if m.get("type") == "function_call_output"
]
# If either exists, the other must exist too (pair integrity)
if fc_items or fco_items:
fc_call_ids = {m["call_id"] for m in fc_items}
fco_call_ids = {m["call_id"] for m in fco_items}
assert (
fco_call_ids <= fc_call_ids
), "function_call_output exists without matching function_call"
# At minimum, neither should have been silently deleted if the
# conversation was short enough to keep them
assert len(fc_items) >= 1, "function_call was deleted during compaction"
assert len(fco_items) >= 1, "function_call_output was deleted during compaction"

View File

@@ -704,8 +704,19 @@ def get_service_client(
return kwargs
def _get_return(self, expected_return: TypeAdapter | None, result: Any) -> Any:
"""Validate and coerce the RPC result to the expected return type.
Falls back to the raw result with a warning if validation fails.
"""
if expected_return:
return expected_return.validate_python(result)
try:
return expected_return.validate_python(result)
except Exception as e:
logger.warning(
"RPC return type validation failed, using raw result: %s",
type(e).__name__,
)
return result
return result
def __getattr__(self, name: str) -> Callable[..., Any]:

View File

@@ -89,10 +89,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
le=500,
description="Thread pool size for FastAPI sync operations. All sync endpoints and dependencies automatically use this pool. Higher values support more concurrent sync operations but use more memory.",
)
tally_extraction_llm_model: str = Field(
default="openai/gpt-4o-mini",
description="OpenRouter model ID used for extracting business understanding from Tally form data",
)
ollama_host: str = Field(
default="localhost:11434",
description="Default Ollama host; exempted from SSRF checks.",
@@ -121,10 +117,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default=True,
description="If authentication is enabled or not",
)
enable_invite_gate: bool = Field(
default=False,
description="If the invite-only signup gate is enforced",
)
enable_credit: bool = Field(
default=False,
description="If user credit system is enabled or not",

View File

@@ -302,7 +302,14 @@ def _value_satisfies_type(value: Any, target: Any) -> bool:
# Simple type (e.g. str, int)
if isinstance(target, type):
return isinstance(value, target)
try:
return isinstance(value, target)
except TypeError:
# TypedDict and some typing constructs don't support isinstance checks.
# For TypedDict, check if value is a dict with the required keys.
if isinstance(value, dict) and hasattr(target, "__required_keys__"):
return all(k in value for k in target.__required_keys__)
return False
return False

View File

@@ -0,0 +1,148 @@
-- CreateEnum
CREATE TYPE "LlmCostUnit" AS ENUM ('RUN', 'TOKENS');
-- CreateTable
CREATE TABLE "LlmProvider" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"name" TEXT NOT NULL,
"displayName" TEXT NOT NULL,
"description" TEXT,
"defaultCredentialProvider" TEXT,
"defaultCredentialId" TEXT,
"defaultCredentialType" TEXT,
"metadata" JSONB NOT NULL DEFAULT '{}',
CONSTRAINT "LlmProvider_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "LlmModelCreator" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"name" TEXT NOT NULL,
"displayName" TEXT NOT NULL,
"description" TEXT,
"websiteUrl" TEXT,
"logoUrl" TEXT,
"metadata" JSONB NOT NULL DEFAULT '{}',
CONSTRAINT "LlmModelCreator_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "LlmModel" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"slug" TEXT NOT NULL,
"displayName" TEXT NOT NULL,
"description" TEXT,
"providerId" TEXT NOT NULL,
"creatorId" TEXT,
"contextWindow" INTEGER NOT NULL,
"maxOutputTokens" INTEGER,
"priceTier" INTEGER NOT NULL DEFAULT 1,
"isEnabled" BOOLEAN NOT NULL DEFAULT true,
"isRecommended" BOOLEAN NOT NULL DEFAULT false,
"supportsTools" BOOLEAN NOT NULL DEFAULT false,
"supportsJsonOutput" BOOLEAN NOT NULL DEFAULT false,
"supportsReasoning" BOOLEAN NOT NULL DEFAULT false,
"supportsParallelToolCalls" BOOLEAN NOT NULL DEFAULT false,
"capabilities" JSONB NOT NULL DEFAULT '{}',
"metadata" JSONB NOT NULL DEFAULT '{}',
CONSTRAINT "LlmModel_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "LlmModelCost" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"unit" "LlmCostUnit" NOT NULL DEFAULT 'RUN',
"creditCost" INTEGER NOT NULL,
"credentialProvider" TEXT NOT NULL,
"credentialId" TEXT,
"credentialType" TEXT,
"currency" TEXT,
"metadata" JSONB NOT NULL DEFAULT '{}',
"llmModelId" TEXT NOT NULL,
CONSTRAINT "LlmModelCost_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "LlmModelMigration" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"sourceModelSlug" TEXT NOT NULL,
"targetModelSlug" TEXT NOT NULL,
"reason" TEXT,
"migratedNodeIds" JSONB NOT NULL DEFAULT '[]',
"nodeCount" INTEGER NOT NULL,
"customCreditCost" INTEGER,
"isReverted" BOOLEAN NOT NULL DEFAULT false,
"revertedAt" TIMESTAMP(3),
CONSTRAINT "LlmModelMigration_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "LlmProvider_name_key" ON "LlmProvider"("name");
-- CreateIndex
CREATE UNIQUE INDEX "LlmModelCreator_name_key" ON "LlmModelCreator"("name");
-- CreateIndex
CREATE UNIQUE INDEX "LlmModel_slug_key" ON "LlmModel"("slug");
-- CreateIndex
CREATE INDEX "LlmModel_providerId_isEnabled_idx" ON "LlmModel"("providerId", "isEnabled");
-- CreateIndex
CREATE INDEX "LlmModel_creatorId_idx" ON "LlmModel"("creatorId");
-- CreateIndex (partial unique for default costs - no specific credential)
CREATE UNIQUE INDEX "LlmModelCost_default_cost_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "unit") WHERE "credentialId" IS NULL;
-- CreateIndex (partial unique for credential-specific costs)
CREATE UNIQUE INDEX "LlmModelCost_credential_cost_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "credentialId", "unit") WHERE "credentialId" IS NOT NULL;
-- CreateIndex
CREATE INDEX "LlmModelMigration_targetModelSlug_idx" ON "LlmModelMigration"("targetModelSlug");
-- CreateIndex
CREATE INDEX "LlmModelMigration_sourceModelSlug_isReverted_idx" ON "LlmModelMigration"("sourceModelSlug", "isReverted");
-- CreateIndex (partial unique to prevent multiple active migrations per source)
CREATE UNIQUE INDEX "LlmModelMigration_active_source_key" ON "LlmModelMigration"("sourceModelSlug") WHERE "isReverted" = false;
-- AddForeignKey
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_providerId_fkey" FOREIGN KEY ("providerId") REFERENCES "LlmProvider"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_creatorId_fkey" FOREIGN KEY ("creatorId") REFERENCES "LlmModelCreator"("id") ON DELETE SET NULL ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LlmModelCost" ADD CONSTRAINT "LlmModelCost_llmModelId_fkey" FOREIGN KEY ("llmModelId") REFERENCES "LlmModel"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LlmModelMigration" ADD CONSTRAINT "LlmModelMigration_sourceModelSlug_fkey" FOREIGN KEY ("sourceModelSlug") REFERENCES "LlmModel"("slug") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LlmModelMigration" ADD CONSTRAINT "LlmModelMigration_targetModelSlug_fkey" FOREIGN KEY ("targetModelSlug") REFERENCES "LlmModel"("slug") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddCheckConstraints (enforce data integrity)
ALTER TABLE "LlmModel"
ADD CONSTRAINT "LlmModel_priceTier_check" CHECK ("priceTier" BETWEEN 1 AND 3);
ALTER TABLE "LlmModelCost"
ADD CONSTRAINT "LlmModelCost_creditCost_check" CHECK ("creditCost" >= 0);
ALTER TABLE "LlmModelMigration"
ADD CONSTRAINT "LlmModelMigration_nodeCount_check" CHECK ("nodeCount" >= 0),
ADD CONSTRAINT "LlmModelMigration_customCreditCost_check" CHECK ("customCreditCost" IS NULL OR "customCreditCost" >= 0);

View File

@@ -0,0 +1,287 @@
-- Seed LLM Registry from existing hard-coded data
-- This migration populates the LlmProvider, LlmModelCreator, LlmModel, and LlmModelCost tables
-- with data from the existing MODEL_METADATA and MODEL_COST dictionaries
-- Insert Providers
INSERT INTO "LlmProvider" ("id", "createdAt", "updatedAt", "name", "displayName", "description", "defaultCredentialProvider", "defaultCredentialType", "metadata")
VALUES
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'openai', 'OpenAI', 'OpenAI language models', 'openai', 'api_key', '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'anthropic', 'Anthropic', 'Anthropic Claude models', 'anthropic', 'api_key', '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'groq', 'Groq', 'Groq inference API', 'groq', 'api_key', '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'open_router', 'OpenRouter', 'OpenRouter unified API', 'open_router', 'api_key', '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'aiml_api', 'AI/ML API', 'AI/ML API models', 'aiml_api', 'api_key', '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'ollama', 'Ollama', 'Ollama local models', 'ollama', 'api_key', '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'llama_api', 'Llama API', 'Llama API models', 'llama_api', 'api_key', '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'v0', 'v0', 'v0 by Vercel models', 'v0', 'api_key', '{}'::jsonb)
ON CONFLICT ("name") DO NOTHING;
-- Insert Model Creators
INSERT INTO "LlmModelCreator" ("id", "createdAt", "updatedAt", "name", "displayName", "description", "websiteUrl", "logoUrl", "metadata")
VALUES
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'openai', 'OpenAI', 'Creator of GPT, O1, O3, and DALL-E models', 'https://openai.com', NULL, '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'anthropic', 'Anthropic', 'Creator of Claude AI models', 'https://anthropic.com', NULL, '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'meta', 'Meta', 'Creator of Llama foundation models', 'https://llama.meta.com', NULL, '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'google', 'Google', 'Creator of Gemini and PaLM models', 'https://deepmind.google', NULL, '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'mistralai', 'Mistral AI', 'Creator of Mistral and Codestral models', 'https://mistral.ai', NULL, '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'cohere', 'Cohere', 'Creator of Command language models', 'https://cohere.com', NULL, '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'deepseek', 'DeepSeek', 'Creator of DeepSeek reasoning models', 'https://deepseek.com', NULL, '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'alibaba', 'Alibaba', 'Creator of Qwen language models', 'https://qwenlm.github.io', NULL, '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'nvidia', 'NVIDIA', 'Creator of Nemotron models', 'https://nvidia.com', NULL, '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'vercel', 'Vercel', 'Creator of v0 AI models', 'https://v0.dev', NULL, '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'microsoft', 'Microsoft', 'Creator of Phi models', 'https://microsoft.com', NULL, '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'xai', 'xAI', 'Creator of Grok models', 'https://x.ai', NULL, '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'perplexity', 'Perplexity AI', 'Creator of Sonar search models', 'https://perplexity.ai', NULL, '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'nousresearch', 'Nous Research', 'Creator of Hermes language models', 'https://nousresearch.com', NULL, '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'amazon', 'Amazon', 'Creator of Nova language models', 'https://aws.amazon.com', NULL, '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'gryphe', 'Gryphe', 'Creator of MythoMax models', 'https://huggingface.co/Gryphe', NULL, '{}'::jsonb),
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'moonshotai', 'Moonshot AI', 'Creator of Kimi language models', 'https://moonshot.ai', NULL, '{}'::jsonb)
ON CONFLICT ("name") DO NOTHING;
-- Insert Models (using CTEs to reference provider and creator IDs)
WITH provider_ids AS (
SELECT "id", "name" FROM "LlmProvider"
),
creator_ids AS (
SELECT "id", "name" FROM "LlmModelCreator"
)
INSERT INTO "LlmModel" ("id", "createdAt", "updatedAt", "slug", "displayName", "description", "providerId", "creatorId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata")
SELECT
gen_random_uuid(),
CURRENT_TIMESTAMP,
CURRENT_TIMESTAMP,
model_slug,
model_display_name,
NULL,
p."id",
c."id",
context_window,
max_output_tokens,
true,
'{}'::jsonb,
'{}'::jsonb
FROM (VALUES
-- OpenAI models (creator: openai)
('o3-2025-04-16', 'O3', 'openai', 'openai', 200000, 100000),
('o3-mini', 'O3 Mini', 'openai', 'openai', 200000, 100000),
('o1', 'O1', 'openai', 'openai', 200000, 100000),
('o1-mini', 'O1 Mini', 'openai', 'openai', 128000, 65536),
('gpt-5.2-2025-12-11', 'GPT-5.2', 'openai', 'openai', 400000, 128000),
('gpt-5-2025-08-07', 'GPT 5', 'openai', 'openai', 400000, 128000),
('gpt-5.1-2025-11-13', 'GPT 5.1', 'openai', 'openai', 400000, 128000),
('gpt-5-mini-2025-08-07', 'GPT 5 Mini', 'openai', 'openai', 400000, 128000),
('gpt-5-nano-2025-08-07', 'GPT 5 Nano', 'openai', 'openai', 400000, 128000),
('gpt-5-chat-latest', 'GPT 5 Chat', 'openai', 'openai', 400000, 16384),
('gpt-4.1-2025-04-14', 'GPT 4.1', 'openai', 'openai', 1000000, 32768),
('gpt-4.1-mini-2025-04-14', 'GPT 4.1 Mini', 'openai', 'openai', 1047576, 32768),
('gpt-4o-mini', 'GPT 4o Mini', 'openai', 'openai', 128000, 16384),
('gpt-4o', 'GPT 4o', 'openai', 'openai', 128000, 16384),
('gpt-4-turbo', 'GPT 4 Turbo', 'openai', 'openai', 128000, 4096),
-- Anthropic models (creator: anthropic)
('claude-opus-4-6', 'Claude Opus 4.6', 'anthropic', 'anthropic', 200000, 128000),
('claude-sonnet-4-6', 'Claude Sonnet 4.6', 'anthropic', 'anthropic', 200000, 64000),
('claude-opus-4-1-20250805', 'Claude 4.1 Opus', 'anthropic', 'anthropic', 200000, 32000),
('claude-opus-4-20250514', 'Claude 4 Opus', 'anthropic', 'anthropic', 200000, 32000),
('claude-sonnet-4-20250514', 'Claude 4 Sonnet', 'anthropic', 'anthropic', 200000, 64000),
('claude-opus-4-5-20251101', 'Claude 4.5 Opus', 'anthropic', 'anthropic', 200000, 64000),
('claude-sonnet-4-5-20250929', 'Claude 4.5 Sonnet', 'anthropic', 'anthropic', 200000, 64000),
('claude-haiku-4-5-20251001', 'Claude 4.5 Haiku', 'anthropic', 'anthropic', 200000, 64000),
('claude-3-haiku-20240307', 'Claude 3 Haiku', 'anthropic', 'anthropic', 200000, 4096),
-- AI/ML API models (creators: alibaba, nvidia, meta)
('Qwen/Qwen2.5-72B-Instruct-Turbo', 'Qwen 2.5 72B', 'aiml_api', 'alibaba', 32000, 8000),
('nvidia/llama-3.1-nemotron-70b-instruct', 'Llama 3.1 Nemotron 70B', 'aiml_api', 'nvidia', 128000, 40000),
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 'Llama 3.3 70B', 'aiml_api', 'meta', 128000, NULL),
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 'Meta Llama 3.1 70B', 'aiml_api', 'meta', 131000, 2000),
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 'Llama 3.2 3B', 'aiml_api', 'meta', 128000, NULL),
-- Groq models (creator: meta for Llama)
('llama-3.3-70b-versatile', 'Llama 3.3 70B', 'groq', 'meta', 128000, 32768),
('llama-3.1-8b-instant', 'Llama 3.1 8B', 'groq', 'meta', 128000, 8192),
-- Ollama models (creators: meta for Llama, mistralai for Mistral)
('llama3.3', 'Llama 3.3', 'ollama', 'meta', 8192, NULL),
('llama3.2', 'Llama 3.2', 'ollama', 'meta', 8192, NULL),
('llama3', 'Llama 3', 'ollama', 'meta', 8192, NULL),
('llama3.1:405b', 'Llama 3.1 405B', 'ollama', 'meta', 8192, NULL),
('dolphin-mistral:latest', 'Dolphin Mistral', 'ollama', 'mistralai', 32768, NULL),
-- OpenRouter models (creators: google, mistralai, cohere, deepseek, perplexity, nousresearch, openai, amazon, microsoft, gryphe, meta, xai, moonshotai, alibaba)
('google/gemini-2.5-pro-preview-03-25', 'Gemini 2.5 Pro', 'open_router', 'google', 1050000, 8192),
('google/gemini-2.5-pro', 'Gemini 2.5 Pro', 'open_router', 'google', 1048576, 65536),
('google/gemini-3.1-pro-preview', 'Gemini 3.1 Pro Preview', 'open_router', 'google', 1048576, 65536),
('google/gemini-3-flash-preview', 'Gemini 3 Flash Preview', 'open_router', 'google', 1048576, 65536),
('google/gemini-2.5-flash', 'Gemini 2.5 Flash', 'open_router', 'google', 1048576, 65535),
('google/gemini-2.0-flash-001', 'Gemini 2.0 Flash', 'open_router', 'google', 1048576, 8192),
('google/gemini-3.1-flash-lite-preview', 'Gemini 3.1 Flash Lite Preview', 'open_router', 'google', 1048576, 65536),
('google/gemini-2.5-flash-lite-preview-06-17', 'Gemini 2.5 Flash Lite Preview', 'open_router', 'google', 1048576, 65535),
('google/gemini-2.0-flash-lite-001', 'Gemini 2.0 Flash Lite', 'open_router', 'google', 1048576, 8192),
('mistralai/mistral-nemo', 'Mistral Nemo', 'open_router', 'mistralai', 128000, 4096),
('mistralai/mistral-large-2512', 'Mistral Large 3 2512', 'open_router', 'mistralai', 262144, NULL),
('mistralai/mistral-medium-3.1', 'Mistral Medium 3.1', 'open_router', 'mistralai', 131072, NULL),
('mistralai/mistral-small-3.2-24b-instruct', 'Mistral Small 3.2 24B', 'open_router', 'mistralai', 131072, 131072),
('mistralai/codestral-2508', 'Codestral 2508', 'open_router', 'mistralai', 256000, NULL),
('cohere/command-r-08-2024', 'Command R', 'open_router', 'cohere', 128000, 4096),
('cohere/command-r-plus-08-2024', 'Command R Plus', 'open_router', 'cohere', 128000, 4096),
('cohere/command-a-03-2025', 'Command A 03.2025', 'open_router', 'cohere', 256000, 8192),
('cohere/command-a-reasoning-08-2025', 'Command A Reasoning 08.2025', 'open_router', 'cohere', 256000, 32768),
('cohere/command-a-translate-08-2025', 'Command A Translate 08.2025', 'open_router', 'cohere', 128000, 8192),
('cohere/command-a-vision-07-2025', 'Command A Vision 07.2025', 'open_router', 'cohere', 128000, 8192),
('deepseek/deepseek-chat', 'DeepSeek Chat', 'open_router', 'deepseek', 64000, 2048),
('deepseek/deepseek-r1-0528', 'DeepSeek R1', 'open_router', 'deepseek', 163840, 163840),
('perplexity/sonar', 'Perplexity Sonar', 'open_router', 'perplexity', 127000, 8000),
('perplexity/sonar-pro', 'Perplexity Sonar Pro', 'open_router', 'perplexity', 200000, 8000),
('perplexity/sonar-deep-research', 'Perplexity Sonar Deep Research', 'open_router', 'perplexity', 128000, 16000),
('perplexity/sonar-reasoning-pro', 'Sonar Reasoning Pro', 'open_router', 'perplexity', 128000, 8000),
('nousresearch/hermes-3-llama-3.1-405b', 'Hermes 3 Llama 3.1 405B', 'open_router', 'nousresearch', 131000, 4096),
('nousresearch/hermes-3-llama-3.1-70b', 'Hermes 3 Llama 3.1 70B', 'open_router', 'nousresearch', 12288, 12288),
('openai/gpt-oss-120b', 'GPT OSS 120B', 'open_router', 'openai', 131072, 131072),
('openai/gpt-oss-20b', 'GPT OSS 20B', 'open_router', 'openai', 131072, 32768),
('amazon/nova-lite-v1', 'Amazon Nova Lite', 'open_router', 'amazon', 300000, 5120),
('amazon/nova-micro-v1', 'Amazon Nova Micro', 'open_router', 'amazon', 128000, 5120),
('amazon/nova-pro-v1', 'Amazon Nova Pro', 'open_router', 'amazon', 300000, 5120),
('microsoft/wizardlm-2-8x22b', 'WizardLM 2 8x22B', 'open_router', 'microsoft', 65536, 4096),
('microsoft/phi-4', 'Phi-4', 'open_router', 'microsoft', 16384, 16384),
('gryphe/mythomax-l2-13b', 'MythoMax L2 13B', 'open_router', 'gryphe', 4096, 4096),
('meta-llama/llama-4-scout', 'Llama 4 Scout', 'open_router', 'meta', 131072, 131072),
('meta-llama/llama-4-maverick', 'Llama 4 Maverick', 'open_router', 'meta', 1048576, 1000000),
('x-ai/grok-3', 'Grok 3', 'open_router', 'xai', 131072, 131072),
('x-ai/grok-4', 'Grok 4', 'open_router', 'xai', 256000, 256000),
('x-ai/grok-4-fast', 'Grok 4 Fast', 'open_router', 'xai', 2000000, 30000),
('x-ai/grok-4.1-fast', 'Grok 4.1 Fast', 'open_router', 'xai', 2000000, 30000),
('x-ai/grok-code-fast-1', 'Grok Code Fast 1', 'open_router', 'xai', 256000, 10000),
('moonshotai/kimi-k2', 'Kimi K2', 'open_router', 'moonshotai', 131000, 131000),
('qwen/qwen3-235b-a22b-thinking-2507', 'Qwen 3 235B Thinking', 'open_router', 'alibaba', 262144, 262144),
('qwen/qwen3-coder', 'Qwen 3 Coder', 'open_router', 'alibaba', 262144, 262144),
-- Llama API models (creator: meta)
('Llama-4-Scout-17B-16E-Instruct-FP8', 'Llama 4 Scout', 'llama_api', 'meta', 128000, 4028),
('Llama-4-Maverick-17B-128E-Instruct-FP8', 'Llama 4 Maverick', 'llama_api', 'meta', 128000, 4028),
('Llama-3.3-8B-Instruct', 'Llama 3.3 8B', 'llama_api', 'meta', 128000, 4028),
('Llama-3.3-70B-Instruct', 'Llama 3.3 70B', 'llama_api', 'meta', 128000, 4028),
-- v0 models (creator: vercel)
('v0-1.5-md', 'v0 1.5 MD', 'v0', 'vercel', 128000, 64000),
('v0-1.5-lg', 'v0 1.5 LG', 'v0', 'vercel', 512000, 64000),
('v0-1.0-md', 'v0 1.0 MD', 'v0', 'vercel', 128000, 64000)
) AS models(model_slug, model_display_name, provider_name, creator_name, context_window, max_output_tokens)
JOIN provider_ids p ON p."name" = models.provider_name
JOIN creator_ids c ON c."name" = models.creator_name
ON CONFLICT ("slug") DO NOTHING;
-- Insert Costs (using CTEs to reference model IDs)
WITH model_ids AS (
SELECT "id", "slug", "providerId" FROM "LlmModel"
),
provider_ids AS (
SELECT "id", "name" FROM "LlmProvider"
)
INSERT INTO "LlmModelCost" ("id", "createdAt", "updatedAt", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId")
SELECT
gen_random_uuid(),
CURRENT_TIMESTAMP,
CURRENT_TIMESTAMP,
'RUN'::"LlmCostUnit",
cost,
p."name",
NULL,
'api_key',
NULL,
'{}'::jsonb,
m."id"
FROM (VALUES
-- OpenAI costs
('o3-2025-04-16', 4),
('o3-mini', 2),
('o1', 16),
('o1-mini', 4),
('gpt-5.2-2025-12-11', 5),
('gpt-5-2025-08-07', 2),
('gpt-5.1-2025-11-13', 5),
('gpt-5-mini-2025-08-07', 1),
('gpt-5-nano-2025-08-07', 1),
('gpt-5-chat-latest', 5),
('gpt-4.1-2025-04-14', 2),
('gpt-4.1-mini-2025-04-14', 1),
('gpt-4o-mini', 1),
('gpt-4o', 3),
('gpt-4-turbo', 10),
-- Anthropic costs
('claude-opus-4-6', 21),
('claude-sonnet-4-6', 5),
('claude-opus-4-1-20250805', 21),
('claude-opus-4-20250514', 21),
('claude-sonnet-4-20250514', 5),
('claude-haiku-4-5-20251001', 4),
('claude-opus-4-5-20251101', 14),
('claude-sonnet-4-5-20250929', 9),
('claude-3-haiku-20240307', 1),
-- AI/ML API costs
('Qwen/Qwen2.5-72B-Instruct-Turbo', 1),
('nvidia/llama-3.1-nemotron-70b-instruct', 1),
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 1),
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 1),
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 1),
-- Groq costs
('llama-3.3-70b-versatile', 1),
('llama-3.1-8b-instant', 1),
-- Ollama costs
('llama3.3', 1),
('llama3.2', 1),
('llama3', 1),
('llama3.1:405b', 1),
('dolphin-mistral:latest', 1),
-- OpenRouter costs
('google/gemini-2.5-pro-preview-03-25', 4),
('google/gemini-2.5-pro', 4),
('google/gemini-3.1-pro-preview', 5),
('google/gemini-3-flash-preview', 3),
('google/gemini-3.1-flash-lite-preview', 1),
('mistralai/mistral-nemo', 1),
('mistralai/mistral-large-2512', 3),
('mistralai/mistral-medium-3.1', 2),
('mistralai/mistral-small-3.2-24b-instruct', 1),
('mistralai/codestral-2508', 2),
('cohere/command-r-08-2024', 1),
('cohere/command-r-plus-08-2024', 3),
('cohere/command-a-03-2025', 2),
('cohere/command-a-reasoning-08-2025', 3),
('cohere/command-a-translate-08-2025', 1),
('cohere/command-a-vision-07-2025', 2),
('deepseek/deepseek-chat', 2),
('perplexity/sonar', 1),
('perplexity/sonar-pro', 5),
('perplexity/sonar-deep-research', 10),
('perplexity/sonar-reasoning-pro', 5),
('nousresearch/hermes-3-llama-3.1-405b', 1),
('nousresearch/hermes-3-llama-3.1-70b', 1),
('amazon/nova-lite-v1', 1),
('amazon/nova-micro-v1', 1),
('amazon/nova-pro-v1', 1),
('microsoft/wizardlm-2-8x22b', 1),
('microsoft/phi-4', 1),
('gryphe/mythomax-l2-13b', 1),
('meta-llama/llama-4-scout', 1),
('meta-llama/llama-4-maverick', 1),
('x-ai/grok-3', 5),
('x-ai/grok-4', 9),
('x-ai/grok-4-fast', 1),
('x-ai/grok-4.1-fast', 1),
('x-ai/grok-code-fast-1', 1),
('moonshotai/kimi-k2', 1),
('qwen/qwen3-235b-a22b-thinking-2507', 1),
('qwen/qwen3-coder', 9),
('google/gemini-2.5-flash', 1),
('google/gemini-2.0-flash-001', 1),
('google/gemini-2.5-flash-lite-preview-06-17', 1),
('google/gemini-2.0-flash-lite-001', 1),
('deepseek/deepseek-r1-0528', 1),
('openai/gpt-oss-120b', 1),
('openai/gpt-oss-20b', 1),
-- Llama API costs
('Llama-4-Scout-17B-16E-Instruct-FP8', 1),
('Llama-4-Maverick-17B-128E-Instruct-FP8', 1),
('Llama-3.3-8B-Instruct', 1),
('Llama-3.3-70B-Instruct', 1),
-- v0 costs
('v0-1.5-md', 1),
('v0-1.5-lg', 2),
('v0-1.0-md', 1)
) AS costs(model_slug, cost)
JOIN model_ids m ON m."slug" = costs.model_slug
JOIN provider_ids p ON p."id" = m."providerId"
ON CONFLICT ("llmModelId", "credentialProvider", "unit") WHERE "credentialId" IS NULL DO NOTHING;

View File

@@ -0,0 +1,107 @@
-- Revert the invite system: drop InvitedUser table + enums, restore User+Profile trigger.
-- Uses current_schema() so the migration works regardless of the configured schema name.
-- 1) Drop the InvitedUser table (also drops its indexes and FK constraints)
DROP TABLE IF EXISTS "InvitedUser";
-- 2) Drop the enums introduced by the invite system
DROP TYPE IF EXISTS "InvitedUserStatus";
DROP TYPE IF EXISTS "TallyComputationStatus";
-- 3) Restore the User+Profile auto-creation trigger on auth.users.
-- Original definition from migration 20250205100104_add_profile_trigger.
-- generate_username() was never dropped and is still present.
-- Uses EXECUTE + format() to inject current_schema() into SET search_path,
-- so the trigger function resolves tables correctly when fired from auth.
DO $$
DECLARE
cs text := current_schema();
BEGIN
EXECUTE format($fn$
CREATE OR REPLACE FUNCTION add_user_and_profile_to_platform()
RETURNS TRIGGER
LANGUAGE plpgsql
SECURITY DEFINER
SET search_path = %I
AS $trigger$
BEGIN
IF NEW.id IS NULL THEN
RAISE EXCEPTION 'Cannot create user/profile: id is null';
END IF;
INSERT INTO "User" (id, email, "updatedAt")
VALUES (NEW.id, NEW.email, now());
INSERT INTO "Profile"
("id", "userId", name, username, description, links, "avatarUrl", "updatedAt")
VALUES (
NEW.id,
NEW.id,
COALESCE(split_part(NEW.email, '@', 1), 'user'),
generate_username(),
'I''m new here',
'{}',
'',
now()
);
RETURN NEW;
EXCEPTION
WHEN OTHERS THEN
RAISE NOTICE 'Error in add_user_and_profile_to_platform: %%', SQLERRM;
RAISE;
END;
$trigger$
$fn$, cs);
END $$;
-- 4) Backfill: create User + Profile rows for any auth.users rows that were
-- created while the trigger was absent (during the invite-system window).
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM information_schema.tables
WHERE table_schema = 'auth' AND table_name = 'users'
) THEN
INSERT INTO "User" (id, email, "updatedAt")
SELECT au.id::text, au.email, now()
FROM auth.users au
LEFT JOIN "User" pu ON pu.id = au.id::text
WHERE pu.id IS NULL
ON CONFLICT (id) DO NOTHING;
INSERT INTO "Profile"
(id, "userId", name, username, description, links, "avatarUrl", "updatedAt")
SELECT
gen_random_uuid()::text,
au.id::text,
COALESCE(NULLIF(split_part(au.email, '@', 1), ''), 'user'),
generate_username(),
'I''m new here',
'{}',
'',
now()
FROM auth.users au
LEFT JOIN "Profile" pp ON pp."userId" = au.id::text
WHERE pp."userId" IS NULL
ON CONFLICT ("userId") DO NOTHING;
END IF;
END $$;
-- 5) Restore the trigger for future signups.
DO $$
BEGIN
IF EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema = 'auth'
AND table_name = 'users'
) THEN
DROP TRIGGER IF EXISTS user_added_to_platform ON auth.users;
CREATE TRIGGER user_added_to_platform
AFTER INSERT ON auth.users
FOR EACH ROW EXECUTE FUNCTION add_user_and_profile_to_platform();
END IF;
END $$;

View File

@@ -65,7 +65,6 @@ model User {
NotificationBatches UserNotificationBatch[]
PendingHumanReviews PendingHumanReview[]
Workspace UserWorkspace?
ClaimedInvite InvitedUser? @relation("InvitedUserAuthUser")
// OAuth Provider relations
OAuthApplications OAuthApplication[]
@@ -74,38 +73,6 @@ model User {
OAuthRefreshTokens OAuthRefreshToken[]
}
enum InvitedUserStatus {
INVITED
CLAIMED
REVOKED
}
enum TallyComputationStatus {
PENDING
RUNNING
READY
FAILED
}
model InvitedUser {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
email String @unique
status InvitedUserStatus @default(INVITED)
authUserId String? @unique
AuthUser User? @relation("InvitedUserAuthUser", fields: [authUserId], references: [id], onDelete: SetNull)
name String?
tallyUnderstanding Json?
tallyStatus TallyComputationStatus @default(PENDING)
tallyComputedAt DateTime?
tallyError String?
@@index([createdAt])
}
enum OnboardingStep {
// Introductory onboarding (Library)
WELCOME
@@ -1025,7 +992,7 @@ model StoreListing {
ActiveVersion StoreListingVersion? @relation("ActiveVersion", fields: [activeVersionId], references: [id])
// The agent link here is only so we can do lookup on agentId
agentGraphId String @unique
agentGraphId String @unique
owningUserId String
OwningUser User @relation(fields: [owningUserId], references: [id])
@@ -1334,3 +1301,164 @@ model OAuthRefreshToken {
@@index([userId, applicationId])
@@index([expiresAt]) // For cleanup
}
// ============================================================================
// LLM Registry Models
// ============================================================================
enum LlmCostUnit {
RUN
TOKENS
}
model LlmProvider {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
name String @unique
displayName String
description String?
defaultCredentialProvider String?
defaultCredentialId String?
defaultCredentialType String?
metadata Json @default("{}")
Models LlmModel[]
}
model LlmModel {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
slug String @unique
displayName String
description String?
providerId String
Provider LlmProvider @relation(fields: [providerId], references: [id], onDelete: Restrict)
// Creator is the organization that created/trained the model (e.g., OpenAI, Meta)
// This is distinct from the provider who hosts/serves the model (e.g., OpenRouter)
creatorId String?
Creator LlmModelCreator? @relation(fields: [creatorId], references: [id], onDelete: SetNull)
contextWindow Int
maxOutputTokens Int?
priceTier Int @default(1) // 1=cheapest, 2=medium, 3=expensive (DB constraint: 1-3)
isEnabled Boolean @default(true)
isRecommended Boolean @default(false)
// Model-specific capabilities
// These vary per model even within the same provider (e.g., Hugging Face)
// Default to false for safety - partially-seeded rows should not be assumed capable
supportsTools Boolean @default(false)
supportsJsonOutput Boolean @default(false)
supportsReasoning Boolean @default(false)
supportsParallelToolCalls Boolean @default(false)
capabilities Json @default("{}")
metadata Json @default("{}")
Costs LlmModelCost[]
SourceMigrations LlmModelMigration[] @relation("SourceMigrations")
TargetMigrations LlmModelMigration[] @relation("TargetMigrations")
@@index([providerId, isEnabled])
@@index([creatorId])
// Note: slug already has @unique which creates an implicit index
}
model LlmModelCost {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
unit LlmCostUnit @default(RUN)
creditCost Int // DB constraint: >= 0
// Provider identifier (e.g., "openai", "anthropic", "openrouter")
// Used to determine which credential system provides the API key.
// Allows different pricing for:
// - Default provider costs (WHERE credentialId IS NULL)
// - User's own API key costs (WHERE credentialId IS NOT NULL)
credentialProvider String
credentialId String?
credentialType String?
currency String?
metadata Json @default("{}")
llmModelId String
Model LlmModel @relation(fields: [llmModelId], references: [id], onDelete: Cascade)
// Note: Unique constraints are implemented as partial indexes in migration SQL:
// - One for default costs (WHERE credentialId IS NULL)
// - One for credential-specific costs (WHERE credentialId IS NOT NULL)
// This allows both provider-level defaults and credential-specific overrides
}
model LlmModelCreator {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
name String @unique // e.g., "openai", "anthropic", "meta"
displayName String // e.g., "OpenAI", "Anthropic", "Meta"
description String?
websiteUrl String? // Link to creator's website
logoUrl String? // URL to creator's logo
metadata Json @default("{}")
Models LlmModel[]
}
model LlmModelMigration {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
sourceModelSlug String // The original model that was disabled
targetModelSlug String // The model workflows were migrated to
reason String? // Why the migration happened (e.g., "Provider outage")
// FK constraints ensure slugs reference valid models
SourceModel LlmModel @relation("SourceMigrations", fields: [sourceModelSlug], references: [slug], onDelete: Restrict)
TargetModel LlmModel @relation("TargetMigrations", fields: [targetModelSlug], references: [slug], onDelete: Restrict)
// Track affected nodes as JSON array of node IDs
// Format: ["node-uuid-1", "node-uuid-2", ...]
migratedNodeIds Json @default("[]")
nodeCount Int // Number of nodes migrated (DB constraint: >= 0)
// Custom pricing override for migrated workflows during the migration period.
// Use case: When migrating users from an expensive model (e.g., GPT-4) to a cheaper
// one (e.g., GPT-3.5), you may want to temporarily maintain the original pricing
// to avoid billing surprises, or offer a discount during the transition.
//
// IMPORTANT: This field is intended for integration with the billing system.
// When billing calculates costs for nodes affected by this migration, it should
// check if customCreditCost is set and use it instead of the target model's cost.
// If null, the target model's normal cost applies.
//
// TODO: Integrate with billing system to apply this override during cost calculation.
// LIMITATION: This is a simple Int and doesn't distinguish RUN vs TOKENS pricing.
// For token-priced models, this may be ambiguous. Consider migrating to a relation
// with LlmModelCost or a dedicated override model in a follow-up PR.
customCreditCost Int? // DB constraint: >= 0 when not null
// Revert tracking
isReverted Boolean @default(false)
revertedAt DateTime?
// Note: Partial unique index in migration SQL prevents multiple active migrations per source:
// UNIQUE (sourceModelSlug) WHERE isReverted = false
@@index([targetModelSlug])
@@index([sourceModelSlug, isReverted]) // Composite index for active migration queries
}

View File

@@ -0,0 +1,123 @@
#!/usr/bin/env bash
# refresh_claude_token.sh — Extract Claude OAuth tokens and update backend/.env
#
# Works on macOS (keychain), Linux (~/.claude/.credentials.json),
# and Windows/WSL (~/.claude/.credentials.json or PowerShell fallback).
#
# Usage:
# ./scripts/refresh_claude_token.sh # auto-detect OS
# ./scripts/refresh_claude_token.sh --env-file /path/to/.env # custom .env path
#
# Prerequisite: You must have run `claude login` at least once on the host.
set -euo pipefail
# --- Parse arguments ---
ENV_FILE=""
while [[ $# -gt 0 ]]; do
case "$1" in
--env-file) ENV_FILE="$2"; shift 2 ;;
*) echo "Unknown option: $1"; exit 1 ;;
esac
done
# Default .env path: relative to this script's location
if [[ -z "$ENV_FILE" ]]; then
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
ENV_FILE="$SCRIPT_DIR/../.env"
fi
# --- Extract tokens by platform ---
ACCESS_TOKEN=""
REFRESH_TOKEN=""
extract_from_credentials_file() {
local creds_file="$1"
if [[ -f "$creds_file" ]]; then
ACCESS_TOKEN=$(jq -r '.claudeAiOauth.accessToken // ""' "$creds_file" 2>/dev/null)
REFRESH_TOKEN=$(jq -r '.claudeAiOauth.refreshToken // ""' "$creds_file" 2>/dev/null)
fi
}
case "$(uname -s)" in
Darwin)
# macOS: extract from system keychain
CREDS_JSON=$(security find-generic-password -s "Claude Code-credentials" -w 2>/dev/null || true)
if [[ -n "$CREDS_JSON" ]]; then
ACCESS_TOKEN=$(echo "$CREDS_JSON" | jq -r '.claudeAiOauth.accessToken // ""' 2>/dev/null)
REFRESH_TOKEN=$(echo "$CREDS_JSON" | jq -r '.claudeAiOauth.refreshToken // ""' 2>/dev/null)
else
# Fallback to credentials file (e.g. if keychain access denied)
extract_from_credentials_file "$HOME/.claude/.credentials.json"
fi
;;
Linux)
# Linux (including WSL): read from credentials file
extract_from_credentials_file "$HOME/.claude/.credentials.json"
;;
MINGW*|MSYS*|CYGWIN*)
# Windows Git Bash / MSYS2 / Cygwin
APPDATA_PATH="${APPDATA:-$USERPROFILE/AppData/Roaming}"
extract_from_credentials_file "$APPDATA_PATH/claude/.credentials.json"
# Fallback to home dir
if [[ -z "$ACCESS_TOKEN" ]]; then
extract_from_credentials_file "$HOME/.claude/.credentials.json"
fi
;;
*)
echo "Unsupported platform: $(uname -s)"
exit 1
;;
esac
# --- Validate ---
if [[ -z "$ACCESS_TOKEN" ]]; then
echo "ERROR: Could not extract Claude OAuth token."
echo ""
echo "Make sure you have run 'claude login' at least once."
echo ""
echo "Locations checked:"
echo " macOS: Keychain ('Claude Code-credentials')"
echo " Linux: ~/.claude/.credentials.json"
echo " Windows: %APPDATA%/claude/.credentials.json"
exit 1
fi
echo "Found Claude OAuth token: ${ACCESS_TOKEN:0:20}..."
[[ -n "$REFRESH_TOKEN" ]] && echo "Found refresh token: ${REFRESH_TOKEN:0:20}..."
# --- Update .env file ---
update_env_var() {
local key="$1" value="$2" file="$3"
if grep -q "^${key}=" "$file" 2>/dev/null; then
# Replace existing value (works on both macOS and Linux sed)
if [[ "$(uname -s)" == "Darwin" ]]; then
sed -i '' "s|^${key}=.*|${key}=${value}|" "$file"
else
sed -i "s|^${key}=.*|${key}=${value}|" "$file"
fi
elif grep -q "^# *${key}=" "$file" 2>/dev/null; then
# Uncomment and set
if [[ "$(uname -s)" == "Darwin" ]]; then
sed -i '' "s|^# *${key}=.*|${key}=${value}|" "$file"
else
sed -i "s|^# *${key}=.*|${key}=${value}|" "$file"
fi
else
# Append
echo "${key}=${value}" >> "$file"
fi
}
if [[ ! -f "$ENV_FILE" ]]; then
echo "WARNING: $ENV_FILE does not exist, creating it."
touch "$ENV_FILE"
fi
update_env_var "CLAUDE_CODE_OAUTH_TOKEN" "$ACCESS_TOKEN" "$ENV_FILE"
[[ -n "$REFRESH_TOKEN" ]] && update_env_var "CLAUDE_CODE_REFRESH_TOKEN" "$REFRESH_TOKEN" "$ENV_FILE"
update_env_var "CHAT_USE_CLAUDE_CODE_SUBSCRIPTION" "true" "$ENV_FILE"
echo ""
echo "Updated $ENV_FILE with Claude subscription tokens."
echo "Run 'docker compose up -d copilot_executor' to apply."

View File

@@ -34,7 +34,7 @@ from backend.data.auth.api_key import create_api_key
from backend.data.credit import get_user_credit_model
from backend.data.db import prisma
from backend.data.graph import Graph, Link, Node, create_graph
from backend.data.invited_user import get_or_activate_user
from backend.data.user import get_or_create_user
from backend.util.clients import get_supabase
faker = Faker()
@@ -151,7 +151,7 @@ class TestDataCreator:
}
# Use the API function to create user in local database
user = await get_or_activate_user(user_data)
user = await get_or_create_user(user_data)
users.append(user.model_dump())
except Exception as e:

View File

@@ -84,7 +84,7 @@ See @CONTRIBUTING.md for complete patterns. Quick reference:
- Regenerate with `pnpm generate:api`
- Pattern: `use{Method}{Version}{OperationName}`
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
5. **Testing**: Add Storybook stories for new components, Playwright for E2E. When fixing a bug, write a failing Playwright test first (use `.fixme` annotation), implement the fix, then remove the annotation.
6. **Code conventions**:
- Use function declarations (not arrow functions) for components/handlers
- Do not use `useCallback` or `useMemo` unless asked to optimise a given function

View File

@@ -73,7 +73,7 @@
"@vercel/analytics": "1.5.0",
"@vercel/speed-insights": "1.2.0",
"@xyflow/react": "12.9.2",
"ai": "6.0.59",
"ai": "6.0.134",
"boring-avatars": "1.11.2",
"canvas-confetti": "1.9.4",
"class-variance-authority": "0.7.1",

View File

@@ -142,8 +142,8 @@ importers:
specifier: 12.9.2
version: 12.9.2(@types/react@18.3.17)(immer@11.1.3)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
ai:
specifier: 6.0.59
version: 6.0.59(zod@3.25.76)
specifier: 6.0.134
version: 6.0.134(zod@3.25.76)
boring-avatars:
specifier: 1.11.2
version: 1.11.2
@@ -448,16 +448,32 @@ packages:
peerDependencies:
zod: ^3.25.76 || ^4.1.8
'@ai-sdk/gateway@3.0.77':
resolution: {integrity: sha512-UdwIG2H2YMuntJQ5L+EmED5XiwnlvDT3HOmKfVFxR4Nq/RSLFA/HcchhwfNXHZ5UJjyuL2VO0huLbWSZ9ijemQ==}
engines: {node: '>=18'}
peerDependencies:
zod: ^3.25.76 || ^4.1.8
'@ai-sdk/provider-utils@4.0.10':
resolution: {integrity: sha512-VeDAiCH+ZK8Xs4hb9Cw7pHlujWNL52RKe8TExOkrw6Ir1AmfajBZTb9XUdKOZO08RwQElIKA8+Ltm+Gqfo8djQ==}
engines: {node: '>=18'}
peerDependencies:
zod: ^3.25.76 || ^4.1.8
'@ai-sdk/provider-utils@4.0.21':
resolution: {integrity: sha512-MtFUYI1/8mgDvRmaBDjbLJPFFrMG777AvSgyIFQtZHIMzm88R/12vYBBpnk7pfiWLFE1DSZzY4WDYzGbKAcmiw==}
engines: {node: '>=18'}
peerDependencies:
zod: ^3.25.76 || ^4.1.8
'@ai-sdk/provider@3.0.5':
resolution: {integrity: sha512-2Xmoq6DBJqmSl80U6V9z5jJSJP7ehaJJQMy2iFUqTay06wdCqTnPVBBQbtEL8RCChenL+q5DC5H5WzU3vV3v8w==}
engines: {node: '>=18'}
'@ai-sdk/provider@3.0.8':
resolution: {integrity: sha512-oGMAgGoQdBXbZqNG0Ze56CHjDZ1IDYOwGYxYjO5KLSlz5HiNQ9udIXsPZ61VWaHGZ5XW/jyjmr6t2xz2jGVwbQ==}
engines: {node: '>=18'}
'@ai-sdk/react@3.0.61':
resolution: {integrity: sha512-vCjZBnY2+TawFBXamSKt6elAt9n1MXMfcjSd9DSgT9peCJN27qNGVSXgaGNh/B3cUgeOktFfhB2GVmIqOjvmLQ==}
engines: {node: '>=18'}
@@ -4053,6 +4069,12 @@ packages:
resolution: {integrity: sha512-MnA+YT8fwfJPgBx3m60MNqakm30XOkyIoH1y6huTQvC0PwZG7ki8NacLBcrPbNoo8vEZy7Jpuk7+jMO+CUovTQ==}
engines: {node: '>= 14'}
ai@6.0.134:
resolution: {integrity: sha512-YalNEaavld/kE444gOcsMKXdVVRGEe0SK77fAFcWYcqLg+a7xKnEet8bdfrEAJTfnMjj01rhgrIL10903w1a5Q==}
engines: {node: '>=18'}
peerDependencies:
zod: ^3.25.76 || ^4.1.8
ai@6.0.59:
resolution: {integrity: sha512-9SfCvcr4kVk4t8ZzIuyHpuL1hFYKsYMQfBSbBq3dipXPa+MphARvI8wHEjNaRqYl3JOsJbWxEBIMqHL0L92mUA==}
engines: {node: '>=18'}
@@ -8718,6 +8740,13 @@ snapshots:
'@vercel/oidc': 3.1.0
zod: 3.25.76
'@ai-sdk/gateway@3.0.77(zod@3.25.76)':
dependencies:
'@ai-sdk/provider': 3.0.8
'@ai-sdk/provider-utils': 4.0.21(zod@3.25.76)
'@vercel/oidc': 3.1.0
zod: 3.25.76
'@ai-sdk/provider-utils@4.0.10(zod@3.25.76)':
dependencies:
'@ai-sdk/provider': 3.0.5
@@ -8725,10 +8754,21 @@ snapshots:
eventsource-parser: 3.0.6
zod: 3.25.76
'@ai-sdk/provider-utils@4.0.21(zod@3.25.76)':
dependencies:
'@ai-sdk/provider': 3.0.8
'@standard-schema/spec': 1.1.0
eventsource-parser: 3.0.6
zod: 3.25.76
'@ai-sdk/provider@3.0.5':
dependencies:
json-schema: 0.4.0
'@ai-sdk/provider@3.0.8':
dependencies:
json-schema: 0.4.0
'@ai-sdk/react@3.0.61(react@18.3.1)(zod@3.25.76)':
dependencies:
'@ai-sdk/provider-utils': 4.0.10(zod@3.25.76)
@@ -12798,6 +12838,14 @@ snapshots:
agent-base@7.1.4:
optional: true
ai@6.0.134(zod@3.25.76):
dependencies:
'@ai-sdk/gateway': 3.0.77(zod@3.25.76)
'@ai-sdk/provider': 3.0.8
'@ai-sdk/provider-utils': 4.0.21(zod@3.25.76)
'@opentelemetry/api': 1.9.0
zod: 3.25.76
ai@6.0.59(zod@3.25.76):
dependencies:
'@ai-sdk/gateway': 3.0.27(zod@3.25.76)
@@ -14066,8 +14114,8 @@ snapshots:
'@typescript-eslint/parser': 8.52.0(eslint@8.57.1)(typescript@5.9.3)
eslint: 8.57.1
eslint-import-resolver-node: 0.3.9
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1)
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1)
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1)
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1)
eslint-plugin-jsx-a11y: 6.10.2(eslint@8.57.1)
eslint-plugin-react: 7.37.5(eslint@8.57.1)
eslint-plugin-react-hooks: 5.2.0(eslint@8.57.1)
@@ -14086,7 +14134,7 @@ snapshots:
transitivePeerDependencies:
- supports-color
eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1):
eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1):
dependencies:
'@nolyfill/is-core-module': 1.0.39
debug: 4.4.3
@@ -14097,22 +14145,22 @@ snapshots:
tinyglobby: 0.2.15
unrs-resolver: 1.11.1
optionalDependencies:
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1)
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1)
transitivePeerDependencies:
- supports-color
eslint-module-utils@2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1):
eslint-module-utils@2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1):
dependencies:
debug: 3.2.7
optionalDependencies:
'@typescript-eslint/parser': 8.52.0(eslint@8.57.1)(typescript@5.9.3)
eslint: 8.57.1
eslint-import-resolver-node: 0.3.9
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1)
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1)
transitivePeerDependencies:
- supports-color
eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1):
eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1):
dependencies:
'@rtsao/scc': 1.1.0
array-includes: 3.1.9
@@ -14123,7 +14171,7 @@ snapshots:
doctrine: 2.1.0
eslint: 8.57.1
eslint-import-resolver-node: 0.3.9
eslint-module-utils: 2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1)
eslint-module-utils: 2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1)
hasown: 2.0.2
is-core-module: 2.16.1
is-glob: 4.0.3

View File

@@ -1,14 +1,7 @@
"use client";
import { Sidebar } from "@/components/__legacy__/Sidebar";
import {
UsersIcon,
CurrencyDollarSimpleIcon,
UserPlusIcon,
MagnifyingGlassIcon,
FileTextIcon,
SlidersHorizontalIcon,
} from "@phosphor-icons/react";
import { Users, DollarSign, UserSearch, FileText } from "lucide-react";
import { IconSliders } from "@/components/__legacy__/ui/icons";
const sidebarLinkGroups = [
{
@@ -16,32 +9,27 @@ const sidebarLinkGroups = [
{
text: "Marketplace Management",
href: "/admin/marketplace",
icon: <UsersIcon size={24} />,
icon: <Users className="h-6 w-6" />,
},
{
text: "User Spending",
href: "/admin/spending",
icon: <CurrencyDollarSimpleIcon size={24} />,
},
{
text: "Beta Invites",
href: "/admin/users",
icon: <UserPlusIcon size={24} />,
icon: <DollarSign className="h-6 w-6" />,
},
{
text: "User Impersonation",
href: "/admin/impersonation",
icon: <MagnifyingGlassIcon size={24} />,
icon: <UserSearch className="h-6 w-6" />,
},
{
text: "Execution Analytics",
href: "/admin/execution-analytics",
icon: <FileTextIcon size={24} />,
icon: <FileText className="h-6 w-6" />,
},
{
text: "Admin User Management",
href: "/admin/settings",
icon: <SlidersHorizontalIcon size={24} />,
icon: <IconSliders className="h-6 w-6" />,
},
],
},

View File

@@ -1,80 +0,0 @@
"use client";
import { Card } from "@/components/atoms/Card/Card";
import { BulkInviteForm } from "../BulkInviteForm/BulkInviteForm";
import { InviteUserForm } from "../InviteUserForm/InviteUserForm";
import { InvitedUsersTable } from "../InvitedUsersTable/InvitedUsersTable";
import { useAdminUsersPage } from "../../useAdminUsersPage";
export function AdminUsersPage() {
const {
email,
name,
bulkInviteFile,
bulkInviteInputKey,
lastBulkInviteResult,
invitedUsers,
isLoadingInvitedUsers,
isRefreshingInvitedUsers,
isCreatingInvite,
isBulkInviting,
pendingInviteAction,
setEmail,
setName,
handleBulkInviteFileChange,
handleBulkInviteSubmit,
handleCreateInvite,
handleRetryTally,
handleRevoke,
} = useAdminUsersPage();
return (
<div className="mx-auto flex max-w-7xl flex-col gap-6 p-6">
<div className="flex flex-col gap-2">
<h1 className="text-3xl font-bold text-zinc-900">Beta Invites</h1>
<p className="max-w-3xl text-sm text-zinc-600">
Pre-provision beta users before they sign up. Invites store the
platform-side record, run Tally understanding extraction, and activate
the real account on the user&apos;s first authenticated request.
</p>
</div>
<div className="grid gap-6 xl:grid-cols-[24rem,1fr]">
<div className="flex flex-col gap-6">
<Card className="border border-zinc-200 shadow-sm">
<InviteUserForm
email={email}
name={name}
isSubmitting={isCreatingInvite}
onEmailChange={setEmail}
onNameChange={setName}
onSubmit={handleCreateInvite}
/>
</Card>
<Card className="border border-zinc-200 shadow-sm">
<BulkInviteForm
selectedFile={bulkInviteFile}
inputKey={bulkInviteInputKey}
isSubmitting={isBulkInviting}
lastResult={lastBulkInviteResult}
onFileChange={handleBulkInviteFileChange}
onSubmit={handleBulkInviteSubmit}
/>
</Card>
</div>
<Card className="border border-zinc-200 shadow-sm">
<InvitedUsersTable
invitedUsers={invitedUsers}
isLoading={isLoadingInvitedUsers}
isRefreshing={isRefreshingInvitedUsers}
pendingInviteAction={pendingInviteAction}
onRetryTally={handleRetryTally}
onRevoke={handleRevoke}
/>
</Card>
</div>
</div>
);
}

View File

@@ -1,135 +0,0 @@
"use client";
import type { BulkInvitedUsersResponse } from "@/app/api/__generated__/models/bulkInvitedUsersResponse";
import { Badge } from "@/components/atoms/Badge/Badge";
import { Button } from "@/components/atoms/Button/Button";
import type { FormEvent } from "react";
interface Props {
selectedFile: File | null;
inputKey: number;
isSubmitting: boolean;
lastResult: BulkInvitedUsersResponse | null;
onFileChange: (file: File | null) => void;
onSubmit: (event: FormEvent<HTMLFormElement>) => void;
}
function getStatusVariant(status: "CREATED" | "SKIPPED" | "ERROR") {
if (status === "CREATED") {
return "success";
}
if (status === "ERROR") {
return "error";
}
return "info";
}
export function BulkInviteForm({
selectedFile,
inputKey,
isSubmitting,
lastResult,
onFileChange,
onSubmit,
}: Props) {
return (
<form className="flex flex-col gap-4" onSubmit={onSubmit}>
<div className="flex flex-col gap-1">
<h2 className="text-xl font-semibold text-zinc-900">Bulk invite</h2>
<p className="text-sm text-zinc-600">
Upload a <span className="font-medium text-zinc-800">.txt</span> file
with one email per line, or a{" "}
<span className="font-medium text-zinc-800">.csv</span> with
<span className="font-medium text-zinc-800"> email</span> and optional
<span className="font-medium text-zinc-800"> name</span> columns.
</p>
</div>
<label
htmlFor="bulk-invite-file-input"
className="flex cursor-pointer flex-col gap-2 rounded-2xl border border-dashed border-zinc-300 bg-zinc-50 px-4 py-5 text-sm text-zinc-600 transition-colors focus-within:ring-2 focus-within:ring-zinc-500 focus-within:ring-offset-2 hover:border-zinc-400 hover:bg-zinc-100"
>
<span className="font-medium text-zinc-900">
{selectedFile ? selectedFile.name : "Choose invite file"}
</span>
<span>Maximum 500 rows, UTF-8 encoded.</span>
<input
id="bulk-invite-file-input"
key={inputKey}
type="file"
accept=".txt,.csv,text/plain,text/csv"
disabled={isSubmitting}
className="sr-only"
onChange={(event) =>
onFileChange(event.target.files?.item(0) ?? null)
}
/>
</label>
<Button
type="submit"
variant="primary"
loading={isSubmitting}
disabled={!selectedFile}
className="w-full"
>
{isSubmitting ? "Uploading invites..." : "Upload invite file"}
</Button>
{lastResult ? (
<div className="flex flex-col gap-3 rounded-2xl border border-zinc-200 bg-zinc-50 p-4">
<div className="grid grid-cols-3 gap-2 text-center">
<div className="rounded-xl bg-white px-3 py-2">
<div className="text-lg font-semibold text-zinc-900">
{lastResult.created_count}
</div>
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
Created
</div>
</div>
<div className="rounded-xl bg-white px-3 py-2">
<div className="text-lg font-semibold text-zinc-900">
{lastResult.skipped_count}
</div>
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
Skipped
</div>
</div>
<div className="rounded-xl bg-white px-3 py-2">
<div className="text-lg font-semibold text-zinc-900">
{lastResult.error_count}
</div>
<div className="text-xs uppercase tracking-[0.16em] text-zinc-500">
Errors
</div>
</div>
</div>
<div className="max-h-64 overflow-y-auto rounded-xl border border-zinc-200 bg-white">
<div className="flex flex-col divide-y divide-zinc-100">
{lastResult.results.map((row) => (
<div
key={`${row.row_number}-${row.email ?? row.message}`}
className="flex items-start gap-3 px-3 py-3"
>
<Badge variant={getStatusVariant(row.status)} size="small">
{row.status}
</Badge>
<div className="flex min-w-0 flex-1 flex-col gap-1">
<span className="text-sm font-medium text-zinc-900">
Row {row.row_number}
{row.email ? ` · ${row.email}` : ""}
</span>
<span className="text-xs text-zinc-500">{row.message}</span>
</div>
</div>
))}
</div>
</div>
</div>
) : null}
</form>
);
}

View File

@@ -1,66 +0,0 @@
"use client";
import { Button } from "@/components/atoms/Button/Button";
import { Input } from "@/components/atoms/Input/Input";
import type { FormEvent } from "react";
interface Props {
email: string;
name: string;
isSubmitting: boolean;
onEmailChange: (value: string) => void;
onNameChange: (value: string) => void;
onSubmit: (event: FormEvent<HTMLFormElement>) => void;
}
export function InviteUserForm({
email,
name,
isSubmitting,
onEmailChange,
onNameChange,
onSubmit,
}: Props) {
return (
<form className="flex flex-col gap-4" onSubmit={onSubmit}>
<div className="flex flex-col gap-1">
<h2 className="text-xl font-semibold text-zinc-900">Create invite</h2>
<p className="text-sm text-zinc-600">
The invite is stored immediately, then Tally pre-seeding starts in the
background.
</p>
</div>
<Input
id="invite-email"
label="Email"
type="email"
value={email}
placeholder="jane@example.com"
autoComplete="email"
disabled={isSubmitting}
onChange={(event) => onEmailChange(event.target.value)}
/>
<Input
id="invite-name"
label="Name"
type="text"
value={name}
placeholder="Jane Doe"
disabled={isSubmitting}
onChange={(event) => onNameChange(event.target.value)}
/>
<Button
type="submit"
variant="primary"
loading={isSubmitting}
disabled={!email.trim()}
className="w-full"
>
{isSubmitting ? "Creating invite..." : "Create invite"}
</Button>
</form>
);
}

View File

@@ -1,209 +0,0 @@
"use client";
import type { InvitedUserResponse } from "@/app/api/__generated__/models/invitedUserResponse";
import { Badge } from "@/components/atoms/Badge/Badge";
import { Button } from "@/components/atoms/Button/Button";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/__legacy__/ui/table";
interface Props {
invitedUsers: InvitedUserResponse[];
isLoading: boolean;
isRefreshing: boolean;
pendingInviteAction: string | null;
onRetryTally: (invitedUserId: string) => void;
onRevoke: (invitedUserId: string) => void;
}
function getInviteBadgeVariant(status: InvitedUserResponse["status"]) {
if (status === "CLAIMED") {
return "success";
}
if (status === "REVOKED") {
return "error";
}
return "info";
}
function getTallyBadgeVariant(status: InvitedUserResponse["tally_status"]) {
if (status === "READY") {
return "success";
}
if (status === "FAILED") {
return "error";
}
return "info";
}
function formatDate(value: Date | undefined) {
if (!value) {
return "-";
}
return value.toLocaleString();
}
function getTallySummary(invitedUser: InvitedUserResponse) {
if (invitedUser.tally_status === "FAILED" && invitedUser.tally_error) {
return invitedUser.tally_error;
}
if (invitedUser.tally_status === "READY" && invitedUser.tally_understanding) {
return "Stored and ready for activation";
}
if (invitedUser.tally_status === "READY") {
return "No matching Tally submission found";
}
if (invitedUser.tally_status === "RUNNING") {
return "Extraction in progress";
}
return "Waiting to run";
}
function isActionPending(
pendingInviteAction: string | null,
action: "retry" | "revoke",
invitedUserId: string,
) {
return pendingInviteAction === `${action}:${invitedUserId}`;
}
export function InvitedUsersTable({
invitedUsers,
isLoading,
isRefreshing,
pendingInviteAction,
onRetryTally,
onRevoke,
}: Props) {
return (
<div className="flex flex-col gap-4">
<div className="flex items-center justify-between gap-4">
<div className="flex flex-col gap-1">
<h2 className="text-xl font-semibold text-zinc-900">Invited users</h2>
<p className="text-sm text-zinc-600">
Live invite state, claim status, and Tally pre-seeding progress.
</p>
</div>
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
{isRefreshing ? "Refreshing" : `${invitedUsers.length} total`}
</span>
</div>
<div className="overflow-hidden rounded-2xl border border-zinc-200">
<Table>
<TableHeader className="bg-zinc-50">
<TableRow>
<TableHead>Email</TableHead>
<TableHead>Name</TableHead>
<TableHead>Invite</TableHead>
<TableHead>Tally</TableHead>
<TableHead>Claimed User</TableHead>
<TableHead>Created</TableHead>
<TableHead className="text-right">Actions</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{isLoading ? (
<TableRow>
<TableCell
colSpan={7}
className="py-10 text-center text-zinc-500"
>
Loading invited users...
</TableCell>
</TableRow>
) : invitedUsers.length === 0 ? (
<TableRow>
<TableCell
colSpan={7}
className="py-10 text-center text-zinc-500"
>
No invited users yet
</TableCell>
</TableRow>
) : (
invitedUsers.map((invitedUser) => (
<TableRow key={invitedUser.id} className="align-top">
<TableCell className="font-medium text-zinc-900">
{invitedUser.email}
</TableCell>
<TableCell>{invitedUser.name || "-"}</TableCell>
<TableCell>
<Badge variant={getInviteBadgeVariant(invitedUser.status)}>
{invitedUser.status}
</Badge>
</TableCell>
<TableCell>
<div className="flex max-w-xs flex-col gap-2">
<Badge
variant={getTallyBadgeVariant(invitedUser.tally_status)}
>
{invitedUser.tally_status}
</Badge>
<span className="text-xs text-zinc-500">
{getTallySummary(invitedUser)}
</span>
<span className="text-xs text-zinc-400">
{formatDate(invitedUser.tally_computed_at ?? undefined)}
</span>
</div>
</TableCell>
<TableCell className="font-mono text-xs text-zinc-500">
{invitedUser.auth_user_id || "-"}
</TableCell>
<TableCell className="text-sm text-zinc-500">
{formatDate(invitedUser.created_at)}
</TableCell>
<TableCell>
<div className="flex justify-end gap-2">
<Button
variant="outline"
size="small"
disabled={invitedUser.status === "REVOKED"}
loading={isActionPending(
pendingInviteAction,
"retry",
invitedUser.id,
)}
onClick={() => onRetryTally(invitedUser.id)}
>
Retry Tally
</Button>
<Button
variant="secondary"
size="small"
disabled={invitedUser.status !== "INVITED"}
loading={isActionPending(
pendingInviteAction,
"revoke",
invitedUser.id,
)}
onClick={() => onRevoke(invitedUser.id)}
>
Revoke
</Button>
</div>
</TableCell>
</TableRow>
))
)}
</TableBody>
</Table>
</div>
</div>
);
}

View File

@@ -1,11 +1,16 @@
import { withRoleAccess } from "@/lib/withRoleAccess";
import { AdminUsersPage } from "./components/AdminUsersPage/AdminUsersPage";
import React from "react";
function AdminUsers() {
return <AdminUsersPage />;
return (
<div>
<h1>Users Dashboard</h1>
{/* Add your admin-only content here */}
</div>
);
}
export default async function AdminUsersRoute() {
export default async function AdminUsersPage() {
"use server";
const withAdminAccess = await withRoleAccess(["admin"]);
const ProtectedAdminUsers = await withAdminAccess(AdminUsers);

View File

@@ -1,197 +0,0 @@
"use client";
import type { BulkInvitedUsersResponse } from "@/app/api/__generated__/models/bulkInvitedUsersResponse";
import { okData } from "@/app/api/helpers";
import {
getGetV2ListInvitedUsersQueryKey,
useGetV2ListInvitedUsers,
usePostV2BulkCreateInvitedUsers,
usePostV2CreateInvitedUser,
usePostV2RetryInvitedUserTally,
usePostV2RevokeInvitedUser,
} from "@/app/api/__generated__/endpoints/admin/admin";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { useQueryClient } from "@tanstack/react-query";
import { type FormEvent, useState } from "react";
function getErrorMessage(error: unknown) {
if (error instanceof Error) {
return error.message;
}
return "Something went wrong";
}
export function useAdminUsersPage() {
const queryClient = useQueryClient();
const { toast } = useToast();
const [email, setEmail] = useState("");
const [name, setName] = useState("");
const [bulkInviteFile, setBulkInviteFile] = useState<File | null>(null);
const [bulkInviteInputKey, setBulkInviteInputKey] = useState(0);
const [lastBulkInviteResult, setLastBulkInviteResult] =
useState<BulkInvitedUsersResponse | null>(null);
const [pendingInviteAction, setPendingInviteAction] = useState<string | null>(
null,
);
const invitedUsersQuery = useGetV2ListInvitedUsers(undefined, {
query: {
select: okData,
refetchInterval: 30_000,
},
});
const createInvitedUserMutation = usePostV2CreateInvitedUser({
mutation: {
onSuccess: async () => {
setEmail("");
setName("");
await queryClient.invalidateQueries({
queryKey: getGetV2ListInvitedUsersQueryKey(),
});
toast({
title: "Invited user created",
variant: "default",
});
},
onError: (error) => {
toast({
title: getErrorMessage(error),
variant: "destructive",
});
},
},
});
const bulkCreateInvitedUsersMutation = usePostV2BulkCreateInvitedUsers({
mutation: {
onSuccess: async (response) => {
const result = okData(response) ?? null;
setBulkInviteFile(null);
setBulkInviteInputKey((currentValue) => currentValue + 1);
setLastBulkInviteResult(result);
await queryClient.invalidateQueries({
queryKey: getGetV2ListInvitedUsersQueryKey(),
});
toast({
title: result
? `${result.created_count} invites created`
: "Bulk invite upload complete",
variant: "default",
});
},
onError: (error) => {
toast({
title: getErrorMessage(error),
variant: "destructive",
});
},
},
});
const retryInvitedUserTallyMutation = usePostV2RetryInvitedUserTally({
mutation: {
onSuccess: async () => {
setPendingInviteAction(null);
await queryClient.invalidateQueries({
queryKey: getGetV2ListInvitedUsersQueryKey(),
});
toast({
title: "Tally pre-seeding restarted",
variant: "default",
});
},
onError: (error) => {
setPendingInviteAction(null);
toast({
title: getErrorMessage(error),
variant: "destructive",
});
},
},
});
const revokeInvitedUserMutation = usePostV2RevokeInvitedUser({
mutation: {
onSuccess: async () => {
setPendingInviteAction(null);
await queryClient.invalidateQueries({
queryKey: getGetV2ListInvitedUsersQueryKey(),
});
toast({
title: "Invite revoked",
variant: "default",
});
},
onError: (error) => {
setPendingInviteAction(null);
toast({
title: getErrorMessage(error),
variant: "destructive",
});
},
},
});
function handleCreateInvite(event: FormEvent<HTMLFormElement>) {
event.preventDefault();
createInvitedUserMutation.mutate({
data: {
email,
name: name.trim() || null,
},
});
}
function handleRetryTally(invitedUserId: string) {
setPendingInviteAction(`retry:${invitedUserId}`);
retryInvitedUserTallyMutation.mutate({ invitedUserId });
}
function handleBulkInviteFileChange(file: File | null) {
setBulkInviteFile(file);
}
function handleBulkInviteSubmit(event: FormEvent<HTMLFormElement>) {
event.preventDefault();
if (!bulkInviteFile) {
return;
}
bulkCreateInvitedUsersMutation.mutate({
data: {
file: bulkInviteFile,
},
});
}
function handleRevoke(invitedUserId: string) {
setPendingInviteAction(`revoke:${invitedUserId}`);
revokeInvitedUserMutation.mutate({ invitedUserId });
}
return {
email,
name,
bulkInviteFile,
bulkInviteInputKey,
lastBulkInviteResult,
invitedUsers: invitedUsersQuery.data?.invited_users ?? [],
invitedUsersError: invitedUsersQuery.error,
isLoadingInvitedUsers: invitedUsersQuery.isLoading,
isRefreshingInvitedUsers: invitedUsersQuery.isFetching,
isCreatingInvite: createInvitedUserMutation.isPending,
isBulkInviting: bulkCreateInvitedUsersMutation.isPending,
pendingInviteAction,
setEmail,
setName,
handleBulkInviteFileChange,
handleBulkInviteSubmit,
handleCreateInvite,
handleRetryTally,
handleRevoke,
};
}

View File

@@ -1,9 +1,7 @@
"use client";
import { useGetV2GetSuggestedPrompts } from "@/app/api/__generated__/endpoints/chat/chat";
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
import { Button } from "@/components/atoms/Button/Button";
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
import { Text } from "@/components/atoms/Text/Text";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { SpinnerGapIcon } from "@phosphor-icons/react";
@@ -35,42 +33,18 @@ export function EmptySession({
}: Props) {
const { user } = useSupabase();
const greetingName = getGreetingName(user);
const { data: suggestedPromptsResponse, isLoading: isLoadingPrompts } =
useGetV2GetSuggestedPrompts({
query: { staleTime: Infinity },
});
const customPrompts =
suggestedPromptsResponse?.status === 200
? suggestedPromptsResponse.data.prompts
: undefined;
const quickActions = getQuickActions(customPrompts);
const quickActions = getQuickActions();
const [loadingAction, setLoadingAction] = useState<string | null>(null);
const [inputPlaceholder, setInputPlaceholder] = useState(
getInputPlaceholder(),
);
// Use matchMedia instead of resize event — fires only when crossing
// the 500px and 1081px breakpoints defined in getInputPlaceholder(),
// rather than dozens of times per second during a window drag.
useEffect(() => {
function update() {
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
}
const mq500 = window.matchMedia("(min-width: 500px)");
const mq1081 = window.matchMedia("(min-width: 1081px)");
update();
mq500.addEventListener("change", update);
mq1081.addEventListener("change", update);
return () => {
mq500.removeEventListener("change", update);
mq1081.removeEventListener("change", update);
};
}, []);
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
}, [window.innerWidth]);
async function handleQuickActionClick(action: string) {
if (isCreatingSession || loadingAction) return;
if (isCreatingSession || loadingAction !== null) return;
setLoadingAction(action);
try {
await onSend(action);
@@ -116,32 +90,28 @@ export function EmptySession({
</div>
<div className="flex flex-wrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
{isLoadingPrompts
? Array.from({ length: 3 }, (_, i) => (
<Skeleton key={i} className="h-10 w-64 shrink-0 rounded-full" />
))
: quickActions.map((action) => (
<Button
key={action}
type="button"
variant="outline"
size="small"
onClick={() => void handleQuickActionClick(action)}
disabled={isCreatingSession || loadingAction !== null}
aria-busy={loadingAction === action}
leftIcon={
loadingAction === action ? (
<SpinnerGapIcon
className="h-4 w-4 animate-spin"
weight="bold"
/>
) : null
}
className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
>
{action}
</Button>
))}
{quickActions.map((action) => (
<Button
key={action}
type="button"
variant="outline"
size="small"
onClick={() => void handleQuickActionClick(action)}
disabled={isCreatingSession || loadingAction !== null}
aria-busy={loadingAction === action}
leftIcon={
loadingAction === action ? (
<SpinnerGapIcon
className="h-4 w-4 animate-spin"
weight="bold"
/>
) : null
}
className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
>
{action}
</Button>
))}
</div>
</motion.div>
</div>

View File

@@ -12,17 +12,12 @@ export function getInputPlaceholder(width?: number) {
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
}
const DEFAULT_QUICK_ACTIONS = [
"I don't know where to start, just ask me stuff",
"I do the same thing every week and it's killing me",
"Help me find where I'm wasting my time",
];
export function getQuickActions(customPrompts?: string[]) {
if (customPrompts && customPrompts.length > 0) {
return customPrompts;
}
return DEFAULT_QUICK_ACTIONS;
export function getQuickActions() {
return [
"I don't know where to start, just ask me stuff",
"I do the same thing every week and it's killing me",
"Help me find where I'm wasting my time",
];
}
export function getGreetingName(user?: User | null) {

View File

@@ -15,46 +15,11 @@ import { useCopilotUIStore } from "./store";
import { useChatSession } from "./useChatSession";
import { useCopilotNotifications } from "./useCopilotNotifications";
import { useCopilotStream } from "./useCopilotStream";
import { useWorkflowImportAutoSubmit } from "./useWorkflowImportAutoSubmit";
const TITLE_POLL_INTERVAL_MS = 2_000;
const TITLE_POLL_MAX_ATTEMPTS = 5;
/**
* Extract a prompt from the URL hash fragment.
* Supports: /copilot#prompt=URL-encoded-text
* Optionally auto-submits if ?autosubmit=true is in the query string.
* Returns null if no prompt is present.
*/
function extractPromptFromUrl(): {
prompt: string;
autosubmit: boolean;
} | null {
if (typeof window === "undefined") return null;
const hash = window.location.hash;
if (!hash) return null;
const hashParams = new URLSearchParams(hash.slice(1));
const prompt = hashParams.get("prompt");
if (!prompt || !prompt.trim()) return null;
const searchParams = new URLSearchParams(window.location.search);
const autosubmit = searchParams.get("autosubmit") === "true";
// Clean up hash + autosubmit param only (preserve other query params)
const cleanURL = new URL(window.location.href);
cleanURL.hash = "";
cleanURL.searchParams.delete("autosubmit");
window.history.replaceState(
null,
"",
`${cleanURL.pathname}${cleanURL.search}`,
);
return { prompt: prompt.trim(), autosubmit };
}
interface UploadedFile {
file_id: string;
name: string;
@@ -130,16 +95,23 @@ export function useCopilotPage() {
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
const pendingFilesRef = useRef<File[]>([]);
// Pre-built file parts from workflow import (already uploaded, skip re-upload)
const pendingFilePartsRef = useRef<FileUIPart[]>([]);
// --- Send pending message after session creation ---
useEffect(() => {
if (!sessionId || pendingMessage === null) return;
const msg = pendingMessage;
const files = pendingFilesRef.current;
const prebuiltParts = pendingFilePartsRef.current;
setPendingMessage(null);
pendingFilesRef.current = [];
pendingFilePartsRef.current = [];
if (files.length > 0) {
if (prebuiltParts.length > 0) {
// File already uploaded (e.g. workflow import) — send directly
sendMessage({ text: msg, files: prebuiltParts });
} else if (files.length > 0) {
setIsUploadingFiles(true);
void uploadFiles(files, sessionId)
.then((uploaded) => {
@@ -164,26 +136,11 @@ export function useCopilotPage() {
}, [sessionId, pendingMessage, sendMessage]);
// --- Extract prompt from URL hash on mount (e.g. /copilot#prompt=Hello) ---
const { setInitialPrompt } = useCopilotUIStore();
const hasProcessedUrlPrompt = useRef(false);
useEffect(() => {
if (hasProcessedUrlPrompt.current) return;
const urlPrompt = extractPromptFromUrl();
if (!urlPrompt) return;
hasProcessedUrlPrompt.current = true;
if (urlPrompt.autosubmit) {
setPendingMessage(urlPrompt.prompt);
void createSession().catch(() => {
setPendingMessage(null);
setInitialPrompt(urlPrompt.prompt);
});
} else {
setInitialPrompt(urlPrompt.prompt);
}
}, [createSession, setInitialPrompt]);
useWorkflowImportAutoSubmit({
createSession,
setPendingMessage,
pendingFilePartsRef,
});
async function uploadFiles(
files: File[],

View File

@@ -0,0 +1,122 @@
import type { FileUIPart } from "ai";
import { useEffect, useRef } from "react";
import { useCopilotUIStore } from "./store";
/**
* Extract a prompt from the URL hash fragment.
* Supports: /copilot#prompt=URL-encoded-text
* Optionally auto-submits if ?autosubmit=true is in the query string.
* Returns null if no prompt is present.
*/
function extractPromptFromUrl(): {
prompt: string;
autosubmit: boolean;
filePart?: FileUIPart;
} | null {
if (typeof window === "undefined") return null;
const searchParams = new URLSearchParams(window.location.search);
const autosubmit = searchParams.get("autosubmit") === "true";
// Check sessionStorage first (used by workflow import for large prompts)
const storedPrompt = sessionStorage.getItem("importWorkflowPrompt");
if (storedPrompt) {
sessionStorage.removeItem("importWorkflowPrompt");
// Check for a pre-uploaded workflow file attached to this import
let filePart: FileUIPart | undefined;
const storedFile = sessionStorage.getItem("importWorkflowFile");
if (storedFile) {
sessionStorage.removeItem("importWorkflowFile");
try {
const { fileId, fileName, mimeType } = JSON.parse(storedFile);
// Validate fileId is a UUID to prevent path traversal
const UUID_RE =
/^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i;
if (typeof fileId === "string" && UUID_RE.test(fileId)) {
filePart = {
type: "file",
mediaType: mimeType ?? "application/json",
filename: fileName ?? "workflow.json",
url: `/api/proxy/api/workspace/files/${fileId}/download`,
};
}
} catch {
// ignore malformed stored data
}
}
// Clean up query params
const cleanURL = new URL(window.location.href);
cleanURL.searchParams.delete("autosubmit");
cleanURL.searchParams.delete("source");
window.history.replaceState(
null,
"",
`${cleanURL.pathname}${cleanURL.search}`,
);
return { prompt: storedPrompt.trim(), autosubmit, filePart };
}
// Fall back to URL hash (e.g. /copilot#prompt=...)
const hash = window.location.hash;
if (!hash) return null;
const hashParams = new URLSearchParams(hash.slice(1));
const prompt = hashParams.get("prompt");
if (!prompt || !prompt.trim()) return null;
// Clean up hash + autosubmit param only (preserve other query params)
const cleanURL = new URL(window.location.href);
cleanURL.hash = "";
cleanURL.searchParams.delete("autosubmit");
window.history.replaceState(
null,
"",
`${cleanURL.pathname}${cleanURL.search}`,
);
return { prompt: prompt.trim(), autosubmit };
}
/**
* Hook that checks for workflow import data in sessionStorage / URL on mount,
* and auto-submits a new CoPilot session when `autosubmit=true`.
*
* Extracted from useCopilotPage to keep that hook focused on page-level concerns.
*/
export function useWorkflowImportAutoSubmit({
createSession,
setPendingMessage,
pendingFilePartsRef,
}: {
createSession: () => Promise<string | undefined>;
setPendingMessage: (msg: string | null) => void;
pendingFilePartsRef: React.MutableRefObject<FileUIPart[]>;
}) {
const { setInitialPrompt } = useCopilotUIStore();
const hasProcessedUrlPrompt = useRef(false);
useEffect(() => {
if (hasProcessedUrlPrompt.current) return;
const urlPrompt = extractPromptFromUrl();
if (!urlPrompt) return;
hasProcessedUrlPrompt.current = true;
if (urlPrompt.autosubmit) {
if (urlPrompt.filePart) {
pendingFilePartsRef.current = [urlPrompt.filePart];
}
setPendingMessage(urlPrompt.prompt);
void createSession().catch(() => {
setPendingMessage(null);
setInitialPrompt(urlPrompt.prompt);
});
} else {
setInitialPrompt(urlPrompt.prompt);
}
}, [createSession, setInitialPrompt, setPendingMessage, pendingFilePartsRef]);
}

View File

@@ -169,7 +169,7 @@ function renderMarkdown(
[remarkMath, { singleDollarTextMath: false }], // Math support for LaTeX
]}
rehypePlugins={[
rehypeKatex, // Render math with KaTeX
[rehypeKatex, { strict: false }], // Render math with KaTeX
rehypeHighlight, // Syntax highlighting for code blocks
rehypeSlug, // Add IDs to headings
[rehypeAutolinkHeadings, { behavior: "wrap" }], // Make headings clickable

View File

@@ -8,33 +8,39 @@ import { Text } from "@/components/atoms/Text/Text";
import { useJumpBackIn } from "./useJumpBackIn";
export function JumpBackIn() {
const { agent, isLoading } = useJumpBackIn();
const { execution, isLoading } = useJumpBackIn();
if (isLoading || !agent) {
if (isLoading || !execution) {
return null;
}
const href = execution.libraryAgentId
? `/library/agents/${execution.libraryAgentId}?activeTab=runs&activeItem=${execution.id}`
: "#";
return (
<div className="flex items-center justify-between rounded-large border border-zinc-200 bg-gradient-to-r from-zinc-50 to-white px-5 py-4">
<div className="flex items-center gap-3">
<div className="flex h-9 w-9 items-center justify-center rounded-full bg-zinc-900">
<Lightning size={18} weight="fill" className="text-white" />
</div>
<div className="flex flex-col">
<Text variant="small" className="text-zinc-500">
Continue where you left off
</Text>
<Text variant="body-medium" className="text-zinc-900">
{agent.name}
</Text>
<div className="rounded-large bg-gradient-to-r from-zinc-200 via-zinc-200/60 to-indigo-200/50 p-[1px]">
<div className="flex items-center justify-between rounded-large bg-[#F6F7F8] px-5 py-4">
<div className="flex items-center gap-3">
<div className="flex h-9 w-9 items-center justify-center rounded-full bg-zinc-900">
<Lightning size={18} weight="fill" className="text-white" />
</div>
<div className="flex flex-col">
<Text variant="small" className="text-zinc-500">
{execution.statusLabel} · {execution.duration}
</Text>
<Text variant="body-medium" className="text-zinc-900">
{execution.agentName}
</Text>
</div>
</div>
<NextLink href={href}>
<Button variant="secondary" size="small" className="gap-1.5">
Jump Back In
<ArrowRight size={16} />
</Button>
</NextLink>
</div>
<NextLink href={`/library/agents/${agent.id}`}>
<Button variant="primary" size="small" className="gap-1.5">
Jump Back In
<ArrowRight size={16} />
</Button>
</NextLink>
</div>
);
}

View File

@@ -1,28 +1,82 @@
"use client";
import { useGetV2ListLibraryAgents } from "@/app/api/__generated__/endpoints/library/library";
import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import { okData } from "@/app/api/helpers";
import { useLibraryAgents } from "@/hooks/useLibraryAgents/useLibraryAgents";
import { useMemo } from "react";
function isActive(status: AgentExecutionStatus) {
return (
status === AgentExecutionStatus.RUNNING ||
status === AgentExecutionStatus.QUEUED ||
status === AgentExecutionStatus.REVIEW
);
}
function formatDuration(startedAt: Date | string | null | undefined): string {
if (!startedAt) return "";
const start = new Date(startedAt);
if (isNaN(start.getTime())) return "";
const ms = Date.now() - start.getTime();
if (ms < 0) return "";
const sec = Math.floor(ms / 1000);
if (sec < 5) return "a few seconds";
if (sec < 60) return `${sec}s`;
const min = Math.floor(sec / 60);
if (min < 60) return `${min}m ${sec % 60}s`;
const hr = Math.floor(min / 60);
return `${hr}h ${min % 60}m`;
}
function getStatusLabel(status: AgentExecutionStatus) {
if (status === AgentExecutionStatus.RUNNING) return "Running";
if (status === AgentExecutionStatus.QUEUED) return "Queued";
if (status === AgentExecutionStatus.REVIEW) return "Awaiting approval";
return "";
}
export function useJumpBackIn() {
const { data, isLoading } = useGetV2ListLibraryAgents(
{
page: 1,
page_size: 1,
sort_by: "updatedAt",
},
{
const { data: executions, isLoading: executionsLoading } =
useGetV1ListAllExecutions({
query: { select: okData },
},
);
});
// The API doesn't include execution data by default (include_executions is
// internal to the backend), so recent_executions is always empty here.
// We use the most recently updated agent as the "jump back in" candidate
// instead — updatedAt is the best available proxy for recent activity.
const agent = data?.agents[0] ?? null;
const { agentInfoMap, isRefreshing: agentsLoading } = useLibraryAgents();
const activeExecution = useMemo(() => {
if (!executions) return null;
const active = executions
.filter((e) => isActive(e.status))
.sort((a, b) => {
const aTime = a.started_at ? new Date(a.started_at).getTime() : 0;
const bTime = b.started_at ? new Date(b.started_at).getTime() : 0;
return bTime - aTime;
});
return active[0] ?? null;
}, [executions]);
const enriched = useMemo(() => {
if (!activeExecution) return null;
const info = agentInfoMap.get(activeExecution.graph_id);
return {
id: activeExecution.id,
agentName: info?.name ?? "Unknown Agent",
libraryAgentId: info?.library_agent_id,
status: activeExecution.status,
statusLabel: getStatusLabel(activeExecution.status),
duration: formatDuration(activeExecution.started_at),
};
}, [activeExecution, agentInfoMap]);
return {
agent,
isLoading,
execution: enriched,
isLoading: executionsLoading || agentsLoading,
};
}

View File

@@ -1,5 +1,5 @@
import LibraryImportDialog from "../LibraryImportDialog/LibraryImportDialog";
import { LibrarySearchBar } from "../LibrarySearchBar/LibrarySearchBar";
import LibraryUploadAgentDialog from "../LibraryUploadAgentDialog/LibraryUploadAgentDialog";
interface Props {
setSearchTerm: (value: string) => void;
@@ -10,13 +10,13 @@ export function LibraryActionHeader({ setSearchTerm }: Props) {
<>
<div className="mb-[32px] hidden items-center justify-center gap-4 md:flex">
<LibrarySearchBar setSearchTerm={setSearchTerm} />
<LibraryUploadAgentDialog />
<LibraryImportDialog />
</div>
{/* Mobile and tablet */}
<div className="flex flex-col gap-4 p-4 pt-[52px] md:hidden">
<div className="flex w-full justify-between">
<LibraryUploadAgentDialog />
<div className="flex w-full justify-between gap-2">
<LibraryImportDialog />
</div>
<div className="flex items-center justify-center">

Some files were not shown because too many files have changed in this diff Show More