Compare commits

..

82 Commits

Author SHA1 Message Date
Bentlybro
e01526cf52 fix: address latest CodeRabbit review comments
- Fix TruncatedLogger calls to use f-strings instead of %s format args (4 calls)
- Fix get_parallel_tool_calls_param return type: NotGiven → Omit
- Add comment clarifying LlmModelMigration is system-level data
- Add pagination input validation to prevent division by zero
2026-02-13 16:43:04 +00:00
Bentlybro
1704812f50 fix: address CodeRabbit review comments
- Fix cache initialization inconsistency in BlockSchema.__pydantic_init_subclass__
  (use None instead of {} to match clear_schema_cache behavior)
- Change logger.error to logger.debug in stagehand blocks (not an error condition)
2026-02-13 16:37:54 +00:00
Bentlybro
29f95e5b61 fix(builder): lowercase query for LLM model matching
The model slugs are lowercased in _get_llm_models() but the query wasn't,
causing case-sensitive matching failures (e.g., 'GPT-4' wouldn't match 'gpt 4').
2026-02-13 15:29:49 +00:00
Bentlybro
266526f08c fix(ws): close registry pubsub connection on shutdown
Track the registry_pubsub connection and close it in the finally block
to prevent Redis connection leaks on WebSocket server shutdown.
2026-02-13 15:12:55 +00:00
Bentlybro
26490e32d8 fix(schema): add composite index for LlmModelMigration active queries
Adds @@index([sourceModelSlug, isReverted]) to match the migration SQL.
This prevents Prisma migration conflicts and optimizes queries for
non-reverted migrations by source model slug.
2026-02-13 15:11:48 +00:00
Bentlybro
d6bf54281b fix(builder): normalize query hyphens for LLM model search matching
Apply same hyphen-to-space normalization to the query that's applied to
model slugs. This ensures 'gpt-4' matches 'gpt 4o' in search.
2026-02-13 15:11:01 +00:00
Bentlybro
a7835056c9 fix(llm): key migration cost overrides by targetModelSlug
The override dict should be keyed by the TARGET model slug (where nodes
migrated TO), not the source model slug. This ensures the custom cost
is applied when building costs for the model that nodes are actually using.
2026-02-13 15:10:09 +00:00
Bentlybro
cf3390d192 fix(llm): move count and validation inside transaction to prevent TOCTOU
Moves the node count query and replacement model validation inside the
transaction to prevent race conditions where nodes could be created
between the count and the actual deletion.

All mutation logic is now atomic within a single transaction.
2026-02-13 15:09:26 +00:00
Bentlybro
d8007f74e9 fix(llm): handle provider-prefixed slugs in o-series detection
The regex now matches 'o' followed by digit at start OR after '/' separator.
This fixes detection for slugs like 'openai/o1-mini' from OpenRouter.
2026-02-13 15:08:15 +00:00
Bentlybro
4d341c55c5 Update openapi.json 2026-02-13 15:03:54 +00:00
Bentlybro
01ef7e1925 refactor(llm): extract model resolution logic into resolve_model_for_call()
- Add ModelUnavailableError for clearer error handling (extends ValueError for backward compat)
- Add ResolvedModel dataclass to hold resolved model metadata
- Extract all model resolution logic (disabled check, fallback, registry refresh) into resolve_model_for_call()
- Simplify llm_call() to use the new function (77 lines → 7 lines)

This improves maintainability by separating concerns:
- resolve_model_for_call() handles model resolution
- llm_call() handles the actual LLM API call
2026-02-13 14:50:57 +00:00
Bentlybro
5baf1a0f60 Use NotGiven sentinels and add migration timestamps
Update LLM integrations and migration to match SDK and schema changes: switch from anthropic.omit/openai.Omit types to anthropic.NOT_GIVEN/openai.NotGiven in llm.py (and update type hints), stop converting createdAt/revertedAt to ISO strings in DB mapping to preserve datetime types, and add createdAt/updatedAt (NOW()) to LlmModel and LlmModelCost inserts in the migration SQL so new rows populate timestamps.
2026-02-13 14:48:21 +00:00
Bentlybro
9fc5d465da Add BlockSchema cache clearing & fix imports
Make BlockSchema.cached_jsonschema default to None and add clear_schema_cache and clear_all_schema_caches (recursive) so JSON schemas can be invalidated and regenerated. Update modules (rest_api, admin llm_routes, executor llm_registry_init) to import BlockSchema from backend.blocks._base so the new cache-clearing API is used when refreshing LLM costs/discriminator mappings. Also switch cache sentinel from {} to None to avoid truthiness preventing regeneration.
2026-02-13 11:56:47 +00:00
Bentlybro
c797f4e1f2 Update ModelsTable.tsx 2026-02-13 11:53:07 +00:00
Bentlybro
05033610bb Reorder and remove unused imports
Reorders the llm_registry import in backend/blocks/llm.py to group data imports together and updates import ordering. Removes unused imports (update_schema_with_llm_registry, NodeExecutionStats, ProviderName) from backend/data/block.py to clean up dead imports and simplify the module surface.
2026-02-13 11:45:55 +00:00
Bentlybro
76f3a89be8 Merge branch 'dev' into add-llm-manager-ui 2026-02-13 11:17:04 +00:00
Bentlybro
df7bb57c83 Update Table.tsx 2026-02-12 11:14:22 +00:00
Bentlybro
b11d46d246 Make LLM cost refresh async and support overrides
Convert refresh_llm_costs to async and update all callers to await it. Implement async _build_llm_costs_from_registry which queries prisma LlmModelMigration for active migrations with customCreditCost and applies per-model pricing overrides when present (with a safe try/except). Add two SQL migrations: a composite index on LlmModelMigration to optimize override queries and a sync migration to add/remove/update LLM models and their costs. This ensures billing uses migration-provided custom pricing and that registry refreshes correctly await cost recalculation.
2026-02-12 11:11:01 +00:00
Bentlybro
8e6bc5eb48 Update route examples and compress_context call
Update doc examples in admin/llm_routes.py to use the new /api/llm/admin/... path. Change compress_context invocation in blocks/llm.py to pass client=None (truncation-only, no LLM summarization) instead of using the lossy_ok parameter.
2026-02-12 09:07:24 +00:00
Bentlybro
8b2b0c853a Update openapi.json 2026-02-11 14:05:32 +00:00
Bentlybro
ffb86cced4 Merge remote-tracking branch 'origin/dev' into add-llm-manager-ui 2026-02-11 13:45:56 +00:00
Bentlybro
fea46a6d28 Use LlmModel and simplify cache clearing
Refactor LLM handling and cache logic: instantiate and pass a LlmModel instance to generate_model_label (rename model_enum -> model) to ensure consistent enum usage when building labels. Remove hasattr guards and directly clear the v2 builder caches during runtime state refresh so cached providers and search results are always attempted to be cleared. Update the AIConditionBlock test fixture to use LlmModel.default() instead of a hardcoded gpt-4o string. These changes simplify the code and standardize LlmModel usage.
2026-02-10 15:32:36 +00:00
Nicholas Tindle
f2f779e54f Merge branch 'dev' into add-llm-manager-ui 2026-01-27 10:39:47 -06:00
Bentlybro
dda9a9b010 Update llm.py 2026-01-23 15:07:55 +00:00
Bentlybro
c1d3604682 Improve LlmModelMeta slug generation logic
Slug generation now checks for exact matches in the registry before applying the letter-digit hyphen transformation. This ensures that model names like 'o1' are preserved as-is if present in the registry, improving compatibility with dynamic model slugs.
2026-01-23 14:59:49 +00:00
Bentlybro
dfbfbdf696 Add pagination and lazy loading to models table
Implemented client-side pagination for the LLM models table in the admin UI, including a 'Load More' button and loading state. The backend now only returns enabled models for selection. This improves performance and usability when managing large numbers of models.
2026-01-23 12:12:32 +00:00
Bentlybro
994ebc2cf8 Merge branch 'dev' into add-llm-manager-ui 2026-01-22 14:38:24 +00:00
Bentlybro
2245d115d3 Refactor form field extraction and validation utilities
Introduced utility functions for extracting and validating required fields from FormData, reducing code duplication and improving error handling across LLM provider, model, and creator actions. Updated all relevant actions to use these new utilities for consistent validation.
2026-01-22 14:07:59 +00:00
Bentlybro
5238b1b71c Add input validation to LLM provider/model actions
Improves robustness by validating and sanitizing form data in deleteLlmProviderAction and createLlmModelAction. Ensures required fields are present and context window and credit cost are valid numbers before proceeding.
2026-01-22 13:51:54 +00:00
Bentlybro
4fb86b2738 Update actions.ts 2026-01-22 13:44:46 +00:00
Bentlybro
e10128e9f0 Improve LLM provider form data handling
Parse 'default_credential_id' and 'default_credential_type' from form data instead of using static values. Update boolean field parsing to use getAll and check for 'on' to better support multiple checkbox inputs.
2026-01-22 13:41:37 +00:00
Bentlybro
b205d5863e format 2026-01-22 13:13:46 +00:00
Bentlybro
6da2dee62f Add edit and delete functionality for LLM providers
Introduces backend API and frontend UI for editing and deleting LLM providers. Providers can only be deleted if they have no associated models. Includes new modals for editing and deleting providers, updates provider list to show model count and actions, and adds corresponding actions and API integration.
2026-01-22 13:08:29 +00:00
Bentlybro
324ebc1e06 Fix LLM model creation, DB JSON handling, and migration logic
Corrects handling of JSON fields in the backend by wrapping metadata and capabilities in prisma.Json, and updates model/creator relationship to use Prisma connect syntax. Updates LlmModelMigration timestamps to use datetime objects. Adjusts SQL migrations to avoid duplicate table/constraint creation and adds conditional foreign key logic. Fixes frontend LLM model form to properly handle is_enabled checkbox state.
2026-01-22 12:37:31 +00:00
Bentlybro
ce2ebee838 Refactor LlmModel priceTier and add creator support
Removes the priceTier field from the LlmModel seed migration and moves price tier assignments to a dedicated migration. Adds new columns to LlmModel for creatorId and isRecommended, creates the LlmModelCreator table, and updates priceTier values for existing models to support enhanced LLM Picker UI functionality.
2026-01-22 12:04:13 +00:00
Bentlybro
0597573b6c Merge branch 'dev' into add-llm-manager-ui 2026-01-22 11:52:43 +00:00
Bentlybro
9496b33a1c Add price tier to LLM model metadata and registry
Introduces a 'priceTier' attribute (1=cheapest, 2=medium, 3=expensive) to LlmModel in the database schema, model metadata, and registry logic. Updates migrations and seed data to support price tier for LLM models, enabling cost-based filtering and selection in the LLM Picker UI.
2026-01-22 11:52:37 +00:00
Bentlybro
8e3aabd558 Use effective model for parallel tool calls param
Replaces usage of llm_model with effective_model when resolving parallel tool calls parameters. This ensures model-specific parameter resolution uses the actual model in use, including after any fallback.
2026-01-22 11:08:09 +00:00
Bentlybro
fbef81c0c9 Improve LLM model iteration and metadata handling
Added __iter__ to LlmModelMeta for dynamic model iteration and updated metadata retrieval to handle missing registry entries gracefully. Fixed BlockSchema cached_jsonschema initialization and improved discriminator mapping refresh logic. Updated NodeInputs to display beautified string if label is missing.
2026-01-22 10:00:06 +00:00
Bentlybro
226d2ef4a0 Merge branch 'dev' into add-llm-manager-ui 2026-01-21 23:46:07 +00:00
Bentlybro
42f8a26ee1 Allow LLM model deletion without replacement if unused
Updated backend logic and API schema to permit deleting an LLM model without specifying a replacement if no workflow nodes are using it. Adjusted tests to cover both cases (with and without usage), made replacement_model_slug optional in the response model, and updated OpenAPI spec accordingly.
2026-01-21 23:26:52 +00:00
Bentlybro
8d021fe76c Allow LLM model deletion without mandatory migration
Backend and frontend logic updated to allow deletion of LLM models without requiring a replacement if no workflows use the model. The API, UI, and OpenAPI spec now conditionally require a replacement model only when migration is necessary, improving admin workflow and error handling.
2026-01-21 22:23:26 +00:00
Bentlybro
cb10907bf6 Add pagination to LLM model listing endpoints
Introduces pagination support to the LLM model listing APIs in both admin and public routes. Updates the response model to include pagination metadata, modifies database queries to support paging, and adjusts related tests. Also renames model_types.py to model.py for consistency.
2026-01-21 21:00:18 +00:00
Bentlybro
54084fe597 Refactor LLM admin route tests for improved mocking and snapshots
Updated tests to use actual model and response classes from llm_model instead of dicts, ensuring more accurate type usage. Snapshot assertions now serialize responses to JSON strings for compatibility. Cleaned up test_delete_llm_model_missing_replacement to remove unnecessary mocking.
2026-01-19 14:28:33 +00:00
Bentlybro
8f5d851908 Set router prefix in llm_routes_test.py
Added the '/admin/llm' prefix to the included router in the test setup to match the expected route structure.
2026-01-19 14:16:08 +00:00
Bentlybro
358a21c6fc prettier 2026-01-19 14:15:04 +00:00
Bentlybro
336fc43b24 Add unique constraint to LlmModelCost on model, provider, unit
Introduces a unique index on the combination of llmModelId, credentialProvider, and unit in the LlmModelCost table to prevent duplicate cost entries. Updates the seed migration to handle conflicts on this unique key by doing nothing on conflict.
2026-01-19 13:39:20 +00:00
Bentlybro
cfb1613877 Update hidden credential_type input logic in EditModelModal
The hidden input for credential_type now prioritizes cost.credential_type, then provider.default_credential_type, and defaults to 'api_key' if neither is set. This ensures the correct credential type is submitted based on available data.
2026-01-16 14:29:46 +00:00
Bentlybro
386eea741c Rename cost_unit field to unit in LLM model forms
Updated form field and related code references from 'cost_unit' to 'unit' in both create and update LLM model actions, as well as in the EditModelModal component. This change ensures consistency in naming and aligns with expected backend parameters.
2026-01-16 14:19:04 +00:00
Bentlybro
e5c6809d9c Improve LLM model cost unit handling and cache refresh
Adds explicit handling of the cost unit in LLM model creation and update actions, ensuring the unit is always set (defaulting to 'RUN'). Updates the EditModelModal to include a hidden cost_unit input. Refactors backend LLM runtime state refresh logic to improve error handling and logging for cache clearing operations.
2026-01-16 13:58:19 +00:00
Bentlybro
963b8090cc Fix admin LLM API routes and improve model migration
Removes redundant route prefix in backend admin LLM API, updates OpenAPI paths to match, and improves parameterization for batch node updates in model migration and revert logic. Also adds stricter validation for replacement model slug in frontend actions and sets button type in EditModelModal.
2026-01-16 12:51:06 +00:00
Bentlybro
eab93aba2b Add options field to BlockIOStringSubSchema type
Introduces an optional 'options' array to BlockIOStringSubSchema, allowing specification of selectable string values with labels and optional descriptions.
2026-01-16 10:13:33 +00:00
Bentlybro
47a70cdbd0 Merge branch 'dev' into add-llm-manager-ui 2026-01-16 09:39:36 +00:00
Bentlybro
69c9136060 Improve LLM registry consistency and frontend UX
Backend: Refactored LLM registry state updates to use atomic swaps for consistency, made Redis notification publishing async, and improved schema/discriminator mapping access to prevent external mutation. Added stricter slug validation for model creation. Frontend: Enhanced Edit and Delete Model modals to refresh data after actions and show error states, and wrapped the LLM Registry Dashboard in an error boundary for better error handling.
2026-01-12 12:52:40 +00:00
Bentlybro
6ed8bb4f14 Clarify custom pricing override for LLM migrations
Improved documentation and comments for the custom_credit_cost field in backend, frontend, and schema files to clarify its use as a billing override during LLM model migrations. Also removed unused LLM registry types and API methods from frontend code, and renamed useLlmRegistryPage.ts to getLlmRegistryPage.ts for consistency.
2026-01-12 11:40:49 +00:00
Bentlybro
6cf28e58d3 Improve LLM model default selection and admin actions
Backend logic for selecting the default LLM model now prioritizes the recommended model, with improved fallbacks and error handling if no models are enabled. The migration enforces a single recommended model at the database level. Frontend admin actions for LLM models and providers now correctly interpret form values for boolean fields and fix the return type for the delete action.
2026-01-09 15:18:54 +00:00
Bentlybro
632ef24408 Add recommended LLM model feature to admin UI and API
Introduces the ability for admins to mark a model as the recommended default via a new boolean field `isRecommended` on LlmModel. Adds backend endpoints and logic to set, get, and persist the recommended model, including a migration and schema update. Updates the frontend admin UI to allow selecting and displaying the recommended model, and reflects the recommended status in model tables and dropdowns.
2026-01-07 19:43:16 +00:00
Bentlybro
6dc767aafa Improve admin LLM registry UX and error handling
Adds user feedback and error handling to LLM registry modals (add/edit creator, model, provider) in the admin UI, including loading states and error messages. Ensures atomic updates for model costs in the backend using transactions. Improves display of creator website URLs and handles the case where no LLM models are available in analytics config. Updates icon usage and removes unnecessary 'use server' directive.
2026-01-07 14:17:37 +00:00
Bentlybro
23e37fd163 Replace delete button with DeleteCreatorModal
Refactored the creator deletion flow in CreatorsTable to use a new DeleteCreatorModal component, providing a confirmation dialog and improved error handling. The previous DeleteCreatorButton was removed and replaced for better user experience and safety.
2026-01-06 14:22:21 +00:00
Bentlybro
63869fe710 format 2026-01-06 13:40:16 +00:00
Bentlybro
90ae75d475 Delete settings.local.json 2026-01-06 13:07:46 +00:00
Bentlybro
9b6dc3be12 prettier 2026-01-06 13:01:51 +00:00
Bentlybro
9b8b6252c5 Refactor LLM registry admin backend and frontend
Refactored backend imports and test mocks to use new admin LLM routes location. Cleaned up and reordered imports for clarity and consistency. Improved code formatting and readability across backend and frontend files. Renamed useLlmRegistryPage to getLlmRegistryPageData for clarity and updated all usages. No functional changes to business logic.
2026-01-06 12:57:33 +00:00
Bentlybro
0d321323f5 Add GPT-5.2 model and admin LLM endpoints
Introduces a migration to add the GPT-5.2 model and updates the O3 model slug in the database. Refactors backend LLM model registry usage for search and migration logic. Expands the OpenAPI spec with new admin endpoints for managing LLM models, providers, creators, and migrations.
2026-01-06 12:46:20 +00:00
Bentlybro
3ee3ea8f02 Merge branch 'dev' into add-llm-manager-ui 2026-01-06 10:28:43 +00:00
Bentlybro
7a842d35ae Refactor LLM admin to use generated API and types
Replaces usage of the custom BackendApi client and legacy types in admin LLM actions and components with generated OpenAPI endpoints and types. Updates API calls, error handling, and type imports throughout the admin LLM dashboard. Also corrects operationId fields in backend routes and OpenAPI spec for consistency.
2026-01-06 09:43:15 +00:00
Bentlybro
07e8568f57 Refactor LLM admin UI for improved consistency and API support
Refactored admin LLM actions and components to improve code organization, update color schemes to use design tokens, and enhance UI consistency. Updated API types and endpoints to support model creators and migrations, and switched tables to use shared Table components. Added and documented new API endpoints for model migrations, creators, and usage in openapi.json.
2026-01-05 17:10:04 +00:00
Bentlybro
13a0caa5d8 Improve model modal UX and credential provider selection
Add auto-selection of creator based on provider in AddModelModal for better usability. Update EditModelModal to use a select dropdown for credential provider, add helper text, and set credential_type as a hidden default input.
2026-01-05 16:01:36 +00:00
Bentlybro
664523a721 Refactor LLM model cost and update logic, remove 'Enabled' checkbox
Improves backend handling of LLM model cost updates by separating scalar and relation field updates, ensuring costs are deleted and recreated as needed. Optional cost fields are now only included if present, and metadata is handled as a Prisma Json type. On the frontend, removes the 'Enabled' checkbox from the EditModelModal component.
2026-01-05 15:56:45 +00:00
Bentlybro
33b103d09b Improve LLM model migration and add AgentNode index
Refactored model migration and revert logic for atomicity and consistency, including transactional node selection and updates. Enhanced revert API to support optional re-enabling of source models and reporting of nodes not reverted. Added a database index on AgentNode.constantInput->>'model' to optimize migration queries and performance.
2026-01-05 15:22:33 +00:00
Bentlybro
2e3fc99caa Add LLM model creator support to registry and admin UI
Introduces the LlmModelCreator entity to distinguish model creators (e.g., OpenAI, Meta) from providers, with full CRUD API endpoints, database migration, and Prisma schema updates. Backend and frontend are updated to support associating models with creators, including admin UI for managing creators and selecting them when creating or editing models. Existing models are backfilled with known creators via migration.
2026-01-05 10:17:00 +00:00
Bently
52c7b223df Add migration management for LLM models
Introduced a new LlmModelMigration model to track migrations when disabling LLM models, allowing for revert capability. Updated the toggle model API to create migration records with optional reason and custom pricing. Added endpoints for listing and reverting migrations, along with corresponding frontend actions and UI components to manage migrations effectively. Enhanced the admin dashboard to display active migrations, improving overall usability and tracking of model changes.
2025-12-19 00:06:03 +00:00
Bently
24d86fde30 Enhance LLM model toggle functionality with migration support
Updated the toggle LLM model API to include an optional migration feature, allowing workflows to be migrated to a specified replacement model when disabling a model. Refactored related request and response models to accommodate this change. Improved error handling and logging for better debugging. Updated frontend actions and components to support the new migration parameter.
2025-12-18 23:32:41 +00:00
Bentlybro
df7be39724 Refactor add model/provider forms to modal dialogs
Replaces AddModelForm and AddProviderForm components with AddModelModal and AddProviderModal, converting the add model/provider flows to use modal dialogs instead of inline forms. Updates LlmRegistryDashboard to use the new modal components and removes dropdown/form selection logic for a cleaner UI.
2025-12-13 19:39:30 +00:00
Bentlybro
8c7b1af409 Refactor LLM registry to modular structure and improve admin UI
Moved LLM registry backend code into a dedicated llm_registry module with submodules for model types, notifications, schema utilities, and registry logic. Updated all backend imports to use the new structure. On the frontend, redesigned the admin LLM registry page with a dashboard layout, modularized data fetching, and improved forms for adding/editing providers and models. Updated UI components for better usability and maintainability.
2025-12-12 11:32:28 +00:00
Bentlybro
b6e2f05b63 Refactor LlmModel to support dynamic registry slugs
Replaces hardcoded LlmModel enum values with a dynamic approach that accepts any model slug from the registry. Updates block defaults to use a default_factory method that pulls the preferred model from the registry. Refactors model validation, migration, and admin analytics routes to use registry-based model lists, ensuring only enabled models are selectable and recommended. Adds get_default_model_slug to llm_registry for consistent default selection.
2025-12-09 15:49:44 +00:00
Bentlybro
7435739053 Add fallback logic for disabled LLM models
Introduces fallback selection for disabled LLM models in llm_call, preferring enabled models from the same provider. Updates registry utilities to support fallback lookup, model info retrieval, and validation of all known model slugs. Schema utilities now keep all known models in validation enums while showing only enabled models in UI options.
2025-12-08 11:29:31 +00:00
Bentlybro
a97fdba554 Restrict LLM model and provider listings to enabled items
Updated public LLM model and provider listing endpoints to only return enabled models and providers. Refactored database access functions to support filtering by enabled status, and improved transaction safety for model deletion. Adjusted tests and internal documentation to reflect these changes.
2025-12-04 15:56:25 +00:00
Bentlybro
ec705bbbcf format 2025-12-02 14:49:03 +00:00
Bentlybro
7fe6b576ae Add LLM model deletion and migration feature
Introduces backend and frontend support for deleting LLM models with automatic workflow migration to a replacement model. Adds API endpoints, database logic, response models, frontend modal, and actions for safe deletion, including usage count display and error handling. Updates table components to use new modal and refactors table imports.
2025-12-02 14:41:13 +00:00
Bentlybro
dfc42003a1 Refactor LLM registry integration and schema updates
Moved LLM registry schema update logic to a shared utility (llm_schema_utils.py) and refactored block and credentials schema post-processing to use this helper. Extracted executor registry initialization and notification handling into llm_registry_init.py for better separation of concerns. Updated manager.py to use new initialization and subscription functions, improving maintainability and clarity of LLM registry refresh logic.
2025-12-01 17:55:43 +00:00
Bentlybro
6bbeb22943 Refactor LLM model registry to use database
Migrates LLM model metadata and cost configuration from static code to a dynamic database-driven registry. Adds new backend modules for LLM registry and model types, updates block and cost configuration logic to fetch model info and costs from the database, and ensures block schemas and UI options reflect enabled/disabled models. This enables dynamic management of LLM models and costs via the admin UI and database migrations.
2025-12-01 14:37:46 +00:00
157 changed files with 10144 additions and 13741 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -40,48 +40,6 @@ jobs:
git checkout -b "$BRANCH_NAME"
echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT
# Backend Python/Poetry setup (so Claude can run linting/tests)
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Set up Python dependency cache
uses: actions/cache@v5
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
- name: Install Poetry
run: |
cd autogpt_platform/backend
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
echo "$HOME/.local/bin" >> $GITHUB_PATH
- name: Install Python dependencies
working-directory: autogpt_platform/backend
run: poetry install
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate && poetry run gen-prisma-stub
# Frontend Node.js/pnpm setup (so Claude can run linting/tests)
- name: Enable corepack
run: corepack enable
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: "22"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Install JavaScript dependencies
working-directory: autogpt_platform/frontend
run: pnpm install --frozen-lockfile
- name: Get CI failure details
id: failure_details
uses: actions/github-script@v8

View File

@@ -77,15 +77,27 @@ jobs:
run: poetry run prisma generate && poetry run gen-prisma-stub
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Enable corepack
run: corepack enable
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: "22"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Enable corepack
run: corepack enable
- name: Set pnpm store directory
run: |
pnpm config set store-dir ~/.pnpm-store
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies
uses: actions/cache@v5
with:
path: ~/.pnpm-store
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install JavaScript dependencies
working-directory: autogpt_platform/frontend

View File

@@ -93,15 +93,27 @@ jobs:
run: poetry run prisma generate && poetry run gen-prisma-stub
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Enable corepack
run: corepack enable
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: "22"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Enable corepack
run: corepack enable
- name: Set pnpm store directory
run: |
pnpm config set store-dir ~/.pnpm-store
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies
uses: actions/cache@v5
with:
path: ~/.pnpm-store
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install JavaScript dependencies
working-directory: autogpt_platform/frontend

View File

@@ -62,7 +62,7 @@ jobs:
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v4
uses: github/codeql-action/init@v3
with:
languages: ${{ matrix.language }}
build-mode: ${{ matrix.build-mode }}
@@ -93,6 +93,6 @@ jobs:
exit 1
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v4
uses: github/codeql-action/analyze@v3
with:
category: "/language:${{matrix.language}}"

View File

@@ -7,10 +7,6 @@ on:
- "docs/integrations/**"
- "autogpt_platform/backend/backend/blocks/**"
concurrency:
group: claude-docs-review-${{ github.event.pull_request.number }}
cancel-in-progress: true
jobs:
claude-review:
# Only run for PRs from members/collaborators
@@ -95,35 +91,5 @@ jobs:
3. Read corresponding documentation files to verify accuracy
4. Provide your feedback as a PR comment
## IMPORTANT: Comment Marker
Start your PR comment with exactly this HTML comment marker on its own line:
<!-- CLAUDE_DOCS_REVIEW -->
This marker is used to identify and replace your comment on subsequent runs.
Be constructive and specific. If everything looks good, say so!
If there are issues, explain what's wrong and suggest how to fix it.
- name: Delete old Claude review comments
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
# Get all comment IDs with our marker, sorted by creation date (oldest first)
COMMENT_IDS=$(gh api \
repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments \
--jq '[.[] | select(.body | contains("<!-- CLAUDE_DOCS_REVIEW -->"))] | sort_by(.created_at) | .[].id')
# Count comments
COMMENT_COUNT=$(echo "$COMMENT_IDS" | grep -c . || true)
if [ "$COMMENT_COUNT" -gt 1 ]; then
# Delete all but the last (newest) comment
echo "$COMMENT_IDS" | head -n -1 | while read -r COMMENT_ID; do
if [ -n "$COMMENT_ID" ]; then
echo "Deleting old review comment: $COMMENT_ID"
gh api -X DELETE repos/${{ github.repository }}/issues/comments/$COMMENT_ID
fi
done
else
echo "No old review comments to clean up"
fi

View File

@@ -1,39 +0,0 @@
name: PR Overlap Detection
on:
pull_request:
types: [opened, synchronize, reopened]
branches:
- dev
- master
permissions:
contents: read
pull-requests: write
jobs:
check-overlaps:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0 # Need full history for merge testing
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Configure git
run: |
git config user.email "github-actions[bot]@users.noreply.github.com"
git config user.name "github-actions[bot]"
- name: Run overlap detection
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# Always succeed - this check informs contributors, it shouldn't block merging
continue-on-error: true
run: |
python .github/scripts/detect_overlaps.py ${{ github.event.pull_request.number }}

View File

@@ -104,12 +104,6 @@ TWITTER_CLIENT_SECRET=
# Make a new workspace for your OAuth APP -- trust me
# https://linear.app/settings/api/applications/new
# Callback URL: http://localhost:3000/auth/integrations/oauth_callback
LINEAR_API_KEY=
# Linear project and team IDs for the feature request tracker.
# Find these in your Linear workspace URL: linear.app/<workspace>/project/<project-id>
# and in team settings. Used by the chat copilot to file and search feature requests.
LINEAR_FEATURE_REQUEST_PROJECT_ID=
LINEAR_FEATURE_REQUEST_TEAM_ID=
LINEAR_CLIENT_ID=
LINEAR_CLIENT_SECRET=

View File

@@ -66,19 +66,13 @@ ENV POETRY_HOME=/opt/poetry \
DEBIAN_FRONTEND=noninteractive
ENV PATH=/opt/poetry/bin:$PATH
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
# for the bash_exec MCP tool.
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.13 \
python3-pip \
ffmpeg \
imagemagick \
jq \
ripgrep \
tree \
bubblewrap \
&& rm -rf /var/lib/apt/lists/*
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*

View File

@@ -122,6 +122,24 @@ class ConnectionManager:
return len(connections)
async def broadcast_to_all(self, *, method: WSMethod, data: dict) -> int:
"""Broadcast a message to all active websocket connections."""
message = WSMessage(
method=method,
data=data,
).model_dump_json()
connections = tuple(self.active_connections)
if not connections:
return 0
await asyncio.gather(
*(connection.send_text(message) for connection in connections),
return_exceptions=True,
)
return len(connections)
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
if channel_key not in self.subscriptions:
self.subscriptions[channel_key] = set()

View File

@@ -176,30 +176,64 @@ async def get_execution_analytics_config(
# Return with provider prefix for clarity
return f"{provider_name}: {model_name}"
# Include all LlmModel values (no more filtering by hardcoded list)
recommended_model = LlmModel.GPT4O_MINI.value
for model in LlmModel:
# Get all models from the registry (dynamic, not hardcoded enum)
from backend.data import llm_registry
from backend.server.v2.llm import db as llm_db
# Get the recommended model from the database (configurable via admin UI)
recommended_model_slug = await llm_db.get_recommended_model_slug()
# Build the available models list
first_enabled_slug = None
for registry_model in llm_registry.iter_dynamic_models():
# Only include enabled models in the list
if not registry_model.is_enabled:
continue
# Track first enabled model as fallback
if first_enabled_slug is None:
first_enabled_slug = registry_model.slug
model = LlmModel(registry_model.slug)
label = generate_model_label(model)
# Add "(Recommended)" suffix to the recommended model
if model.value == recommended_model:
if registry_model.slug == recommended_model_slug:
label += " (Recommended)"
available_models.append(
ModelInfo(
value=model.value,
value=registry_model.slug,
label=label,
provider=model.provider,
provider=registry_model.metadata.provider,
)
)
# Sort models by provider and name for better UX
available_models.sort(key=lambda x: (x.provider, x.label))
# Handle case where no models are available
if not available_models:
logger.warning(
"No enabled LLM models found in registry. "
"Ensure models are configured and enabled in the LLM Registry."
)
# Provide a placeholder entry so admins see meaningful feedback
available_models.append(
ModelInfo(
value="",
label="No models available - configure in LLM Registry",
provider="none",
)
)
# Use the DB recommended model, or fallback to first enabled model
final_recommended = recommended_model_slug or first_enabled_slug or ""
return ExecutionAnalyticsConfig(
available_models=available_models,
default_system_prompt=DEFAULT_SYSTEM_PROMPT,
default_user_prompt=DEFAULT_USER_PROMPT,
recommended_model=recommended_model,
recommended_model=final_recommended,
)

View File

@@ -0,0 +1,593 @@
import logging
import autogpt_libs.auth
import fastapi
from backend.data import llm_registry
from backend.data.block_cost_config import refresh_llm_costs
from backend.server.v2.llm import db as llm_db
from backend.server.v2.llm import model as llm_model
logger = logging.getLogger(__name__)
router = fastapi.APIRouter(
tags=["llm", "admin"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_admin_user)],
)
async def _refresh_runtime_state() -> None:
"""Refresh the LLM registry and clear all related caches to ensure real-time updates."""
logger.info("Refreshing LLM registry runtime state...")
try:
# Refresh registry from database
await llm_registry.refresh_llm_registry()
await refresh_llm_costs()
# Clear block schema caches so they're regenerated with updated model options
from backend.blocks._base import BlockSchema
BlockSchema.clear_all_schema_caches()
logger.info("Cleared all block schema caches")
# Clear the /blocks endpoint cache so frontend gets updated schemas
try:
from backend.api.features.v1 import _get_cached_blocks
_get_cached_blocks.cache_clear()
logger.info("Cleared /blocks endpoint cache")
except Exception as e:
logger.warning("Failed to clear /blocks cache: %s", e)
# Clear the v2 builder caches
try:
from backend.api.features.builder import db as builder_db
builder_db._get_all_providers.cache_clear()
logger.info("Cleared v2 builder providers cache")
builder_db._build_cached_search_results.cache_clear()
logger.info("Cleared v2 builder search results cache")
except Exception as e:
logger.debug("Could not clear v2 builder cache: %s", e)
# Notify all executor services to refresh their registry cache
from backend.data.llm_registry import publish_registry_refresh_notification
await publish_registry_refresh_notification()
logger.info("Published registry refresh notification")
except Exception as exc:
logger.exception(
"LLM runtime state refresh failed; caches may be stale: %s", exc
)
@router.get(
"/providers",
summary="List LLM providers",
response_model=llm_model.LlmProvidersResponse,
)
async def list_llm_providers(include_models: bool = True):
providers = await llm_db.list_providers(include_models=include_models)
return llm_model.LlmProvidersResponse(providers=providers)
@router.post(
"/providers",
summary="Create LLM provider",
response_model=llm_model.LlmProvider,
)
async def create_llm_provider(request: llm_model.UpsertLlmProviderRequest):
provider = await llm_db.upsert_provider(request=request)
await _refresh_runtime_state()
return provider
@router.patch(
"/providers/{provider_id}",
summary="Update LLM provider",
response_model=llm_model.LlmProvider,
)
async def update_llm_provider(
provider_id: str,
request: llm_model.UpsertLlmProviderRequest,
):
provider = await llm_db.upsert_provider(request=request, provider_id=provider_id)
await _refresh_runtime_state()
return provider
@router.delete(
"/providers/{provider_id}",
summary="Delete LLM provider",
response_model=dict,
)
async def delete_llm_provider(provider_id: str):
"""
Delete an LLM provider.
A provider can only be deleted if it has no associated models.
Delete all models from the provider first before deleting the provider.
"""
try:
await llm_db.delete_provider(provider_id)
await _refresh_runtime_state()
logger.info("Deleted LLM provider '%s'", provider_id)
return {"success": True, "message": "Provider deleted successfully"}
except ValueError as e:
logger.warning("Failed to delete provider '%s': %s", provider_id, e)
raise fastapi.HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.exception("Failed to delete provider '%s': %s", provider_id, e)
raise fastapi.HTTPException(status_code=500, detail=str(e))
@router.get(
"/models",
summary="List LLM models",
response_model=llm_model.LlmModelsResponse,
)
async def list_llm_models(
provider_id: str | None = fastapi.Query(default=None),
page: int = fastapi.Query(default=1, ge=1, description="Page number (1-indexed)"),
page_size: int = fastapi.Query(
default=50, ge=1, le=100, description="Number of models per page"
),
):
return await llm_db.list_models(
provider_id=provider_id, page=page, page_size=page_size
)
@router.post(
"/models",
summary="Create LLM model",
response_model=llm_model.LlmModel,
)
async def create_llm_model(request: llm_model.CreateLlmModelRequest):
model = await llm_db.create_model(request=request)
await _refresh_runtime_state()
return model
@router.patch(
"/models/{model_id}",
summary="Update LLM model",
response_model=llm_model.LlmModel,
)
async def update_llm_model(
model_id: str,
request: llm_model.UpdateLlmModelRequest,
):
model = await llm_db.update_model(model_id=model_id, request=request)
await _refresh_runtime_state()
return model
@router.patch(
"/models/{model_id}/toggle",
summary="Toggle LLM model availability",
response_model=llm_model.ToggleLlmModelResponse,
)
async def toggle_llm_model(
model_id: str,
request: llm_model.ToggleLlmModelRequest,
):
"""
Toggle a model's enabled status, optionally migrating workflows when disabling.
If disabling a model and `migrate_to_slug` is provided, all workflows using
this model will be migrated to the specified replacement model before disabling.
A migration record is created which can be reverted later using the revert endpoint.
Optional fields:
- `migration_reason`: Reason for the migration (e.g., "Provider outage")
- `custom_credit_cost`: Custom pricing override for billing during migration
"""
try:
result = await llm_db.toggle_model(
model_id=model_id,
is_enabled=request.is_enabled,
migrate_to_slug=request.migrate_to_slug,
migration_reason=request.migration_reason,
custom_credit_cost=request.custom_credit_cost,
)
await _refresh_runtime_state()
if result.nodes_migrated > 0:
logger.info(
"Toggled model '%s' to %s and migrated %d nodes to '%s' (migration_id=%s)",
result.model.slug,
"enabled" if request.is_enabled else "disabled",
result.nodes_migrated,
result.migrated_to_slug,
result.migration_id,
)
return result
except ValueError as exc:
logger.warning("Model toggle validation failed: %s", exc)
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to toggle LLM model %s: %s", model_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to toggle model availability",
) from exc
@router.get(
"/models/{model_id}/usage",
summary="Get model usage count",
response_model=llm_model.LlmModelUsageResponse,
)
async def get_llm_model_usage(model_id: str):
"""Get the number of workflow nodes using this model."""
try:
return await llm_db.get_model_usage(model_id=model_id)
except ValueError as exc:
raise fastapi.HTTPException(status_code=404, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to get model usage %s: %s", model_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to get model usage",
) from exc
@router.delete(
"/models/{model_id}",
summary="Delete LLM model and migrate workflows",
response_model=llm_model.DeleteLlmModelResponse,
)
async def delete_llm_model(
model_id: str,
replacement_model_slug: str | None = fastapi.Query(
default=None,
description="Slug of the model to migrate existing workflows to (required only if workflows use this model)",
),
):
"""
Delete a model and optionally migrate workflows using it to a replacement model.
If no workflows are using this model, it can be deleted without providing a
replacement. If workflows exist, replacement_model_slug is required.
This endpoint:
1. Counts how many workflow nodes use the model being deleted
2. If nodes exist, validates the replacement model and migrates them
3. Deletes the model record
4. Refreshes all caches and notifies executors
Example: DELETE /api/llm/admin/models/{id}?replacement_model_slug=gpt-4o
Example (no usage): DELETE /api/llm/admin/models/{id}
"""
try:
result = await llm_db.delete_model(
model_id=model_id, replacement_model_slug=replacement_model_slug
)
await _refresh_runtime_state()
logger.info(
"Deleted model '%s' and migrated %d nodes to '%s'",
result.deleted_model_slug,
result.nodes_migrated,
result.replacement_model_slug,
)
return result
except ValueError as exc:
# Validation errors (model not found, replacement invalid, etc.)
logger.warning("Model deletion validation failed: %s", exc)
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to delete LLM model %s: %s", model_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to delete model and migrate workflows",
) from exc
# ============================================================================
# Migration Management Endpoints
# ============================================================================
@router.get(
"/migrations",
summary="List model migrations",
response_model=llm_model.LlmMigrationsResponse,
)
async def list_llm_migrations(
include_reverted: bool = fastapi.Query(
default=False, description="Include reverted migrations in the list"
),
):
"""
List all model migrations.
Migrations are created when disabling a model with the migrate_to_slug option.
They can be reverted to restore the original model configuration.
"""
try:
migrations = await llm_db.list_migrations(include_reverted=include_reverted)
return llm_model.LlmMigrationsResponse(migrations=migrations)
except Exception as exc:
logger.exception("Failed to list migrations: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to list migrations",
) from exc
@router.get(
"/migrations/{migration_id}",
summary="Get migration details",
response_model=llm_model.LlmModelMigration,
)
async def get_llm_migration(migration_id: str):
"""Get details of a specific migration."""
try:
migration = await llm_db.get_migration(migration_id)
if not migration:
raise fastapi.HTTPException(
status_code=404, detail=f"Migration '{migration_id}' not found"
)
return migration
except fastapi.HTTPException:
raise
except Exception as exc:
logger.exception("Failed to get migration %s: %s", migration_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to get migration",
) from exc
@router.post(
"/migrations/{migration_id}/revert",
summary="Revert a model migration",
response_model=llm_model.RevertMigrationResponse,
)
async def revert_llm_migration(
migration_id: str,
request: llm_model.RevertMigrationRequest | None = None,
):
"""
Revert a model migration, restoring affected workflows to their original model.
This only reverts the specific nodes that were part of the migration.
The source model must exist for the revert to succeed.
Options:
- `re_enable_source_model`: Whether to re-enable the source model if disabled (default: True)
Response includes:
- `nodes_reverted`: Number of nodes successfully reverted
- `nodes_already_changed`: Number of nodes that were modified since migration (not reverted)
- `source_model_re_enabled`: Whether the source model was re-enabled
Requirements:
- Migration must not already be reverted
- Source model must exist
"""
try:
re_enable = request.re_enable_source_model if request else True
result = await llm_db.revert_migration(
migration_id,
re_enable_source_model=re_enable,
)
await _refresh_runtime_state()
logger.info(
"Reverted migration '%s': %d nodes restored from '%s' to '%s' "
"(%d already changed, source re-enabled=%s)",
migration_id,
result.nodes_reverted,
result.target_model_slug,
result.source_model_slug,
result.nodes_already_changed,
result.source_model_re_enabled,
)
return result
except ValueError as exc:
logger.warning("Migration revert validation failed: %s", exc)
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to revert migration %s: %s", migration_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to revert migration",
) from exc
# ============================================================================
# Creator Management Endpoints
# ============================================================================
@router.get(
"/creators",
summary="List model creators",
response_model=llm_model.LlmCreatorsResponse,
)
async def list_llm_creators():
"""
List all model creators.
Creators are organizations that create/train models (e.g., OpenAI, Meta, Anthropic).
This is distinct from providers who host/serve the models (e.g., OpenRouter).
"""
try:
creators = await llm_db.list_creators()
return llm_model.LlmCreatorsResponse(creators=creators)
except Exception as exc:
logger.exception("Failed to list creators: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to list creators",
) from exc
@router.get(
"/creators/{creator_id}",
summary="Get creator details",
response_model=llm_model.LlmModelCreator,
)
async def get_llm_creator(creator_id: str):
"""Get details of a specific model creator."""
try:
creator = await llm_db.get_creator(creator_id)
if not creator:
raise fastapi.HTTPException(
status_code=404, detail=f"Creator '{creator_id}' not found"
)
return creator
except fastapi.HTTPException:
raise
except Exception as exc:
logger.exception("Failed to get creator %s: %s", creator_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to get creator",
) from exc
@router.post(
"/creators",
summary="Create model creator",
response_model=llm_model.LlmModelCreator,
)
async def create_llm_creator(request: llm_model.UpsertLlmCreatorRequest):
"""
Create a new model creator.
A creator represents an organization that creates/trains AI models,
such as OpenAI, Anthropic, Meta, or Google.
"""
try:
creator = await llm_db.upsert_creator(request=request)
await _refresh_runtime_state()
logger.info("Created model creator '%s' (%s)", creator.display_name, creator.id)
return creator
except Exception as exc:
logger.exception("Failed to create creator: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to create creator",
) from exc
@router.patch(
"/creators/{creator_id}",
summary="Update model creator",
response_model=llm_model.LlmModelCreator,
)
async def update_llm_creator(
creator_id: str,
request: llm_model.UpsertLlmCreatorRequest,
):
"""Update an existing model creator."""
try:
creator = await llm_db.upsert_creator(request=request, creator_id=creator_id)
await _refresh_runtime_state()
logger.info("Updated model creator '%s' (%s)", creator.display_name, creator_id)
return creator
except Exception as exc:
logger.exception("Failed to update creator %s: %s", creator_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to update creator",
) from exc
@router.delete(
"/creators/{creator_id}",
summary="Delete model creator",
response_model=dict,
)
async def delete_llm_creator(creator_id: str):
"""
Delete a model creator.
This will remove the creator association from all models that reference it
(sets creatorId to NULL), but will not delete the models themselves.
"""
try:
await llm_db.delete_creator(creator_id)
await _refresh_runtime_state()
logger.info("Deleted model creator '%s'", creator_id)
return {"success": True, "message": f"Creator '{creator_id}' deleted"}
except ValueError as exc:
logger.warning("Creator deletion validation failed: %s", exc)
raise fastapi.HTTPException(status_code=404, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to delete creator %s: %s", creator_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to delete creator",
) from exc
# ============================================================================
# Recommended Model Endpoints
# ============================================================================
@router.get(
"/recommended-model",
summary="Get recommended model",
response_model=llm_model.RecommendedModelResponse,
)
async def get_recommended_model():
"""
Get the currently recommended LLM model.
The recommended model is shown to users as the default/suggested option
in model selection dropdowns.
"""
try:
model = await llm_db.get_recommended_model()
return llm_model.RecommendedModelResponse(
model=model,
slug=model.slug if model else None,
)
except Exception as exc:
logger.exception("Failed to get recommended model: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to get recommended model",
) from exc
@router.post(
"/recommended-model",
summary="Set recommended model",
response_model=llm_model.SetRecommendedModelResponse,
)
async def set_recommended_model(request: llm_model.SetRecommendedModelRequest):
"""
Set a model as the recommended model.
This clears the recommended flag from any other model and sets it on
the specified model. The model must be enabled to be set as recommended.
The recommended model is displayed to users as the default/suggested
option in model selection dropdowns throughout the platform.
"""
try:
model, previous_slug = await llm_db.set_recommended_model(request.model_id)
await _refresh_runtime_state()
logger.info(
"Set recommended model to '%s' (previous: %s)",
model.slug,
previous_slug or "none",
)
return llm_model.SetRecommendedModelResponse(
model=model,
previous_recommended_slug=previous_slug,
message=f"Model '{model.display_name}' is now the recommended model",
)
except ValueError as exc:
logger.warning("Set recommended model validation failed: %s", exc)
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to set recommended model: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to set recommended model",
) from exc

View File

@@ -0,0 +1,491 @@
import json
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from pytest_snapshot.plugin import Snapshot
import backend.api.features.admin.llm_routes as llm_routes
from backend.server.v2.llm import model as llm_model
from backend.util.models import Pagination
app = fastapi.FastAPI()
app.include_router(llm_routes.router, prefix="/admin/llm")
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
"""Setup admin auth overrides for all tests in this module"""
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def test_list_llm_providers_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful listing of LLM providers"""
# Mock the database function
mock_providers = [
{
"id": "provider-1",
"name": "openai",
"display_name": "OpenAI",
"description": "OpenAI LLM provider",
"supports_tools": True,
"supports_json_output": True,
"supports_reasoning": False,
"supports_parallel_tool": True,
"metadata": {},
"models": [],
},
{
"id": "provider-2",
"name": "anthropic",
"display_name": "Anthropic",
"description": "Anthropic LLM provider",
"supports_tools": True,
"supports_json_output": True,
"supports_reasoning": False,
"supports_parallel_tool": True,
"metadata": {},
"models": [],
},
]
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.list_providers",
new=AsyncMock(return_value=mock_providers),
)
response = client.get("/admin/llm/providers")
assert response.status_code == 200
response_data = response.json()
assert len(response_data["providers"]) == 2
assert response_data["providers"][0]["name"] == "openai"
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"list_llm_providers_success.json",
)
def test_list_llm_models_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful listing of LLM models with pagination"""
# Mock the database function - now returns LlmModelsResponse
mock_model = llm_model.LlmModel(
id="model-1",
slug="gpt-4o",
display_name="GPT-4o",
description="GPT-4 Optimized",
provider_id="provider-1",
context_window=128000,
max_output_tokens=16384,
is_enabled=True,
capabilities={},
metadata={},
costs=[
llm_model.LlmModelCost(
id="cost-1",
credit_cost=10,
credential_provider="openai",
metadata={},
)
],
)
mock_response = llm_model.LlmModelsResponse(
models=[mock_model],
pagination=Pagination(
total_items=1,
total_pages=1,
current_page=1,
page_size=50,
),
)
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.list_models",
new=AsyncMock(return_value=mock_response),
)
response = client.get("/admin/llm/models")
assert response.status_code == 200
response_data = response.json()
assert len(response_data["models"]) == 1
assert response_data["models"][0]["slug"] == "gpt-4o"
assert response_data["pagination"]["total_items"] == 1
assert response_data["pagination"]["page_size"] == 50
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"list_llm_models_success.json",
)
def test_create_llm_provider_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful creation of LLM provider"""
mock_provider = {
"id": "new-provider-id",
"name": "groq",
"display_name": "Groq",
"description": "Groq LLM provider",
"supports_tools": True,
"supports_json_output": True,
"supports_reasoning": False,
"supports_parallel_tool": False,
"metadata": {},
}
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.upsert_provider",
new=AsyncMock(return_value=mock_provider),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
request_data = {
"name": "groq",
"display_name": "Groq",
"description": "Groq LLM provider",
"supports_tools": True,
"supports_json_output": True,
"supports_reasoning": False,
"supports_parallel_tool": False,
"metadata": {},
}
response = client.post("/admin/llm/providers", json=request_data)
assert response.status_code == 200
response_data = response.json()
assert response_data["name"] == "groq"
assert response_data["display_name"] == "Groq"
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"create_llm_provider_success.json",
)
def test_create_llm_model_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful creation of LLM model"""
mock_model = {
"id": "new-model-id",
"slug": "gpt-4.1-mini",
"display_name": "GPT-4.1 Mini",
"description": "Latest GPT-4.1 Mini model",
"provider_id": "provider-1",
"context_window": 128000,
"max_output_tokens": 16384,
"is_enabled": True,
"capabilities": {},
"metadata": {},
"costs": [
{
"id": "cost-id",
"credit_cost": 5,
"credential_provider": "openai",
"metadata": {},
}
],
}
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.create_model",
new=AsyncMock(return_value=mock_model),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
request_data = {
"slug": "gpt-4.1-mini",
"display_name": "GPT-4.1 Mini",
"description": "Latest GPT-4.1 Mini model",
"provider_id": "provider-1",
"context_window": 128000,
"max_output_tokens": 16384,
"is_enabled": True,
"capabilities": {},
"metadata": {},
"costs": [
{
"credit_cost": 5,
"credential_provider": "openai",
"metadata": {},
}
],
}
response = client.post("/admin/llm/models", json=request_data)
assert response.status_code == 200
response_data = response.json()
assert response_data["slug"] == "gpt-4.1-mini"
assert response_data["is_enabled"] is True
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"create_llm_model_success.json",
)
def test_update_llm_model_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful update of LLM model"""
mock_model = {
"id": "model-1",
"slug": "gpt-4o",
"display_name": "GPT-4o Updated",
"description": "Updated description",
"provider_id": "provider-1",
"context_window": 256000,
"max_output_tokens": 32768,
"is_enabled": True,
"capabilities": {},
"metadata": {},
"costs": [
{
"id": "cost-1",
"credit_cost": 15,
"credential_provider": "openai",
"metadata": {},
}
],
}
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.update_model",
new=AsyncMock(return_value=mock_model),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
request_data = {
"display_name": "GPT-4o Updated",
"description": "Updated description",
"context_window": 256000,
"max_output_tokens": 32768,
}
response = client.patch("/admin/llm/models/model-1", json=request_data)
assert response.status_code == 200
response_data = response.json()
assert response_data["display_name"] == "GPT-4o Updated"
assert response_data["context_window"] == 256000
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"update_llm_model_success.json",
)
def test_toggle_llm_model_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful toggling of LLM model enabled status"""
# Create a proper mock model object
mock_model = llm_model.LlmModel(
id="model-1",
slug="gpt-4o",
display_name="GPT-4o",
description="GPT-4 Optimized",
provider_id="provider-1",
context_window=128000,
max_output_tokens=16384,
is_enabled=False,
capabilities={},
metadata={},
costs=[],
)
# Create a proper ToggleLlmModelResponse
mock_response = llm_model.ToggleLlmModelResponse(
model=mock_model,
nodes_migrated=0,
migrated_to_slug=None,
migration_id=None,
)
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.toggle_model",
new=AsyncMock(return_value=mock_response),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
request_data = {"is_enabled": False}
response = client.patch("/admin/llm/models/model-1/toggle", json=request_data)
assert response.status_code == 200
response_data = response.json()
assert response_data["model"]["is_enabled"] is False
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"toggle_llm_model_success.json",
)
def test_delete_llm_model_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful deletion of LLM model with migration"""
# Create a proper DeleteLlmModelResponse
mock_response = llm_model.DeleteLlmModelResponse(
deleted_model_slug="gpt-3.5-turbo",
deleted_model_display_name="GPT-3.5 Turbo",
replacement_model_slug="gpt-4o-mini",
nodes_migrated=42,
message="Successfully deleted model 'GPT-3.5 Turbo' (gpt-3.5-turbo) "
"and migrated 42 workflow node(s) to 'gpt-4o-mini'.",
)
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.delete_model",
new=AsyncMock(return_value=mock_response),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
response = client.delete(
"/admin/llm/models/model-1?replacement_model_slug=gpt-4o-mini"
)
assert response.status_code == 200
response_data = response.json()
assert response_data["deleted_model_slug"] == "gpt-3.5-turbo"
assert response_data["nodes_migrated"] == 42
assert response_data["replacement_model_slug"] == "gpt-4o-mini"
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"delete_llm_model_success.json",
)
def test_delete_llm_model_validation_error(
mocker: pytest_mock.MockFixture,
) -> None:
"""Test deletion fails with proper error when validation fails"""
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.delete_model",
new=AsyncMock(side_effect=ValueError("Replacement model 'invalid' not found")),
)
response = client.delete("/admin/llm/models/model-1?replacement_model_slug=invalid")
assert response.status_code == 400
assert "Replacement model 'invalid' not found" in response.json()["detail"]
def test_delete_llm_model_no_replacement_with_usage(
mocker: pytest_mock.MockFixture,
) -> None:
"""Test deletion fails when nodes exist but no replacement is provided"""
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.delete_model",
new=AsyncMock(
side_effect=ValueError(
"Cannot delete model 'test-model': 5 workflow node(s) are using it. "
"Please provide a replacement_model_slug to migrate them."
)
),
)
response = client.delete("/admin/llm/models/model-1")
assert response.status_code == 400
assert "workflow node(s) are using it" in response.json()["detail"]
def test_delete_llm_model_no_replacement_no_usage(
mocker: pytest_mock.MockFixture,
) -> None:
"""Test deletion succeeds when no nodes use the model and no replacement is provided"""
mock_response = llm_model.DeleteLlmModelResponse(
deleted_model_slug="unused-model",
deleted_model_display_name="Unused Model",
replacement_model_slug=None,
nodes_migrated=0,
message="Successfully deleted model 'Unused Model' (unused-model). No workflows were using this model.",
)
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.delete_model",
new=AsyncMock(return_value=mock_response),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
response = client.delete("/admin/llm/models/model-1")
assert response.status_code == 200
response_data = response.json()
assert response_data["deleted_model_slug"] == "unused-model"
assert response_data["nodes_migrated"] == 0
assert response_data["replacement_model_slug"] is None
mock_refresh.assert_called_once()

View File

@@ -20,6 +20,7 @@ from backend.blocks._base import (
)
from backend.blocks.llm import LlmModel
from backend.data.db import query_raw_with_schema
from backend.data.llm_registry import get_all_model_slugs_for_validation
from backend.integrations.providers import ProviderName
from backend.util.cache import cached
from backend.util.models import Pagination
@@ -36,7 +37,14 @@ from .model import (
)
logger = logging.getLogger(__name__)
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
def _get_llm_models() -> list[str]:
"""Get LLM model names for search matching from the registry."""
return [
slug.lower().replace("-", " ") for slug in get_all_model_slugs_for_validation()
]
MAX_LIBRARY_AGENT_RESULTS = 100
MAX_MARKETPLACE_AGENT_RESULTS = 100
@@ -501,8 +509,10 @@ async def _get_static_counts():
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
for field in schema_cls.model_fields.values():
if field.annotation == LlmModel:
# Check if query matches any value in llm_models
if any(query in name for name in llm_models):
# Normalize query same as model slugs (lowercase, hyphens to spaces)
normalized_model_query = query.lower().replace("-", " ")
# Check if query matches any value in llm_models from registry
if any(normalized_model_query in name for name in _get_llm_models()):
return True
return False

View File

@@ -27,11 +27,12 @@ class ChatConfig(BaseSettings):
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
# Streaming Configuration
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(
default=3,
description="Max retries for fallback path (SDK handles retries internally)",
max_context_messages: int = Field(
default=50, ge=1, le=200, description="Maximum context messages"
)
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(default=3, description="Maximum number of retries")
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
max_agent_schedules: int = Field(
default=30, description="Maximum number of agent schedules"
@@ -92,31 +93,6 @@ class ChatConfig(BaseSettings):
description="Name of the prompt in Langfuse to fetch",
)
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,
description="Use Claude Agent SDK for chat completions",
)
claude_agent_model: str | None = Field(
default=None,
description="Model for the Claude Agent SDK path. If None, derives from "
"the `model` field by stripping the OpenRouter provider prefix.",
)
claude_agent_max_buffer_size: int = Field(
default=10 * 1024 * 1024, # 10MB (default SDK is 1MB)
description="Max buffer size in bytes for Claude Agent SDK JSON message parsing. "
"Increase if tool outputs exceed the limit.",
)
claude_agent_max_subtasks: int = Field(
default=10,
description="Max number of sub-agent Tasks the SDK can spawn per session.",
)
claude_agent_use_resume: bool = Field(
default=True,
description="Use --resume for multi-turn conversations instead of "
"history compression. Falls back to compression when unavailable.",
)
# Extended thinking configuration for Claude models
thinking_enabled: bool = Field(
default=True,
@@ -162,17 +138,6 @@ class ChatConfig(BaseSettings):
v = os.getenv("CHAT_INTERNAL_API_KEY")
return v
@field_validator("use_claude_agent_sdk", mode="before")
@classmethod
def get_use_claude_agent_sdk(cls, v):
"""Get use_claude_agent_sdk from environment if not provided."""
# Check environment variable - default to True if not set
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
if env_val:
return env_val in ("true", "1", "yes", "on")
# Default to True (SDK enabled by default)
return True if v is None else v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",

View File

@@ -334,8 +334,9 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None:
try:
session = ChatSession.model_validate_json(raw_session)
logger.info(
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, "
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles
f"Loading session {session_id} from cache: "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
)
return session
except Exception as e:
@@ -377,9 +378,11 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
return None
messages = prisma_session.Messages
logger.debug(
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, "
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles
logger.info(
f"Loading session {session_id} from DB: "
f"has_messages={messages is not None}, "
f"message_count={len(messages) if messages else 0}, "
f"roles={[m.role for m in messages] if messages else []}"
)
return ChatSession.from_db(prisma_session, messages)
@@ -430,9 +433,10 @@ async def _save_session_to_db(
"function_call": msg.function_call,
}
)
logger.debug(
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, "
f"roles={[m['role'] for m in messages_data]}"
logger.info(
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
)
await chat_db.add_chat_messages_batch(
session_id=session.session_id,
@@ -472,7 +476,7 @@ async def get_chat_session(
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
# Fall back to database
logger.debug(f"Session {session_id} not in cache, checking database")
logger.info(f"Session {session_id} not in cache, checking database")
session = await _get_session_from_db(session_id)
if session is None:
@@ -489,6 +493,7 @@ async def get_chat_session(
# Cache the session from DB
try:
await _cache_session(session)
logger.info(f"Cached session {session_id} from database")
except Exception as e:
logger.warning(f"Failed to cache session {session_id}: {e}")
@@ -553,40 +558,6 @@ async def upsert_chat_session(
return session
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
"""Atomically append a message to a session and persist it.
Acquires the session lock, re-fetches the latest session state,
appends the message, and saves — preventing message loss when
concurrent requests modify the same session.
"""
lock = await _get_session_lock(session_id)
async with lock:
session = await get_chat_session(session_id)
if session is None:
raise ValueError(f"Session {session_id} not found")
session.messages.append(message)
existing_message_count = await chat_db.get_chat_session_message_count(
session_id
)
try:
await _save_session_to_db(session, existing_message_count)
except Exception as e:
raise DatabaseError(
f"Failed to persist message to session {session_id}"
) from e
try:
await _cache_session(session)
except Exception as e:
logger.warning(f"Cache write failed for session {session_id}: {e}")
return session
async def create_chat_session(user_id: str) -> ChatSession:
"""Create a new chat session and persist it.
@@ -693,19 +664,13 @@ async def update_session_title(session_id: str, title: str) -> bool:
logger.warning(f"Session {session_id} not found for title update")
return False
# Update title in cache if it exists (instead of invalidating).
# This prevents race conditions where cache invalidation causes
# the frontend to see stale DB data while streaming is still in progress.
# Invalidate cache so next fetch gets updated title
try:
cached = await _get_session_from_cache(session_id)
if cached:
cached.title = title
await _cache_session(cached)
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
except Exception as e:
# Not critical - title will be correct on next full cache refresh
logger.warning(
f"Failed to update title in cache for session {session_id}: {e}"
)
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
return True
except Exception as e:

View File

@@ -1,6 +1,5 @@
"""Chat API routes for chat session management and streaming via SSE."""
import asyncio
import logging
import uuid as uuid_module
from collections.abc import AsyncGenerator
@@ -12,23 +11,13 @@ from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from backend.util.exceptions import NotFoundError
from backend.util.feature_flag import Flag, is_feature_enabled
from . import service as chat_service
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
from .model import (
ChatMessage,
ChatSession,
append_and_save_message,
create_chat_session,
delete_chat_session,
get_chat_session,
get_user_sessions,
)
from .response_model import StreamError, StreamFinish, StreamHeartbeat, StreamStart
from .sdk import service as sdk_service
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
from .response_model import StreamFinish, StreamHeartbeat
from .tools.models import (
AgentDetailsResponse,
AgentOutputResponse,
@@ -52,7 +41,6 @@ from .tools.models import (
SetupRequirementsResponse,
UnderstandingUpdatedResponse,
)
from .tracking import track_user_message
config = ChatConfig()
@@ -212,43 +200,6 @@ async def create_session(
)
@router.delete(
"/sessions/{session_id}",
dependencies=[Security(auth.requires_user)],
status_code=204,
responses={404: {"description": "Session not found or access denied"}},
)
async def delete_session(
session_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> Response:
"""
Delete a chat session.
Permanently removes a chat session and all its messages.
Only the owner can delete their sessions.
Args:
session_id: The session ID to delete.
user_id: The authenticated user's ID.
Returns:
204 No Content on success.
Raises:
HTTPException: 404 if session not found or not owned by user.
"""
deleted = await delete_chat_session(session_id, user_id)
if not deleted:
raise HTTPException(
status_code=404,
detail=f"Session {session_id} not found or access denied",
)
return Response(status_code=204)
@router.get(
"/sessions/{session_id}",
)
@@ -281,10 +232,6 @@ async def get_session(
active_task, last_message_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_task:
# Filter out the in-progress assistant message from the session response.
# The client will receive the complete assistant response through the SSE
@@ -354,9 +301,10 @@ async def stream_chat_post(
f"user={user_id}, message_len={len(request.message)}",
extra={"json_fields": log_meta},
)
session = await _validate_and_get_session(session_id, user_id)
logger.info(
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
@@ -365,25 +313,6 @@ async def stream_chat_post(
},
)
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
# saved yet. append_and_save_message re-fetches inside a lock to prevent
# message loss from concurrent requests.
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
)
if request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
session = await append_and_save_message(session_id, message)
logger.info(f"[STREAM] User message saved for session {session_id}")
# Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4())
@@ -399,7 +328,7 @@ async def stream_chat_post(
operation_id=operation_id,
)
logger.info(
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
@@ -420,47 +349,15 @@ async def stream_chat_post(
first_chunk_time, ttfc = None, None
chunk_count = 0
try:
# Emit a start event with task_id for reconnection
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
await stream_registry.publish_chunk(task_id, start_chunk)
logger.info(
f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": (time_module.perf_counter() - gen_start_time)
* 1000,
}
},
)
# Choose service based on LaunchDarkly flag (falls back to config default)
use_sdk = await is_feature_enabled(
Flag.COPILOT_SDK,
user_id or "anonymous",
default=config.use_claude_agent_sdk,
)
stream_fn = (
sdk_service.stream_chat_completion_sdk
if use_sdk
else chat_service.stream_chat_completion
)
logger.info(
f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion",
extra={"json_fields": log_meta},
)
# Pass message=None since we already added it to the session above
async for chunk in stream_fn(
async for chunk in chat_service.stream_chat_completion(
session_id,
None, # Message already in session
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass session with message already added
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
):
# Skip duplicate StreamStart — we already published one above
if isinstance(chunk, StreamStart):
continue
chunk_count += 1
if first_chunk_time is None:
first_chunk_time = time_module.perf_counter()
@@ -481,7 +378,7 @@ async def stream_chat_post(
gen_end_time = time_module.perf_counter()
total_time = (gen_end_time - gen_start_time) * 1000
logger.info(
f"[TIMING] run_ai_generation FINISHED in {total_time / 1000:.1f}s; "
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
f"task={task_id}, session={session_id}, "
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
extra={
@@ -508,17 +405,6 @@ async def stream_chat_post(
}
},
)
# Publish a StreamError so the frontend can display an error message
try:
await stream_registry.publish_chunk(
task_id,
StreamError(
errorText="An error occurred. Please try again.",
code="stream_error",
),
)
except Exception:
pass # Best-effort; mark_task_completed will publish StreamFinish
await stream_registry.mark_task_completed(task_id, "failed")
# Start the AI generation in a background task
@@ -621,14 +507,8 @@ async def stream_chat_post(
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
},
)
# Surface error to frontend so it doesn't appear stuck
yield StreamError(
errorText="An error occurred. Please try again.",
code="stream_error",
).to_sse()
yield StreamFinish().to_sse()
finally:
# Unsubscribe when client disconnects or stream ends
# Unsubscribe when client disconnects or stream ends to prevent resource leak
if subscriber_queue is not None:
try:
await stream_registry.unsubscribe_from_task(
@@ -872,6 +752,8 @@ async def stream_task(
)
async def event_generator() -> AsyncGenerator[str, None]:
import asyncio
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
try:
while True:

View File

@@ -1,14 +0,0 @@
"""Claude Agent SDK integration for CoPilot.
This module provides the integration layer between the Claude Agent SDK
and the existing CoPilot tool system, enabling drop-in replacement of
the current LLM orchestration with the battle-tested Claude Agent SDK.
"""
from .service import stream_chat_completion_sdk
from .tool_adapter import create_copilot_mcp_server
__all__ = [
"stream_chat_completion_sdk",
"create_copilot_mcp_server",
]

View File

@@ -1,203 +0,0 @@
"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
This module provides the adapter layer that converts streaming messages from
the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that
the frontend expects.
"""
import json
import logging
import uuid
from claude_agent_sdk import (
AssistantMessage,
Message,
ResultMessage,
SystemMessage,
TextBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
)
from backend.api.features.chat.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from backend.api.features.chat.sdk.tool_adapter import (
MCP_TOOL_PREFIX,
pop_pending_tool_output,
)
logger = logging.getLogger(__name__)
class SDKResponseAdapter:
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
This class maintains state during a streaming session to properly track
text blocks, tool calls, and message lifecycle.
"""
def __init__(self, message_id: str | None = None):
self.message_id = message_id or str(uuid.uuid4())
self.text_block_id = str(uuid.uuid4())
self.has_started_text = False
self.has_ended_text = False
self.current_tool_calls: dict[str, dict[str, str]] = {}
self.task_id: str | None = None
self.step_open = False
def set_task_id(self, task_id: str) -> None:
"""Set the task ID for reconnection support."""
self.task_id = task_id
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
"""Convert a single SDK message to Vercel AI SDK format."""
responses: list[StreamBaseResponse] = []
if isinstance(sdk_message, SystemMessage):
if sdk_message.subtype == "init":
responses.append(
StreamStart(messageId=self.message_id, taskId=self.task_id)
)
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
responses.append(StreamStartStep())
self.step_open = True
elif isinstance(sdk_message, AssistantMessage):
# After tool results, the SDK sends a new AssistantMessage for the
# next LLM turn. Open a new step if the previous one was closed.
if not self.step_open:
responses.append(StreamStartStep())
self.step_open = True
for block in sdk_message.content:
if isinstance(block, TextBlock):
if block.text:
self._ensure_text_started(responses)
responses.append(
StreamTextDelta(id=self.text_block_id, delta=block.text)
)
elif isinstance(block, ToolUseBlock):
self._end_text_if_open(responses)
# Strip MCP prefix so frontend sees "find_block"
# instead of "mcp__copilot__find_block".
tool_name = block.name.removeprefix(MCP_TOOL_PREFIX)
responses.append(
StreamToolInputStart(toolCallId=block.id, toolName=tool_name)
)
responses.append(
StreamToolInputAvailable(
toolCallId=block.id,
toolName=tool_name,
input=block.input,
)
)
self.current_tool_calls[block.id] = {"name": tool_name}
elif isinstance(sdk_message, UserMessage):
# UserMessage carries tool results back from tool execution.
content = sdk_message.content
blocks = content if isinstance(content, list) else []
for block in blocks:
if isinstance(block, ToolResultBlock) and block.tool_use_id:
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
tool_name = tool_info.get("name", "unknown")
# Prefer the stashed full output over the SDK's
# (potentially truncated) ToolResultBlock content.
# The SDK truncates large results, writing them to disk,
# which breaks frontend widget parsing.
output = pop_pending_tool_output(tool_name) or (
_extract_tool_output(block.content)
)
responses.append(
StreamToolOutputAvailable(
toolCallId=block.tool_use_id,
toolName=tool_name,
output=output,
success=not (block.is_error or False),
)
)
# Close the current step after tool results — the next
# AssistantMessage will open a new step for the continuation.
if self.step_open:
responses.append(StreamFinishStep())
self.step_open = False
elif isinstance(sdk_message, ResultMessage):
self._end_text_if_open(responses)
# Close the step before finishing.
if self.step_open:
responses.append(StreamFinishStep())
self.step_open = False
if sdk_message.subtype == "success":
responses.append(StreamFinish())
elif sdk_message.subtype in ("error", "error_during_execution"):
error_msg = getattr(sdk_message, "result", None) or "Unknown error"
responses.append(
StreamError(errorText=str(error_msg), code="sdk_error")
)
responses.append(StreamFinish())
else:
logger.warning(
f"Unexpected ResultMessage subtype: {sdk_message.subtype}"
)
responses.append(StreamFinish())
else:
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
return responses
def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None:
"""Start (or restart) a text block if needed."""
if not self.has_started_text or self.has_ended_text:
if self.has_ended_text:
self.text_block_id = str(uuid.uuid4())
self.has_ended_text = False
responses.append(StreamTextStart(id=self.text_block_id))
self.has_started_text = True
def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None:
"""End the current text block if one is open."""
if self.has_started_text and not self.has_ended_text:
responses.append(StreamTextEnd(id=self.text_block_id))
self.has_ended_text = True
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
"""Extract a string output from a ToolResultBlock's content field."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
if parts:
return "".join(parts)
try:
return json.dumps(content)
except (TypeError, ValueError):
return str(content)
if content is None:
return ""
try:
return json.dumps(content)
except (TypeError, ValueError):
return str(content)

View File

@@ -1,366 +0,0 @@
"""Unit tests for the SDK response adapter."""
from claude_agent_sdk import (
AssistantMessage,
ResultMessage,
SystemMessage,
TextBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
)
from backend.api.features.chat.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from .response_adapter import SDKResponseAdapter
from .tool_adapter import MCP_TOOL_PREFIX
def _adapter() -> SDKResponseAdapter:
a = SDKResponseAdapter(message_id="msg-1")
a.set_task_id("task-1")
return a
# -- SystemMessage -----------------------------------------------------------
def test_system_init_emits_start_and_step():
adapter = _adapter()
results = adapter.convert_message(SystemMessage(subtype="init", data={}))
assert len(results) == 2
assert isinstance(results[0], StreamStart)
assert results[0].messageId == "msg-1"
assert results[0].taskId == "task-1"
assert isinstance(results[1], StreamStartStep)
def test_system_non_init_emits_nothing():
adapter = _adapter()
results = adapter.convert_message(SystemMessage(subtype="other", data={}))
assert results == []
# -- AssistantMessage with TextBlock -----------------------------------------
def test_text_block_emits_step_start_and_delta():
adapter = _adapter()
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
results = adapter.convert_message(msg)
assert len(results) == 3
assert isinstance(results[0], StreamStartStep)
assert isinstance(results[1], StreamTextStart)
assert isinstance(results[2], StreamTextDelta)
assert results[2].delta == "hello"
def test_empty_text_block_emits_only_step():
adapter = _adapter()
msg = AssistantMessage(content=[TextBlock(text="")], model="test")
results = adapter.convert_message(msg)
# Empty text skipped, but step still opens
assert len(results) == 1
assert isinstance(results[0], StreamStartStep)
def test_multiple_text_deltas_reuse_block_id():
adapter = _adapter()
msg1 = AssistantMessage(content=[TextBlock(text="a")], model="test")
msg2 = AssistantMessage(content=[TextBlock(text="b")], model="test")
r1 = adapter.convert_message(msg1)
r2 = adapter.convert_message(msg2)
# First gets step+start+delta, second only delta (block & step already started)
assert len(r1) == 3
assert isinstance(r1[0], StreamStartStep)
assert isinstance(r1[1], StreamTextStart)
assert len(r2) == 1
assert isinstance(r2[0], StreamTextDelta)
assert r1[1].id == r2[0].id # same block ID
# -- AssistantMessage with ToolUseBlock --------------------------------------
def test_tool_use_emits_input_start_and_available():
"""Tool names arrive with MCP prefix and should be stripped for the frontend."""
adapter = _adapter()
msg = AssistantMessage(
content=[
ToolUseBlock(
id="tool-1",
name=f"{MCP_TOOL_PREFIX}find_agent",
input={"q": "x"},
)
],
model="test",
)
results = adapter.convert_message(msg)
assert len(results) == 3
assert isinstance(results[0], StreamStartStep)
assert isinstance(results[1], StreamToolInputStart)
assert results[1].toolCallId == "tool-1"
assert results[1].toolName == "find_agent" # prefix stripped
assert isinstance(results[2], StreamToolInputAvailable)
assert results[2].toolName == "find_agent" # prefix stripped
assert results[2].input == {"q": "x"}
def test_text_then_tool_ends_text_block():
adapter = _adapter()
text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
tool_msg = AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
model="test",
)
adapter.convert_message(text_msg) # opens step + text
results = adapter.convert_message(tool_msg)
# Step already open, so: TextEnd, ToolInputStart, ToolInputAvailable
assert len(results) == 3
assert isinstance(results[0], StreamTextEnd)
assert isinstance(results[1], StreamToolInputStart)
# -- UserMessage with ToolResultBlock ----------------------------------------
def test_tool_result_emits_output_and_finish_step():
adapter = _adapter()
# First register the tool call (opens step) — SDK sends prefixed name
tool_msg = AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_agent", input={})],
model="test",
)
adapter.convert_message(tool_msg)
# Now send tool result
result_msg = UserMessage(
content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")]
)
results = adapter.convert_message(result_msg)
assert len(results) == 2
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].toolCallId == "t1"
assert results[0].toolName == "find_agent" # prefix stripped
assert results[0].output == "found 3 agents"
assert results[0].success is True
assert isinstance(results[1], StreamFinishStep)
def test_tool_result_error():
adapter = _adapter()
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}run_agent", input={})
],
model="test",
)
)
result_msg = UserMessage(
content=[ToolResultBlock(tool_use_id="t1", content="timeout", is_error=True)]
)
results = adapter.convert_message(result_msg)
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].success is False
assert isinstance(results[1], StreamFinishStep)
def test_tool_result_list_content():
adapter = _adapter()
adapter.convert_message(
AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
model="test",
)
)
result_msg = UserMessage(
content=[
ToolResultBlock(
tool_use_id="t1",
content=[
{"type": "text", "text": "line1"},
{"type": "text", "text": "line2"},
],
)
]
)
results = adapter.convert_message(result_msg)
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].output == "line1line2"
assert isinstance(results[1], StreamFinishStep)
def test_string_user_message_ignored():
"""A plain string UserMessage (not tool results) produces no output."""
adapter = _adapter()
results = adapter.convert_message(UserMessage(content="hello"))
assert results == []
# -- ResultMessage -----------------------------------------------------------
def test_result_success_emits_finish_step_and_finish():
adapter = _adapter()
# Start some text first (opens step)
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="done")], model="test")
)
msg = ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="s1",
)
results = adapter.convert_message(msg)
# TextEnd + FinishStep + StreamFinish
assert len(results) == 3
assert isinstance(results[0], StreamTextEnd)
assert isinstance(results[1], StreamFinishStep)
assert isinstance(results[2], StreamFinish)
def test_result_error_emits_error_and_finish():
adapter = _adapter()
msg = ResultMessage(
subtype="error",
duration_ms=100,
duration_api_ms=50,
is_error=True,
num_turns=0,
session_id="s1",
result="API rate limited",
)
results = adapter.convert_message(msg)
# No step was open, so no FinishStep — just Error + Finish
assert len(results) == 2
assert isinstance(results[0], StreamError)
assert "API rate limited" in results[0].errorText
assert isinstance(results[1], StreamFinish)
# -- Text after tools (new block ID) ----------------------------------------
def test_text_after_tool_gets_new_block_id():
adapter = _adapter()
# Text -> Tool -> ToolResult -> Text should get a new text block ID and step
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="before")], model="test")
)
adapter.convert_message(
AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
model="test",
)
)
# Send tool result (closes step)
adapter.convert_message(
UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="ok")])
)
results = adapter.convert_message(
AssistantMessage(content=[TextBlock(text="after")], model="test")
)
# Should get StreamStartStep (new step) + StreamTextStart (new block) + StreamTextDelta
assert len(results) == 3
assert isinstance(results[0], StreamStartStep)
assert isinstance(results[1], StreamTextStart)
assert isinstance(results[2], StreamTextDelta)
assert results[2].delta == "after"
# -- Full conversation flow --------------------------------------------------
def test_full_conversation_flow():
"""Simulate a complete conversation: init -> text -> tool -> result -> text -> finish."""
adapter = _adapter()
all_responses: list[StreamBaseResponse] = []
# 1. Init
all_responses.extend(
adapter.convert_message(SystemMessage(subtype="init", data={}))
)
# 2. Assistant text
all_responses.extend(
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="Let me search")], model="test")
)
)
# 3. Tool use
all_responses.extend(
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(
id="t1",
name=f"{MCP_TOOL_PREFIX}find_agent",
input={"query": "email"},
)
],
model="test",
)
)
)
# 4. Tool result
all_responses.extend(
adapter.convert_message(
UserMessage(
content=[ToolResultBlock(tool_use_id="t1", content="Found 2 agents")]
)
)
)
# 5. More text
all_responses.extend(
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="I found 2")], model="test")
)
)
# 6. Result
all_responses.extend(
adapter.convert_message(
ResultMessage(
subtype="success",
duration_ms=500,
duration_api_ms=400,
is_error=False,
num_turns=2,
session_id="s1",
)
)
)
types = [type(r).__name__ for r in all_responses]
assert types == [
"StreamStart",
"StreamStartStep", # step 1: text + tool call
"StreamTextStart",
"StreamTextDelta", # "Let me search"
"StreamTextEnd", # closed before tool
"StreamToolInputStart",
"StreamToolInputAvailable",
"StreamToolOutputAvailable", # tool result
"StreamFinishStep", # step 1 closed after tool result
"StreamStartStep", # step 2: continuation text
"StreamTextStart", # new block after tool
"StreamTextDelta", # "I found 2"
"StreamTextEnd", # closed by result
"StreamFinishStep", # step 2 closed
"StreamFinish",
]

View File

@@ -1,335 +0,0 @@
"""Security hooks for Claude Agent SDK integration.
This module provides security hooks that validate tool calls before execution,
ensuring multi-user isolation and preventing unauthorized operations.
"""
import json
import logging
import os
import re
from collections.abc import Callable
from typing import Any, cast
from backend.api.features.chat.sdk.tool_adapter import MCP_TOOL_PREFIX
logger = logging.getLogger(__name__)
# Tools that are blocked entirely (CLI/system access).
# "Bash" (capital) is the SDK built-in — it's NOT in allowed_tools but blocked
# here as defence-in-depth. The agent uses mcp__copilot__bash_exec instead,
# which has kernel-level network isolation (unshare --net).
BLOCKED_TOOLS = {
"Bash",
"bash",
"shell",
"exec",
"terminal",
"command",
}
# Tools allowed only when their path argument stays within the SDK workspace.
# The SDK uses these to handle oversized tool results (writes to tool-results/
# files, then reads them back) and for workspace file operations.
WORKSPACE_SCOPED_TOOLS = {"Read", "Write", "Edit", "Glob", "Grep"}
# Dangerous patterns in tool inputs
DANGEROUS_PATTERNS = [
r"sudo",
r"rm\s+-rf",
r"dd\s+if=",
r"/etc/passwd",
r"/etc/shadow",
r"chmod\s+777",
r"curl\s+.*\|.*sh",
r"wget\s+.*\|.*sh",
r"eval\s*\(",
r"exec\s*\(",
r"__import__",
r"os\.system",
r"subprocess",
]
def _deny(reason: str) -> dict[str, Any]:
"""Return a hook denial response."""
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": reason,
}
}
def _validate_workspace_path(
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None
) -> dict[str, Any]:
"""Validate that a workspace-scoped tool only accesses allowed paths.
Allowed directories:
- The SDK working directory (``/tmp/copilot-<session>/``)
- The SDK tool-results directory (``~/.claude/projects/…/tool-results/``)
"""
path = tool_input.get("file_path") or tool_input.get("path") or ""
if not path:
# Glob/Grep without a path default to cwd which is already sandboxed
return {}
# Resolve relative paths against sdk_cwd (the SDK sets cwd so the LLM
# naturally uses relative paths like "test.txt" instead of absolute ones).
# Tilde paths (~/) are home-dir references, not relative — expand first.
if path.startswith("~"):
resolved = os.path.realpath(os.path.expanduser(path))
elif not os.path.isabs(path) and sdk_cwd:
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
else:
resolved = os.path.realpath(path)
# Allow access within the SDK working directory
if sdk_cwd:
norm_cwd = os.path.realpath(sdk_cwd)
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
return {}
# Allow access to ~/.claude/projects/*/tool-results/ (big tool results)
claude_dir = os.path.realpath(os.path.expanduser("~/.claude/projects"))
tool_results_seg = os.sep + "tool-results" + os.sep
if resolved.startswith(claude_dir + os.sep) and tool_results_seg in resolved:
return {}
logger.warning(
f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})"
)
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
return _deny(
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
f"directory.{workspace_hint} "
"This is enforced by the platform and cannot be bypassed."
)
def _validate_tool_access(
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None = None
) -> dict[str, Any]:
"""Validate that a tool call is allowed.
Returns:
Empty dict to allow, or dict with hookSpecificOutput to deny
"""
# Block forbidden tools
if tool_name in BLOCKED_TOOLS:
logger.warning(f"Blocked tool access attempt: {tool_name}")
return _deny(
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
"This is enforced by the platform and cannot be bypassed. "
"Use the CoPilot-specific MCP tools instead."
)
# Workspace-scoped tools: allowed only within the SDK workspace directory
if tool_name in WORKSPACE_SCOPED_TOOLS:
return _validate_workspace_path(tool_name, tool_input, sdk_cwd)
# Check for dangerous patterns in tool input
# Use json.dumps for predictable format (str() produces Python repr)
input_str = json.dumps(tool_input) if tool_input else ""
for pattern in DANGEROUS_PATTERNS:
if re.search(pattern, input_str, re.IGNORECASE):
logger.warning(
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
)
return _deny(
"[SECURITY] Input contains a blocked pattern. "
"This is enforced by the platform and cannot be bypassed."
)
return {}
def _validate_user_isolation(
tool_name: str, tool_input: dict[str, Any], user_id: str | None
) -> dict[str, Any]:
"""Validate that tool calls respect user isolation."""
# For workspace file tools, ensure path doesn't escape
if "workspace" in tool_name.lower():
path = tool_input.get("path", "") or tool_input.get("file_path", "")
if path:
# Check for path traversal
if ".." in path or path.startswith("/"):
logger.warning(
f"Blocked path traversal attempt: {path} by user {user_id}"
)
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": "Path traversal not allowed",
}
}
return {}
def create_security_hooks(
user_id: str | None,
sdk_cwd: str | None = None,
max_subtasks: int = 3,
on_stop: Callable[[str, str], None] | None = None,
) -> dict[str, Any]:
"""Create the security hooks configuration for Claude Agent SDK.
Includes security validation and observability hooks:
- PreToolUse: Security validation before tool execution
- PostToolUse: Log successful tool executions
- PostToolUseFailure: Log and handle failed tool executions
- PreCompact: Log context compaction events (SDK handles compaction automatically)
- Stop: Capture transcript path for stateless resume (when *on_stop* is provided)
Args:
user_id: Current user ID for isolation validation
sdk_cwd: SDK working directory for workspace-scoped tool validation
max_subtasks: Maximum Task (sub-agent) spawns allowed per session
on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when
the SDK finishes processing — used to read the JSONL transcript
before the CLI process exits.
Returns:
Hooks configuration dict for ClaudeAgentOptions
"""
try:
from claude_agent_sdk import HookMatcher
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
# Per-session counter for Task sub-agent spawns
task_spawn_count = 0
async def pre_tool_use_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Combined pre-tool-use validation hook."""
nonlocal task_spawn_count
_ = context # unused but required by signature
tool_name = cast(str, input_data.get("tool_name", ""))
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
# Rate-limit Task (sub-agent) spawns per session
if tool_name == "Task":
task_spawn_count += 1
if task_spawn_count > max_subtasks:
logger.warning(
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
)
return cast(
SyncHookJSONOutput,
_deny(
f"Maximum {max_subtasks} sub-tasks per session. "
"Please continue in the main conversation."
),
)
# Strip MCP prefix for consistent validation
is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX)
clean_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
# Only block non-CoPilot tools; our MCP-registered tools
# (including Read for oversized results) are already sandboxed.
if not is_copilot_tool:
result = _validate_tool_access(clean_name, tool_input, sdk_cwd)
if result:
return cast(SyncHookJSONOutput, result)
# Validate user isolation
result = _validate_user_isolation(clean_name, tool_input, user_id)
if result:
return cast(SyncHookJSONOutput, result)
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
return cast(SyncHookJSONOutput, {})
async def post_tool_use_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log successful tool executions for observability."""
_ = context
tool_name = cast(str, input_data.get("tool_name", ""))
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
return cast(SyncHookJSONOutput, {})
async def post_tool_failure_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log failed tool executions for debugging."""
_ = context
tool_name = cast(str, input_data.get("tool_name", ""))
error = input_data.get("error", "Unknown error")
logger.warning(
f"[SDK] Tool failed: {tool_name}, error={error}, "
f"user={user_id}, tool_use_id={tool_use_id}"
)
return cast(SyncHookJSONOutput, {})
async def pre_compact_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log when SDK triggers context compaction.
The SDK automatically compacts conversation history when it grows too large.
This hook provides visibility into when compaction happens.
"""
_ = context, tool_use_id
trigger = input_data.get("trigger", "auto")
logger.info(
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
)
return cast(SyncHookJSONOutput, {})
# --- Stop hook: capture transcript path for stateless resume ---
async def stop_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Capture transcript path when SDK finishes processing.
The Stop hook fires while the CLI process is still alive, giving us
a reliable window to read the JSONL transcript before SIGTERM.
"""
_ = context, tool_use_id
transcript_path = cast(str, input_data.get("transcript_path", ""))
sdk_session_id = cast(str, input_data.get("session_id", ""))
if transcript_path and on_stop:
logger.info(
f"[SDK] Stop hook: transcript_path={transcript_path}, "
f"sdk_session_id={sdk_session_id[:12]}..."
)
on_stop(transcript_path, sdk_session_id)
return cast(SyncHookJSONOutput, {})
hooks: dict[str, Any] = {
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
"PostToolUseFailure": [
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
],
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
}
if on_stop is not None:
hooks["Stop"] = [HookMatcher(matcher=None, hooks=[stop_hook])]
return hooks
except ImportError:
# Fallback for when SDK isn't available - return empty hooks
logger.warning("claude-agent-sdk not available, security hooks disabled")
return {}

View File

@@ -1,165 +0,0 @@
"""Unit tests for SDK security hooks."""
import os
from .security_hooks import _validate_tool_access, _validate_user_isolation
SDK_CWD = "/tmp/copilot-abc123"
def _is_denied(result: dict) -> bool:
hook = result.get("hookSpecificOutput", {})
return hook.get("permissionDecision") == "deny"
# -- Blocked tools -----------------------------------------------------------
def test_blocked_tools_denied():
for tool in ("bash", "shell", "exec", "terminal", "command"):
result = _validate_tool_access(tool, {})
assert _is_denied(result), f"{tool} should be blocked"
def test_unknown_tool_allowed():
result = _validate_tool_access("SomeCustomTool", {})
assert result == {}
# -- Workspace-scoped tools --------------------------------------------------
def test_read_within_workspace_allowed():
result = _validate_tool_access(
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, sdk_cwd=SDK_CWD
)
assert result == {}
def test_write_within_workspace_allowed():
result = _validate_tool_access(
"Write", {"file_path": f"{SDK_CWD}/output.json"}, sdk_cwd=SDK_CWD
)
assert result == {}
def test_edit_within_workspace_allowed():
result = _validate_tool_access(
"Edit", {"file_path": f"{SDK_CWD}/src/main.py"}, sdk_cwd=SDK_CWD
)
assert result == {}
def test_glob_within_workspace_allowed():
result = _validate_tool_access("Glob", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
assert result == {}
def test_grep_within_workspace_allowed():
result = _validate_tool_access("Grep", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_outside_workspace_denied():
result = _validate_tool_access(
"Read", {"file_path": "/etc/passwd"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_write_outside_workspace_denied():
result = _validate_tool_access(
"Write", {"file_path": "/home/user/secrets.txt"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_traversal_attack_denied():
result = _validate_tool_access(
"Read",
{"file_path": f"{SDK_CWD}/../../etc/passwd"},
sdk_cwd=SDK_CWD,
)
assert _is_denied(result)
def test_no_path_allowed():
"""Glob/Grep without a path argument defaults to cwd — should pass."""
result = _validate_tool_access("Glob", {}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_no_cwd_denies_absolute():
"""If no sdk_cwd is set, absolute paths are denied."""
result = _validate_tool_access("Read", {"file_path": "/tmp/anything"})
assert _is_denied(result)
# -- Tool-results directory --------------------------------------------------
def test_read_tool_results_allowed():
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_claude_projects_without_tool_results_denied():
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert _is_denied(result)
# -- Built-in Bash is blocked (use bash_exec MCP tool instead) ---------------
def test_bash_builtin_always_blocked():
"""SDK built-in Bash is blocked — bash_exec MCP tool with bubblewrap is used instead."""
result = _validate_tool_access("Bash", {"command": "echo hello"}, sdk_cwd=SDK_CWD)
assert _is_denied(result)
# -- Dangerous patterns ------------------------------------------------------
def test_dangerous_pattern_blocked():
result = _validate_tool_access("SomeTool", {"cmd": "sudo rm -rf /"})
assert _is_denied(result)
def test_subprocess_pattern_blocked():
result = _validate_tool_access("SomeTool", {"code": "subprocess.run(...)"})
assert _is_denied(result)
# -- User isolation ----------------------------------------------------------
def test_workspace_path_traversal_blocked():
result = _validate_user_isolation(
"workspace_read", {"path": "../../../etc/shadow"}, user_id="user-1"
)
assert _is_denied(result)
def test_workspace_absolute_path_blocked():
result = _validate_user_isolation(
"workspace_read", {"path": "/etc/passwd"}, user_id="user-1"
)
assert _is_denied(result)
def test_workspace_normal_path_allowed():
result = _validate_user_isolation(
"workspace_read", {"path": "src/main.py"}, user_id="user-1"
)
assert result == {}
def test_non_workspace_tool_passes_isolation():
result = _validate_user_isolation(
"find_agent", {"query": "email"}, user_id="user-1"
)
assert result == {}

View File

@@ -1,751 +0,0 @@
"""Claude Agent SDK service layer for CoPilot chat completions."""
import asyncio
import json
import logging
import os
import uuid
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any
from backend.util.exceptions import NotFoundError
from .. import stream_registry
from ..config import ChatConfig
from ..model import (
ChatMessage,
ChatSession,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from ..response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamStart,
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from ..service import (
_build_system_prompt,
_execute_long_running_tool_with_streaming,
_generate_session_title,
)
from ..tools.models import OperationPendingResponse, OperationStartedResponse
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tracking import track_user_message
from .response_adapter import SDKResponseAdapter
from .security_hooks import create_security_hooks
from .tool_adapter import (
COPILOT_TOOL_NAMES,
LongRunningCallback,
create_copilot_mcp_server,
set_execution_context,
)
from .transcript import (
download_transcript,
read_transcript_file,
upload_transcript,
validate_transcript,
write_transcript_to_tempfile,
)
logger = logging.getLogger(__name__)
config = ChatConfig()
# Set to hold background tasks to prevent garbage collection
_background_tasks: set[asyncio.Task[Any]] = set()
@dataclass
class CapturedTranscript:
"""Info captured by the SDK Stop hook for stateless --resume."""
path: str = ""
sdk_session_id: str = ""
@property
def available(self) -> bool:
return bool(self.path)
_SDK_CWD_PREFIX = WORKSPACE_PREFIX
# Appended to the system prompt to inform the agent about available tools.
# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead,
# which has kernel-level network isolation (unshare --net).
_SDK_TOOL_SUPPLEMENT = """
## Tool notes
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
for shell commands — it runs in a network-isolated sandbox.
- **Shared workspace**: The SDK Read/Write tools and `bash_exec` share the
same working directory. Files created by one are readable by the other.
These files are **ephemeral** — they exist only for the current session.
- **Persistent storage**: Use `write_workspace_file` / `read_workspace_file`
for files that should persist across sessions (stored in cloud storage).
- Long-running tools (create_agent, edit_agent, etc.) are handled
asynchronously. You will receive an immediate response; the actual result
is delivered to the user via a background stream.
"""
def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
"""Build a callback that delegates long-running tools to the non-SDK infrastructure.
Long-running tools (create_agent, edit_agent, etc.) are delegated to the
existing background infrastructure: stream_registry (Redis Streams),
database persistence, and SSE reconnection. This means results survive
page refreshes / pod restarts, and the frontend shows the proper loading
widget with progress updates.
The returned callback matches the ``LongRunningCallback`` signature:
``(tool_name, args, session) -> MCP response dict``.
"""
async def _callback(
tool_name: str, args: dict[str, Any], session: ChatSession
) -> dict[str, Any]:
operation_id = str(uuid.uuid4())
task_id = str(uuid.uuid4())
tool_call_id = f"sdk-{uuid.uuid4().hex[:12]}"
session_id = session.session_id
# --- Build user-friendly messages (matches non-SDK service) ---
if tool_name == "create_agent":
desc = args.get("description", "")
desc_preview = (desc[:100] + "...") if len(desc) > 100 else desc
pending_msg = (
f"Creating your agent: {desc_preview}"
if desc_preview
else "Creating agent... This may take a few minutes."
)
started_msg = (
"Agent creation started. You can close this tab - "
"check your library in a few minutes."
)
elif tool_name == "edit_agent":
changes = args.get("changes", "")
changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes
pending_msg = (
f"Editing agent: {changes_preview}"
if changes_preview
else "Editing agent... This may take a few minutes."
)
started_msg = (
"Agent edit started. You can close this tab - "
"check your library in a few minutes."
)
else:
pending_msg = f"Running {tool_name}... This may take a few minutes."
started_msg = (
f"{tool_name} started. You can close this tab - "
"check back in a few minutes."
)
# --- Register task in Redis for SSE reconnection ---
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
operation_id=operation_id,
)
# --- Save OperationPendingResponse to chat history ---
pending_message = ChatMessage(
role="tool",
content=OperationPendingResponse(
message=pending_msg,
operation_id=operation_id,
tool_name=tool_name,
).model_dump_json(),
tool_call_id=tool_call_id,
)
session.messages.append(pending_message)
await upsert_chat_session(session)
# --- Spawn background task (reuses non-SDK infrastructure) ---
bg_task = asyncio.create_task(
_execute_long_running_tool_with_streaming(
tool_name=tool_name,
parameters=args,
tool_call_id=tool_call_id,
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
user_id=user_id,
)
)
_background_tasks.add(bg_task)
bg_task.add_done_callback(_background_tasks.discard)
await stream_registry.set_task_asyncio_task(task_id, bg_task)
logger.info(
f"[SDK] Long-running tool {tool_name} delegated to background "
f"(operation_id={operation_id}, task_id={task_id})"
)
# --- Return OperationStartedResponse as MCP tool result ---
# This flows through SDK → response adapter → frontend, triggering
# the loading widget with SSE reconnection support.
started_json = OperationStartedResponse(
message=started_msg,
operation_id=operation_id,
tool_name=tool_name,
task_id=task_id,
).model_dump_json()
return {
"content": [{"type": "text", "text": started_json}],
"isError": False,
}
return _callback
def _resolve_sdk_model() -> str | None:
"""Resolve the model name for the Claude Agent SDK CLI.
Uses ``config.claude_agent_model`` if set, otherwise derives from
``config.model`` by stripping the OpenRouter provider prefix (e.g.,
``"anthropic/claude-opus-4.6"`` → ``"claude-opus-4.6"``).
"""
if config.claude_agent_model:
return config.claude_agent_model
model = config.model
if "/" in model:
return model.split("/", 1)[1]
return model
def _build_sdk_env() -> dict[str, str]:
"""Build env vars for the SDK CLI process.
Routes API calls through OpenRouter (or a custom base_url) using
the same ``config.api_key`` / ``config.base_url`` as the non-SDK path.
This gives per-call token and cost tracking on the OpenRouter dashboard.
Only overrides ``ANTHROPIC_API_KEY`` when a valid proxy URL and auth
token are both present — otherwise returns an empty dict so the SDK
falls back to its default credentials.
"""
env: dict[str, str] = {}
if config.api_key and config.base_url:
# Strip /v1 suffix — SDK expects the base URL without a version path
base = config.base_url.rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
if not base or not base.startswith("http"):
# Invalid base_url — don't override SDK defaults
return env
env["ANTHROPIC_BASE_URL"] = base
env["ANTHROPIC_AUTH_TOKEN"] = config.api_key
# Must be explicitly empty so the CLI uses AUTH_TOKEN instead
env["ANTHROPIC_API_KEY"] = ""
return env
def _make_sdk_cwd(session_id: str) -> str:
"""Create a safe, session-specific working directory path.
Delegates to :func:`~backend.api.features.chat.tools.sandbox.make_session_path`
(single source of truth for path sanitization) and adds a defence-in-depth
assertion.
"""
cwd = make_session_path(session_id)
# Defence-in-depth: normpath + startswith is a CodeQL-recognised sanitizer
cwd = os.path.normpath(cwd)
if not cwd.startswith(_SDK_CWD_PREFIX):
raise ValueError(f"SDK cwd escaped prefix: {cwd}")
return cwd
def _cleanup_sdk_tool_results(cwd: str) -> None:
"""Remove SDK tool-result files for a specific session working directory.
The SDK creates tool-result files under ~/.claude/projects/<encoded-cwd>/tool-results/.
We clean only the specific cwd's results to avoid race conditions between
concurrent sessions.
Security: cwd MUST be created by _make_sdk_cwd() which sanitizes session_id.
"""
import shutil
# Validate cwd is under the expected prefix
normalized = os.path.normpath(cwd)
if not normalized.startswith(_SDK_CWD_PREFIX):
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
return
# SDK encodes the cwd path by replacing '/' with '-'
encoded_cwd = normalized.replace("/", "-")
# Construct the project directory path (known-safe home expansion)
claude_projects = os.path.expanduser("~/.claude/projects")
project_dir = os.path.join(claude_projects, encoded_cwd)
# Security check 3: Validate project_dir is under ~/.claude/projects
project_dir = os.path.normpath(project_dir)
if not project_dir.startswith(claude_projects):
logger.warning(
f"[SDK] Rejecting cleanup for escaped project path: {project_dir}"
)
return
results_dir = os.path.join(project_dir, "tool-results")
if os.path.isdir(results_dir):
for filename in os.listdir(results_dir):
file_path = os.path.join(results_dir, filename)
try:
if os.path.isfile(file_path):
os.remove(file_path)
except OSError:
pass
# Also clean up the temp cwd directory itself
try:
shutil.rmtree(normalized, ignore_errors=True)
except OSError:
pass
async def _compress_conversation_history(
session: ChatSession,
) -> list[ChatMessage]:
"""Compress prior conversation messages if they exceed the token threshold.
Uses the shared compress_context() from prompt.py which supports:
- LLM summarization of old messages (keeps recent ones intact)
- Progressive content truncation as fallback
- Middle-out deletion as last resort
Returns the compressed prior messages (everything except the current message).
"""
prior = session.messages[:-1]
if len(prior) < 2:
return prior
from backend.util.prompt import compress_context
# Convert ChatMessages to dicts for compress_context
messages_dict = []
for msg in prior:
msg_dict: dict[str, Any] = {"role": msg.role}
if msg.content:
msg_dict["content"] = msg.content
if msg.tool_calls:
msg_dict["tool_calls"] = msg.tool_calls
if msg.tool_call_id:
msg_dict["tool_call_id"] = msg.tool_call_id
messages_dict.append(msg_dict)
try:
import openai
async with openai.AsyncOpenAI(
api_key=config.api_key, base_url=config.base_url, timeout=30.0
) as client:
result = await compress_context(
messages=messages_dict,
model=config.model,
client=client,
)
except Exception as e:
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
# Fall back to truncation-only (no LLM summarization)
result = await compress_context(
messages=messages_dict,
model=config.model,
client=None,
)
if result.was_compacted:
logger.info(
f"[SDK] Context compacted: {result.original_token_count} -> "
f"{result.token_count} tokens "
f"({result.messages_summarized} summarized, "
f"{result.messages_dropped} dropped)"
)
# Convert compressed dicts back to ChatMessages
return [
ChatMessage(
role=m["role"],
content=m.get("content"),
tool_calls=m.get("tool_calls"),
tool_call_id=m.get("tool_call_id"),
)
for m in result.messages
]
return prior
def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
"""Format conversation messages into a context prefix for the user message.
Returns a string like:
<conversation_history>
User: hello
You responded: Hi! How can I help?
</conversation_history>
Returns None if there are no messages to format.
"""
if not messages:
return None
lines: list[str] = []
for msg in messages:
if not msg.content:
continue
if msg.role == "user":
lines.append(f"User: {msg.content}")
elif msg.role == "assistant":
lines.append(f"You responded: {msg.content}")
# Skip tool messages — they're internal details
if not lines:
return None
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
async def stream_chat_completion_sdk(
session_id: str,
message: str | None = None,
tool_call_response: str | None = None, # noqa: ARG001
is_user_message: bool = True,
user_id: str | None = None,
retry_count: int = 0, # noqa: ARG001
session: ChatSession | None = None,
context: dict[str, str] | None = None, # noqa: ARG001
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Stream chat completion using Claude Agent SDK.
Drop-in replacement for stream_chat_completion with improved reliability.
"""
if session is None:
session = await get_chat_session(session_id, user_id)
if not session:
raise NotFoundError(
f"Session {session_id} not found. Please create a new session first."
)
if message:
session.messages.append(
ChatMessage(
role="user" if is_user_message else "assistant", content=message
)
)
if is_user_message:
track_user_message(
user_id=user_id, session_id=session_id, message_length=len(message)
)
session = await upsert_chat_session(session)
# Generate title for new sessions (first user message)
if is_user_message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
if len(user_messages) == 1:
first_message = user_messages[0].content or message or ""
if first_message:
task = asyncio.create_task(
_update_title_async(session_id, first_message, user_id)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
# Build system prompt (reuses non-SDK path with Langfuse support)
has_history = len(session.messages) > 1
system_prompt, _ = await _build_system_prompt(
user_id, has_conversation_history=has_history
)
system_prompt += _SDK_TOOL_SUPPLEMENT
message_id = str(uuid.uuid4())
task_id = str(uuid.uuid4())
yield StreamStart(messageId=message_id, taskId=task_id)
stream_completed = False
# Initialise sdk_cwd before the try so the finally can reference it
# even if _make_sdk_cwd raises (in that case it stays as "").
sdk_cwd = ""
use_resume = False
try:
# Use a session-specific temp dir to avoid cleanup race conditions
# between concurrent sessions.
sdk_cwd = _make_sdk_cwd(session_id)
os.makedirs(sdk_cwd, exist_ok=True)
set_execution_context(
user_id,
session,
long_running_callback=_build_long_running_callback(user_id),
)
try:
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
# Fail fast when no API credentials are available at all
sdk_env = _build_sdk_env()
if not sdk_env and not os.environ.get("ANTHROPIC_API_KEY"):
raise RuntimeError(
"No API key configured. Set OPEN_ROUTER_API_KEY "
"(or CHAT_API_KEY) for OpenRouter routing, "
"or ANTHROPIC_API_KEY for direct Anthropic access."
)
mcp_server = create_copilot_mcp_server()
sdk_model = _resolve_sdk_model()
# --- Transcript capture via Stop hook ---
captured_transcript = CapturedTranscript()
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
captured_transcript.path = transcript_path
captured_transcript.sdk_session_id = sdk_session_id
security_hooks = create_security_hooks(
user_id,
sdk_cwd=sdk_cwd,
max_subtasks=config.claude_agent_max_subtasks,
on_stop=_on_stop if config.claude_agent_use_resume else None,
)
# --- Resume strategy: download transcript from bucket ---
resume_file: str | None = None
use_resume = False
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
transcript_content = await download_transcript(user_id, session_id)
if transcript_content and validate_transcript(transcript_content):
resume_file = write_transcript_to_tempfile(
transcript_content, session_id, sdk_cwd
)
if resume_file:
use_resume = True
logger.info(
f"[SDK] Using --resume with transcript "
f"({len(transcript_content)} bytes)"
)
sdk_options_kwargs: dict[str, Any] = {
"system_prompt": system_prompt,
"mcp_servers": {"copilot": mcp_server},
"allowed_tools": COPILOT_TOOL_NAMES,
"disallowed_tools": ["Bash"],
"hooks": security_hooks,
"cwd": sdk_cwd,
"max_buffer_size": config.claude_agent_max_buffer_size,
}
if sdk_env:
sdk_options_kwargs["model"] = sdk_model
sdk_options_kwargs["env"] = sdk_env
if use_resume and resume_file:
sdk_options_kwargs["resume"] = resume_file
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type]
adapter = SDKResponseAdapter(message_id=message_id)
adapter.set_task_id(task_id)
async with ClaudeSDKClient(options=options) as client:
current_message = message or ""
if not current_message and session.messages:
last_user = [m for m in session.messages if m.role == "user"]
if last_user:
current_message = last_user[-1].content or ""
if not current_message.strip():
yield StreamError(
errorText="Message cannot be empty.",
code="empty_prompt",
)
yield StreamFinish()
return
# Build query: with --resume the CLI already has full
# context, so we only send the new message. Without
# resume, compress history into a context prefix.
query_message = current_message
if not use_resume and len(session.messages) > 1:
logger.warning(
f"[SDK] Using compression fallback for session "
f"{session_id} ({len(session.messages)} messages) — "
f"no transcript available for --resume"
)
compressed = await _compress_conversation_history(session)
history_context = _format_conversation_context(compressed)
if history_context:
query_message = (
f"{history_context}\n\n"
f"Now, the user says:\n{current_message}"
)
logger.info(
f"[SDK] Sending query ({len(session.messages)} msgs in session)"
)
logger.debug(f"[SDK] Query preview: {current_message[:80]!r}")
await client.query(query_message, session_id=session_id)
assistant_response = ChatMessage(role="assistant", content="")
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False
async for sdk_msg in client.receive_messages():
logger.debug(
f"[SDK] Received: {type(sdk_msg).__name__} "
f"{getattr(sdk_msg, 'subtype', '')}"
)
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
continue
yield response
if isinstance(response, StreamTextDelta):
delta = response.delta or ""
# After tool results, start a new assistant
# message for the post-tool text.
if has_tool_results and has_appended_assistant:
assistant_response = ChatMessage(
role="assistant", content=delta
)
accumulated_tool_calls = []
has_appended_assistant = False
has_tool_results = False
session.messages.append(assistant_response)
has_appended_assistant = True
else:
assistant_response.content = (
assistant_response.content or ""
) + delta
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolInputAvailable):
accumulated_tool_calls.append(
{
"id": response.toolCallId,
"type": "function",
"function": {
"name": response.toolName,
"arguments": json.dumps(response.input or {}),
},
}
)
assistant_response.tool_calls = accumulated_tool_calls
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolOutputAvailable):
session.messages.append(
ChatMessage(
role="tool",
content=(
response.output
if isinstance(response.output, str)
else str(response.output)
),
tool_call_id=response.toolCallId,
)
)
has_tool_results = True
elif isinstance(response, StreamFinish):
stream_completed = True
if stream_completed:
break
if (
assistant_response.content or assistant_response.tool_calls
) and not has_appended_assistant:
session.messages.append(assistant_response)
# --- Capture transcript while CLI is still alive ---
# Must happen INSIDE async with: close() sends SIGTERM
# which kills the CLI before it can flush the JSONL.
if (
config.claude_agent_use_resume
and user_id
and captured_transcript.available
):
# Give CLI time to flush JSONL writes before we read
await asyncio.sleep(0.5)
raw_transcript = read_transcript_file(captured_transcript.path)
if raw_transcript:
task = asyncio.create_task(
_upload_transcript_bg(user_id, session_id, raw_transcript)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
else:
logger.debug("[SDK] Stop hook fired but transcript not usable")
except ImportError:
raise RuntimeError(
"claude-agent-sdk is not installed. "
"Disable SDK mode (CHAT_USE_CLAUDE_AGENT_SDK=false) "
"to use the OpenAI-compatible fallback."
)
await upsert_chat_session(session)
logger.debug(
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
)
if not stream_completed:
yield StreamFinish()
except Exception as e:
logger.error(f"[SDK] Error: {e}", exc_info=True)
try:
await upsert_chat_session(session)
except Exception as save_err:
logger.error(f"[SDK] Failed to save session on error: {save_err}")
yield StreamError(
errorText="An error occurred. Please try again.",
code="sdk_error",
)
yield StreamFinish()
finally:
if sdk_cwd:
_cleanup_sdk_tool_results(sdk_cwd)
async def _upload_transcript_bg(
user_id: str, session_id: str, raw_content: str
) -> None:
"""Background task to strip progress entries and upload transcript."""
try:
await upload_transcript(user_id, session_id, raw_content)
except Exception as e:
logger.error(f"[SDK] Failed to upload transcript for {session_id}: {e}")
async def _update_title_async(
session_id: str, message: str, user_id: str | None = None
) -> None:
"""Background task to update session title."""
try:
title = await _generate_session_title(
message, user_id=user_id, session_id=session_id
)
if title:
await update_session_title(session_id, title)
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
except Exception as e:
logger.warning(f"[SDK] Failed to update session title: {e}")

View File

@@ -1,322 +0,0 @@
"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools.
This module provides the adapter layer that converts existing BaseTool implementations
into in-process MCP tools that can be used with the Claude Agent SDK.
Long-running tools (``is_long_running=True``) are delegated to the non-SDK
background infrastructure (stream_registry, Redis persistence, SSE reconnection)
via a callback provided by the service layer. This avoids wasteful SDK polling
and makes results survive page refreshes.
"""
import itertools
import json
import logging
import os
import uuid
from collections.abc import Awaitable, Callable
from contextvars import ContextVar
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools import TOOL_REGISTRY
from backend.api.features.chat.tools.base import BaseTool
logger = logging.getLogger(__name__)
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
# in the path — prevents reading settings, credentials, or other sensitive files.
_SDK_PROJECTS_DIR = os.path.expanduser("~/.claude/projects/")
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
MCP_SERVER_NAME = "copilot"
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
# Context variables to pass user/session info to tool execution
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
_current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
# Stash for MCP tool outputs before the SDK potentially truncates them.
# Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable.
_pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar(
"pending_tool_outputs", default=None # type: ignore[arg-type]
)
# Callback type for delegating long-running tools to the non-SDK infrastructure.
# Args: (tool_name, arguments, session) → MCP-formatted response dict.
LongRunningCallback = Callable[
[str, dict[str, Any], ChatSession], Awaitable[dict[str, Any]]
]
# ContextVar so the service layer can inject the callback per-request.
_long_running_callback: ContextVar[LongRunningCallback | None] = ContextVar(
"long_running_callback", default=None
)
def set_execution_context(
user_id: str | None,
session: ChatSession,
long_running_callback: LongRunningCallback | None = None,
) -> None:
"""Set the execution context for tool calls.
This must be called before streaming begins to ensure tools have access
to user_id and session information.
Args:
user_id: Current user's ID.
session: Current chat session.
long_running_callback: Optional callback to delegate long-running tools
to the non-SDK background infrastructure (stream_registry + Redis).
"""
_current_user_id.set(user_id)
_current_session.set(session)
_pending_tool_outputs.set({})
_long_running_callback.set(long_running_callback)
def get_execution_context() -> tuple[str | None, ChatSession | None]:
"""Get the current execution context."""
return (
_current_user_id.get(),
_current_session.get(),
)
def pop_pending_tool_output(tool_name: str) -> str | None:
"""Pop and return the stashed full output for *tool_name*.
The SDK CLI may truncate large tool results (writing them to disk and
replacing the content with a file reference). This stash keeps the
original MCP output so the response adapter can forward it to the
frontend for proper widget rendering.
Returns ``None`` if nothing was stashed for *tool_name*.
"""
pending = _pending_tool_outputs.get(None)
if pending is None:
return None
return pending.pop(tool_name, None)
async def _execute_tool_sync(
base_tool: BaseTool,
user_id: str | None,
session: ChatSession,
args: dict[str, Any],
) -> dict[str, Any]:
"""Execute a tool synchronously and return MCP-formatted response."""
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
result = await base_tool.execute(
user_id=user_id,
session=session,
tool_call_id=effective_id,
**args,
)
text = (
result.output if isinstance(result.output, str) else json.dumps(result.output)
)
# Stash the full output before the SDK potentially truncates it.
pending = _pending_tool_outputs.get(None)
if pending is not None:
pending[base_tool.name] = text
return {
"content": [{"type": "text", "text": text}],
"isError": not result.success,
}
def _mcp_error(message: str) -> dict[str, Any]:
return {
"content": [
{"type": "text", "text": json.dumps({"error": message, "type": "error"})}
],
"isError": True,
}
def create_tool_handler(base_tool: BaseTool):
"""Create an async handler function for a BaseTool.
This wraps the existing BaseTool._execute method to be compatible
with the Claude Agent SDK MCP tool format.
Long-running tools (``is_long_running=True``) are delegated to the
non-SDK background infrastructure via a callback set in the execution
context. The callback persists the operation in Redis (stream_registry)
so results survive page refreshes and pod restarts.
"""
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Execute the wrapped tool and return MCP-formatted response."""
user_id, session = get_execution_context()
if session is None:
return _mcp_error("No session context available")
# --- Long-running: delegate to non-SDK background infrastructure ---
if base_tool.is_long_running:
callback = _long_running_callback.get(None)
if callback:
try:
return await callback(base_tool.name, args, session)
except Exception as e:
logger.error(
f"Long-running callback failed for {base_tool.name}: {e}",
exc_info=True,
)
return _mcp_error(f"Failed to start {base_tool.name}: {e}")
# No callback — fall through to synchronous execution
logger.warning(
f"[SDK] No long-running callback for {base_tool.name}, "
f"executing synchronously (may block)"
)
# --- Normal (fast) tool: execute synchronously ---
try:
return await _execute_tool_sync(base_tool, user_id, session, args)
except Exception as e:
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
return tool_handler
def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
"""Build a JSON Schema input schema for a tool."""
return {
"type": "object",
"properties": base_tool.parameters.get("properties", {}),
"required": base_tool.parameters.get("required", []),
}
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Read a file with optional offset/limit. Restricted to SDK working directory.
After reading, the file is deleted to prevent accumulation in long-running pods.
"""
file_path = args.get("file_path", "")
offset = args.get("offset", 0)
limit = args.get("limit", 2000)
# Security: only allow reads under ~/.claude/projects/**/tool-results/
real_path = os.path.realpath(file_path)
if not real_path.startswith(_SDK_PROJECTS_DIR) or "tool-results" not in real_path:
return {
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
"isError": True,
}
try:
with open(real_path) as f:
selected = list(itertools.islice(f, offset, offset + limit))
content = "".join(selected)
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
# don't delete here — the SDK may read in multiple chunks.
return {"content": [{"type": "text", "text": content}], "isError": False}
except FileNotFoundError:
return {
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
"isError": True,
}
except Exception as e:
return {
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
"isError": True,
}
_READ_TOOL_NAME = "Read"
_READ_TOOL_DESCRIPTION = (
"Read a file from the local filesystem. "
"Use offset and limit to read specific line ranges for large files."
)
_READ_TOOL_SCHEMA = {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "The absolute path to the file to read",
},
"offset": {
"type": "integer",
"description": "Line number to start reading from (0-indexed). Default: 0",
},
"limit": {
"type": "integer",
"description": "Number of lines to read. Default: 2000",
},
},
"required": ["file_path"],
}
# Create the MCP server configuration
def create_copilot_mcp_server():
"""Create an in-process MCP server configuration for CoPilot tools.
This can be passed to ClaudeAgentOptions.mcp_servers.
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
package being available. This function returns the configuration that
can be used with the SDK.
"""
try:
from claude_agent_sdk import create_sdk_mcp_server, tool
# Create decorated tool functions
sdk_tools = []
for tool_name, base_tool in TOOL_REGISTRY.items():
handler = create_tool_handler(base_tool)
decorated = tool(
tool_name,
base_tool.description,
_build_input_schema(base_tool),
)(handler)
sdk_tools.append(decorated)
# Add the Read tool so the SDK can read back oversized tool results
read_tool = tool(
_READ_TOOL_NAME,
_READ_TOOL_DESCRIPTION,
_READ_TOOL_SCHEMA,
)(_read_file_handler)
sdk_tools.append(read_tool)
server = create_sdk_mcp_server(
name=MCP_SERVER_NAME,
version="1.0.0",
tools=sdk_tools,
)
return server
except ImportError:
# Let ImportError propagate so service.py handles the fallback
raise
# SDK built-in tools allowed within the workspace directory.
# Security hooks validate that file paths stay within sdk_cwd.
# Bash is NOT included — use the sandboxed MCP bash_exec tool instead,
# which provides kernel-level network isolation via unshare --net.
# Task allows spawning sub-agents (rate-limited by security hooks).
_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Task"]
# List of tool names for allowed_tools configuration
# Include MCP tools, the MCP Read tool for oversized results,
# and SDK built-in file tools for workspace operations.
COPILOT_TOOL_NAMES = [
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
*_SDK_BUILTIN_TOOLS,
]

View File

@@ -1,356 +0,0 @@
"""JSONL transcript management for stateless multi-turn resume.
The Claude Code CLI persists conversations as JSONL files (one JSON object per
line). When the SDK's ``Stop`` hook fires we read this file, strip bloat
(progress entries, metadata), and upload the result to bucket storage. On the
next turn we download the transcript, write it to a temp file, and pass
``--resume`` so the CLI can reconstruct the full conversation.
Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
filesystem for self-hosted) — no DB column needed.
"""
import json
import logging
import os
import re
logger = logging.getLogger(__name__)
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
_SAFE_ID_RE = re.compile(r"[^0-9a-fA-F-]")
# Entry types that can be safely removed from the transcript without breaking
# the parentUuid conversation tree that ``--resume`` relies on.
# - progress: UI progress ticks, no message content (avg 97KB for agent_progress)
# - file-history-snapshot: undo tracking metadata
# - queue-operation: internal queue bookkeeping
# - summary: session summaries
# - pr-link: PR link metadata
STRIPPABLE_TYPES = frozenset(
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
)
# Workspace storage constants — deterministic path from session_id.
TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts"
# ---------------------------------------------------------------------------
# Progress stripping
# ---------------------------------------------------------------------------
def strip_progress_entries(content: str) -> str:
"""Remove progress/metadata entries from a JSONL transcript.
Removes entries whose ``type`` is in ``STRIPPABLE_TYPES`` and reparents
any remaining child entries so the ``parentUuid`` chain stays intact.
Typically reduces transcript size by ~30%.
"""
lines = content.strip().split("\n")
entries: list[dict] = []
for line in lines:
try:
entries.append(json.loads(line))
except json.JSONDecodeError:
# Keep unparseable lines as-is (safety)
entries.append({"_raw": line})
stripped_uuids: set[str] = set()
uuid_to_parent: dict[str, str] = {}
kept: list[dict] = []
for entry in entries:
if "_raw" in entry:
kept.append(entry)
continue
uid = entry.get("uuid", "")
parent = entry.get("parentUuid", "")
entry_type = entry.get("type", "")
if uid:
uuid_to_parent[uid] = parent
if entry_type in STRIPPABLE_TYPES:
if uid:
stripped_uuids.add(uid)
else:
kept.append(entry)
# Reparent: walk up chain through stripped entries to find surviving ancestor
for entry in kept:
if "_raw" in entry:
continue
parent = entry.get("parentUuid", "")
original_parent = parent
while parent in stripped_uuids:
parent = uuid_to_parent.get(parent, "")
if parent != original_parent:
entry["parentUuid"] = parent
result_lines: list[str] = []
for entry in kept:
if "_raw" in entry:
result_lines.append(entry["_raw"])
else:
result_lines.append(json.dumps(entry, separators=(",", ":")))
return "\n".join(result_lines) + "\n"
# ---------------------------------------------------------------------------
# Local file I/O (read from CLI's JSONL, write temp file for --resume)
# ---------------------------------------------------------------------------
def read_transcript_file(transcript_path: str) -> str | None:
"""Read a JSONL transcript file from disk.
Returns the raw JSONL content, or ``None`` if the file is missing, empty,
or only contains metadata (≤2 lines with no conversation messages).
"""
if not transcript_path or not os.path.isfile(transcript_path):
logger.debug(f"[Transcript] File not found: {transcript_path}")
return None
try:
with open(transcript_path) as f:
content = f.read()
if not content.strip():
logger.debug(f"[Transcript] Empty file: {transcript_path}")
return None
lines = content.strip().split("\n")
if len(lines) < 3:
# Raw files with ≤2 lines are metadata-only
# (queue-operation + file-history-snapshot, no conversation).
logger.debug(
f"[Transcript] Too few lines ({len(lines)}): {transcript_path}"
)
return None
# Quick structural validation — parse first and last lines.
json.loads(lines[0])
json.loads(lines[-1])
logger.info(
f"[Transcript] Read {len(lines)} lines, "
f"{len(content)} bytes from {transcript_path}"
)
return content
except (json.JSONDecodeError, OSError) as e:
logger.warning(f"[Transcript] Failed to read {transcript_path}: {e}")
return None
def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
"""Sanitize an ID for safe use in file paths.
Session/user IDs are expected to be UUIDs (hex + hyphens). Strip
everything else and truncate to *max_len* so the result cannot introduce
path separators or other special characters.
"""
cleaned = _SAFE_ID_RE.sub("", raw_id or "")[:max_len]
return cleaned or "unknown"
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
def write_transcript_to_tempfile(
transcript_content: str,
session_id: str,
cwd: str,
) -> str | None:
"""Write JSONL transcript to a temp file inside *cwd* for ``--resume``.
The file lives in the session working directory so it is cleaned up
automatically when the session ends.
Returns the absolute path to the file, or ``None`` on failure.
"""
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
real_cwd = os.path.realpath(cwd)
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
return None
try:
os.makedirs(real_cwd, exist_ok=True)
safe_id = _sanitize_id(session_id, max_len=8)
jsonl_path = os.path.realpath(
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
)
if not jsonl_path.startswith(real_cwd):
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
return None
with open(jsonl_path, "w") as f:
f.write(transcript_content)
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
return jsonl_path
except OSError as e:
logger.warning(f"[Transcript] Failed to write resume file: {e}")
return None
def validate_transcript(content: str | None) -> bool:
"""Check that a transcript has actual conversation messages.
A valid transcript for resume needs at least one user message and one
assistant message (not just queue-operation / file-history-snapshot
metadata).
"""
if not content or not content.strip():
return False
lines = content.strip().split("\n")
if len(lines) < 2:
return False
has_user = False
has_assistant = False
for line in lines:
try:
entry = json.loads(line)
msg_type = entry.get("type")
if msg_type == "user":
has_user = True
elif msg_type == "assistant":
has_assistant = True
except json.JSONDecodeError:
return False
return has_user and has_assistant
# ---------------------------------------------------------------------------
# Bucket storage (GCS / local via WorkspaceStorageBackend)
# ---------------------------------------------------------------------------
def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
"""Return (workspace_id, file_id, filename) for a session's transcript.
Path structure: ``chat-transcripts/{user_id}/{session_id}.jsonl``
IDs are sanitized to hex+hyphen to prevent path traversal.
"""
return (
TRANSCRIPT_STORAGE_PREFIX,
_sanitize_id(user_id),
f"{_sanitize_id(session_id)}.jsonl",
)
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
"""Build the full storage path string that ``retrieve()`` expects.
``store()`` returns a path like ``gcs://bucket/workspaces/...`` or
``local://workspace_id/file_id/filename``. Since we use deterministic
arguments we can reconstruct the same path for download/delete without
having stored the return value.
"""
from backend.util.workspace_storage import GCSWorkspaceStorage
wid, fid, fname = _storage_path_parts(user_id, session_id)
if isinstance(backend, GCSWorkspaceStorage):
blob = f"workspaces/{wid}/{fid}/{fname}"
return f"gcs://{backend.bucket_name}/{blob}"
else:
# LocalWorkspaceStorage returns local://{relative_path}
return f"local://{wid}/{fid}/{fname}"
async def upload_transcript(user_id: str, session_id: str, content: str) -> None:
"""Strip progress entries and upload transcript to bucket storage.
Safety: only overwrites when the new (stripped) transcript is larger than
what is already stored. Since JSONL is append-only, the latest transcript
is always the longest. This prevents a slow/stale background task from
clobbering a newer upload from a concurrent turn.
"""
from backend.util.workspace_storage import get_workspace_storage
stripped = strip_progress_entries(content)
if not validate_transcript(stripped):
logger.warning(
f"[Transcript] Skipping upload — stripped content is not a valid "
f"transcript for session {session_id}"
)
return
storage = await get_workspace_storage()
wid, fid, fname = _storage_path_parts(user_id, session_id)
encoded = stripped.encode("utf-8")
new_size = len(encoded)
# Check existing transcript size to avoid overwriting newer with older
path = _build_storage_path(user_id, session_id, storage)
try:
existing = await storage.retrieve(path)
if len(existing) >= new_size:
logger.info(
f"[Transcript] Skipping upload — existing transcript "
f"({len(existing)}B) >= new ({new_size}B) for session "
f"{session_id}"
)
return
except (FileNotFoundError, Exception):
pass # No existing transcript or retrieval error — proceed with upload
await storage.store(
workspace_id=wid,
file_id=fid,
filename=fname,
content=encoded,
)
logger.info(
f"[Transcript] Uploaded {new_size} bytes "
f"(stripped from {len(content)}) for session {session_id}"
)
async def download_transcript(user_id: str, session_id: str) -> str | None:
"""Download transcript from bucket storage.
Returns the JSONL content string, or ``None`` if not found.
"""
from backend.util.workspace_storage import get_workspace_storage
storage = await get_workspace_storage()
path = _build_storage_path(user_id, session_id, storage)
try:
data = await storage.retrieve(path)
content = data.decode("utf-8")
logger.info(
f"[Transcript] Downloaded {len(content)} bytes for session {session_id}"
)
return content
except FileNotFoundError:
logger.debug(f"[Transcript] No transcript in storage for {session_id}")
return None
except Exception as e:
logger.warning(f"[Transcript] Failed to download transcript: {e}")
return None
async def delete_transcript(user_id: str, session_id: str) -> None:
"""Delete transcript from bucket storage (e.g. after resume failure)."""
from backend.util.workspace_storage import get_workspace_storage
storage = await get_workspace_storage()
path = _build_storage_path(user_id, session_id, storage)
try:
await storage.delete(path)
logger.info(f"[Transcript] Deleted transcript for session {session_id}")
except Exception as e:
logger.warning(f"[Transcript] Failed to delete transcript: {e}")

View File

@@ -245,16 +245,12 @@ async def _get_system_prompt_template(context: str) -> str:
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
async def _build_system_prompt(
user_id: str | None, has_conversation_history: bool = False
) -> tuple[str, Any]:
async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
"""Build the full system prompt including business understanding if available.
Args:
user_id: The user ID for fetching business understanding.
has_conversation_history: Whether there's existing conversation history.
If True, we don't tell the model to greet/introduce (since they're
already in a conversation).
user_id: The user ID for fetching business understanding
If "default" and this is the user's first session, will use "onboarding" instead.
Returns:
Tuple of (compiled prompt string, business understanding object)
@@ -270,8 +266,6 @@ async def _build_system_prompt(
if understanding:
context = format_understanding_for_prompt(understanding)
elif has_conversation_history:
context = "No prior understanding saved yet. Continue the existing conversation naturally."
else:
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
@@ -380,6 +374,7 @@ async def stream_chat_completion(
Raises:
NotFoundError: If session_id is invalid
ValueError: If max_context_messages is exceeded
"""
completion_start = time.monotonic()
@@ -464,9 +459,8 @@ async def stream_chat_completion(
# Generate title for new sessions on first user message (non-blocking)
# Check: is_user_message, no title yet, and this is the first user message
user_messages = [m for m in session.messages if m.role == "user"]
first_user_msg = message or (user_messages[0].content if user_messages else None)
if is_user_message and first_user_msg and not session.title:
if is_user_message and message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
if len(user_messages) == 1:
# First user message - generate title in background
import asyncio
@@ -474,7 +468,7 @@ async def stream_chat_completion(
# Capture only the values we need (not the session object) to avoid
# stale data issues when the main flow modifies the session
captured_session_id = session_id
captured_message = first_user_msg
captured_message = message
captured_user_id = user_id
async def _update_title():
@@ -1243,7 +1237,7 @@ async def _stream_chat_chunks(
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
logger.info(
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; "
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; "
f"session={session.session_id}, user={session.user_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
)
@@ -1251,7 +1245,6 @@ async def _stream_chat_chunks(
return
except Exception as e:
last_error = e
if _is_retryable_error(e) and retry_count < MAX_RETRIES:
retry_count += 1
# Calculate delay with exponential backoff
@@ -1267,27 +1260,12 @@ async def _stream_chat_chunks(
continue # Retry the stream
else:
# Non-retryable error or max retries exceeded
_log_api_error(
error=e,
context="stream (not retrying)",
session_id=session.session_id if session else None,
message_count=len(messages) if messages else None,
model=model,
retry_count=retry_count,
logger.error(
f"Error in stream (not retrying): {e!s}",
exc_info=True,
)
error_code = None
error_text = str(e)
error_details = _extract_api_error_details(e)
if error_details.get("response_body"):
body = error_details["response_body"]
if isinstance(body, dict):
err = body.get("error")
if isinstance(err, dict) and err.get("message"):
error_text = err["message"]
elif body.get("message"):
error_text = body["message"]
if _is_region_blocked_error(e):
error_code = "MODEL_NOT_AVAILABLE_REGION"
error_text = (
@@ -1304,13 +1282,9 @@ async def _stream_chat_chunks(
# If we exit the retry loop without returning, it means we exhausted retries
if last_error:
_log_api_error(
error=last_error,
context=f"stream (max retries {MAX_RETRIES} exceeded)",
session_id=session.session_id if session else None,
message_count=len(messages) if messages else None,
model=model,
retry_count=MAX_RETRIES,
logger.error(
f"Max retries ({MAX_RETRIES}) exceeded. Last error: {last_error!s}",
exc_info=True,
)
yield StreamError(errorText=f"Max retries exceeded: {last_error!s}")
yield StreamFinish()
@@ -1883,7 +1857,6 @@ async def _generate_llm_continuation(
break # Success, exit retry loop
except Exception as e:
last_error = e
if _is_retryable_error(e) and retry_count < MAX_RETRIES:
retry_count += 1
delay = min(
@@ -1897,25 +1870,17 @@ async def _generate_llm_continuation(
await asyncio.sleep(delay)
continue
else:
# Non-retryable error - log details and exit gracefully
_log_api_error(
error=e,
context="LLM continuation (not retrying)",
session_id=session_id,
message_count=len(messages) if messages else None,
model=config.model,
retry_count=retry_count,
# Non-retryable error - log and exit gracefully
logger.error(
f"Non-retryable error in LLM continuation: {e!s}",
exc_info=True,
)
return
if last_error:
_log_api_error(
error=last_error,
context=f"LLM continuation (max retries {MAX_RETRIES} exceeded)",
session_id=session_id,
message_count=len(messages) if messages else None,
model=config.model,
retry_count=MAX_RETRIES,
logger.error(
f"Max retries ({MAX_RETRIES}) exceeded for LLM continuation. "
f"Last error: {last_error!s}"
)
return
@@ -1955,91 +1920,6 @@ async def _generate_llm_continuation(
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
def _log_api_error(
error: Exception,
context: str,
session_id: str | None = None,
message_count: int | None = None,
model: str | None = None,
retry_count: int = 0,
) -> None:
"""Log detailed API error information for debugging."""
details = _extract_api_error_details(error)
details["context"] = context
details["session_id"] = session_id
details["message_count"] = message_count
details["model"] = model
details["retry_count"] = retry_count
if isinstance(error, RateLimitError):
logger.warning(f"Rate limit error in {context}: {details}", exc_info=error)
elif isinstance(error, APIConnectionError):
logger.warning(f"API connection error in {context}: {details}", exc_info=error)
elif isinstance(error, APIStatusError) and error.status_code >= 500:
logger.error(f"API server error (5xx) in {context}: {details}", exc_info=error)
else:
logger.error(f"API error in {context}: {details}", exc_info=error)
def _extract_api_error_details(error: Exception) -> dict[str, Any]:
"""Extract detailed information from OpenAI/OpenRouter API errors."""
error_msg = str(error)
details: dict[str, Any] = {
"error_type": type(error).__name__,
"error_message": error_msg[:500] + "..." if len(error_msg) > 500 else error_msg,
}
if hasattr(error, "code"):
details["code"] = getattr(error, "code", None)
if hasattr(error, "param"):
details["param"] = getattr(error, "param", None)
if isinstance(error, APIStatusError):
details["status_code"] = error.status_code
details["request_id"] = getattr(error, "request_id", None)
if hasattr(error, "body") and error.body:
details["response_body"] = _sanitize_error_body(error.body)
if hasattr(error, "response") and error.response:
headers = error.response.headers
details["openrouter_provider"] = headers.get("x-openrouter-provider")
details["openrouter_model"] = headers.get("x-openrouter-model")
details["retry_after"] = headers.get("retry-after")
details["rate_limit_remaining"] = headers.get("x-ratelimit-remaining")
return details
def _sanitize_error_body(
body: Any, max_length: int = 2000
) -> dict[str, Any] | str | None:
"""Extract only safe fields from error response body to avoid logging sensitive data."""
if not isinstance(body, dict):
# Non-dict bodies (e.g., HTML error pages) - return truncated string
if body is not None:
body_str = str(body)
if len(body_str) > max_length:
return body_str[:max_length] + "...[truncated]"
return body_str
return None
safe_fields = ("message", "type", "code", "param", "error")
sanitized: dict[str, Any] = {}
for field in safe_fields:
if field in body:
value = body[field]
if field == "error" and isinstance(value, dict):
sanitized[field] = _sanitize_error_body(value, max_length)
elif isinstance(value, str) and len(value) > max_length:
sanitized[field] = value[:max_length] + "...[truncated]"
else:
sanitized[field] = value
return sanitized if sanitized else None
async def _generate_llm_continuation_with_streaming(
session_id: str,
user_id: str | None,

View File

@@ -1,4 +1,3 @@
import asyncio
import logging
from os import getenv
@@ -12,8 +11,6 @@ from .response_model import (
StreamTextDelta,
StreamToolOutputAvailable,
)
from .sdk import service as sdk_service
from .sdk.transcript import download_transcript
logger = logging.getLogger(__name__)
@@ -83,96 +80,3 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user
session = await get_chat_session(session.session_id)
assert session, "Session not found"
assert session.usage, "Usage is empty"
@pytest.mark.asyncio(loop_scope="session")
async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
"""Test that the SDK --resume path captures and uses transcripts across turns.
Turn 1: Send a message containing a unique keyword.
Turn 2: Ask the model to recall that keyword — proving the transcript was
persisted and restored via --resume.
"""
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
from .config import ChatConfig
cfg = ChatConfig()
if not cfg.claude_agent_use_resume:
return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test")
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
# --- Turn 1: send a message with a unique keyword ---
keyword = "ZEPHYR42"
turn1_msg = (
f"Please remember this special keyword: {keyword}. "
"Just confirm you've noted it, keep your response brief."
)
turn1_text = ""
turn1_errors: list[str] = []
turn1_ended = False
async for chunk in sdk_service.stream_chat_completion_sdk(
session.session_id,
turn1_msg,
user_id=test_user_id,
):
if isinstance(chunk, StreamTextDelta):
turn1_text += chunk.delta
elif isinstance(chunk, StreamError):
turn1_errors.append(chunk.errorText)
elif isinstance(chunk, StreamFinish):
turn1_ended = True
assert turn1_ended, "Turn 1 did not finish"
assert not turn1_errors, f"Turn 1 errors: {turn1_errors}"
assert turn1_text, "Turn 1 produced no text"
# Wait for background upload task to complete (retry up to 5s)
transcript = None
for _ in range(10):
await asyncio.sleep(0.5)
transcript = await download_transcript(test_user_id, session.session_id)
if transcript:
break
assert transcript, (
"Transcript was not uploaded to bucket after turn 1 — "
"Stop hook may not have fired or transcript was too small"
)
logger.info(f"Turn 1 transcript uploaded: {len(transcript)} bytes")
# Reload session for turn 2
session = await get_chat_session(session.session_id, test_user_id)
assert session, "Session not found after turn 1"
# --- Turn 2: ask model to recall the keyword ---
turn2_msg = "What was the special keyword I asked you to remember?"
turn2_text = ""
turn2_errors: list[str] = []
turn2_ended = False
async for chunk in sdk_service.stream_chat_completion_sdk(
session.session_id,
turn2_msg,
user_id=test_user_id,
session=session,
):
if isinstance(chunk, StreamTextDelta):
turn2_text += chunk.delta
elif isinstance(chunk, StreamError):
turn2_errors.append(chunk.errorText)
elif isinstance(chunk, StreamFinish):
turn2_ended = True
assert turn2_ended, "Turn 2 did not finish"
assert not turn2_errors, f"Turn 2 errors: {turn2_errors}"
assert turn2_text, "Turn 2 produced no text"
assert keyword in turn2_text, (
f"Model did not recall keyword '{keyword}' in turn 2. "
f"Response: {turn2_text[:200]}"
)
logger.info(f"Turn 2 recalled keyword successfully: {turn2_text[:100]}")

View File

@@ -814,28 +814,6 @@ async def get_active_task_for_session(
if task_user_id and user_id != task_user_id:
continue
# Auto-expire stale tasks that exceeded stream_timeout
created_at_str = meta.get("created_at", "")
if created_at_str:
try:
created_at = datetime.fromisoformat(created_at_str)
age_seconds = (
datetime.now(timezone.utc) - created_at
).total_seconds()
if age_seconds > config.stream_timeout:
logger.warning(
f"[TASK_LOOKUP] Auto-expiring stale task {task_id[:8]}... "
f"(age={age_seconds:.0f}s > timeout={config.stream_timeout}s)"
)
await mark_task_completed(task_id, "failed")
continue
except (ValueError, TypeError):
pass
logger.info(
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
)
# Get the last message ID from Redis Stream
stream_key = _get_task_stream_key(task_id)
last_id = "0-0"

View File

@@ -9,12 +9,9 @@ from backend.api.features.chat.tracking import track_tool_called
from .add_understanding import AddUnderstandingTool
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
from .check_operation_status import CheckOperationStatusTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
from .edit_agent import EditAgentTool
from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsTool
from .find_agent import FindAgentTool
from .find_block import FindBlockTool
from .find_library_agent import FindLibraryAgentTool
@@ -22,7 +19,6 @@ from .get_doc_page import GetDocPageTool
from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .search_docs import SearchDocsTool
from .web_fetch import WebFetchTool
from .workspace_files import (
DeleteWorkspaceFileTool,
ListWorkspaceFilesTool,
@@ -47,17 +43,8 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"run_agent": RunAgentTool(),
"run_block": RunBlockTool(),
"view_agent_output": AgentOutputTool(),
"check_operation_status": CheckOperationStatusTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
# Web fetch for safe URL retrieval
"web_fetch": WebFetchTool(),
# Sandboxed code execution (bubblewrap)
"bash_exec": BashExecTool(),
# Persistent workspace tools (cloud storage, survives across sessions)
# Feature request tools
"search_feature_requests": SearchFeatureRequestsTool(),
"create_feature_request": CreateFeatureRequestTool(),
# Workspace tools for CoPilot file operations
"list_workspace_files": ListWorkspaceFilesTool(),
"read_workspace_file": ReadWorkspaceFileTool(),

View File

@@ -1,131 +0,0 @@
"""Bash execution tool — run shell commands in a bubblewrap sandbox.
Full Bash scripting is allowed (loops, conditionals, pipes, functions, etc.).
Safety comes from OS-level isolation (bubblewrap): only system dirs visible
read-only, writable workspace only, clean env, no network.
Requires bubblewrap (``bwrap``) — the tool is disabled when bwrap is not
available (e.g. macOS development).
"""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
BashExecResponse,
ErrorResponse,
ToolResponseBase,
)
from backend.api.features.chat.tools.sandbox import (
get_workspace_dir,
has_full_sandbox,
run_sandboxed,
)
logger = logging.getLogger(__name__)
class BashExecTool(BaseTool):
"""Execute Bash commands in a bubblewrap sandbox."""
@property
def name(self) -> str:
return "bash_exec"
@property
def description(self) -> str:
if not has_full_sandbox():
return (
"Bash execution is DISABLED — bubblewrap sandbox is not "
"available on this platform. Do not call this tool."
)
return (
"Execute a Bash command or script in a bubblewrap sandbox. "
"Full Bash scripting is supported (loops, conditionals, pipes, "
"functions, etc.). "
"The sandbox shares the same working directory as the SDK Read/Write "
"tools — files created by either are accessible to both. "
"SECURITY: Only system directories (/usr, /bin, /lib, /etc) are "
"visible read-only, the per-session workspace is the only writable "
"path, environment variables are wiped (no secrets), all network "
"access is blocked at the kernel level, and resource limits are "
"enforced (max 64 processes, 512MB memory, 50MB file size). "
"Application code, configs, and other directories are NOT accessible. "
"To fetch web content, use the web_fetch tool instead. "
"Execution is killed after the timeout (default 30s, max 120s). "
"Returns stdout and stderr. "
"Useful for file manipulation, data processing with Unix tools "
"(grep, awk, sed, jq, etc.), and running shell scripts."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "Bash command or script to execute.",
},
"timeout": {
"type": "integer",
"description": (
"Max execution time in seconds (default 30, max 120)."
),
"default": 30,
},
},
"required": ["command"],
}
@property
def requires_auth(self) -> bool:
return False
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
session_id = session.session_id if session else None
if not has_full_sandbox():
return ErrorResponse(
message="bash_exec requires bubblewrap sandbox (Linux only).",
error="sandbox_unavailable",
session_id=session_id,
)
command: str = (kwargs.get("command") or "").strip()
timeout: int = kwargs.get("timeout", 30)
if not command:
return ErrorResponse(
message="No command provided.",
error="empty_command",
session_id=session_id,
)
workspace = get_workspace_dir(session_id or "default")
stdout, stderr, exit_code, timed_out = await run_sandboxed(
command=["bash", "-c", command],
cwd=workspace,
timeout=timeout,
)
return BashExecResponse(
message=(
"Execution timed out"
if timed_out
else f"Command executed (exit {exit_code})"
),
stdout=stdout,
stderr=stderr,
exit_code=exit_code,
timed_out=timed_out,
session_id=session_id,
)

View File

@@ -1,127 +0,0 @@
"""CheckOperationStatusTool — query the status of a long-running operation."""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
ErrorResponse,
ResponseType,
ToolResponseBase,
)
logger = logging.getLogger(__name__)
class OperationStatusResponse(ToolResponseBase):
"""Response for check_operation_status tool."""
type: ResponseType = ResponseType.OPERATION_STATUS
task_id: str
operation_id: str
status: str # "running", "completed", "failed"
tool_name: str | None = None
message: str = ""
class CheckOperationStatusTool(BaseTool):
"""Check the status of a long-running operation (create_agent, edit_agent, etc.).
The CoPilot uses this tool to report back to the user whether an
operation that was started earlier has completed, failed, or is still
running.
"""
@property
def name(self) -> str:
return "check_operation_status"
@property
def description(self) -> str:
return (
"Check the current status of a long-running operation such as "
"create_agent or edit_agent. Accepts either an operation_id or "
"task_id from a previous operation_started response. "
"Returns the current status: running, completed, or failed."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"operation_id": {
"type": "string",
"description": (
"The operation_id from an operation_started response."
),
},
"task_id": {
"type": "string",
"description": (
"The task_id from an operation_started response. "
"Used as fallback if operation_id is not provided."
),
},
},
"required": [],
}
@property
def requires_auth(self) -> bool:
return False
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
from backend.api.features.chat import stream_registry
operation_id = (kwargs.get("operation_id") or "").strip()
task_id = (kwargs.get("task_id") or "").strip()
if not operation_id and not task_id:
return ErrorResponse(
message="Please provide an operation_id or task_id.",
error="missing_parameter",
)
task = None
if operation_id:
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None and task_id:
task = await stream_registry.get_task(task_id)
if task is None:
# Task not in Redis — it may have already expired (TTL).
# Check conversation history for the result instead.
return ErrorResponse(
message=(
"Operation not found — it may have already completed and "
"expired from the status tracker. Check the conversation "
"history for the result."
),
error="not_found",
)
status_messages = {
"running": (
f"The {task.tool_name or 'operation'} is still running. "
"Please wait for it to complete."
),
"completed": (
f"The {task.tool_name or 'operation'} has completed successfully."
),
"failed": f"The {task.tool_name or 'operation'} has failed.",
}
return OperationStatusResponse(
task_id=task.task_id,
operation_id=task.operation_id,
status=task.status,
tool_name=task.tool_name,
message=status_messages.get(task.status, f"Status: {task.status}"),
)

View File

@@ -1,448 +0,0 @@
"""Feature request tools - search and create feature requests via Linear."""
import logging
from typing import Any
from pydantic import SecretStr
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
ErrorResponse,
FeatureRequestCreatedResponse,
FeatureRequestInfo,
FeatureRequestSearchResponse,
NoResultsResponse,
ToolResponseBase,
)
from backend.blocks.linear._api import LinearClient
from backend.data.model import APIKeyCredentials
from backend.data.user import get_user_email_by_id
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
MAX_SEARCH_RESULTS = 10
# GraphQL queries/mutations
SEARCH_ISSUES_QUERY = """
query SearchFeatureRequests($term: String!, $filter: IssueFilter, $first: Int) {
searchIssues(term: $term, filter: $filter, first: $first) {
nodes {
id
identifier
title
description
}
}
}
"""
CUSTOMER_UPSERT_MUTATION = """
mutation CustomerUpsert($input: CustomerUpsertInput!) {
customerUpsert(input: $input) {
success
customer {
id
name
externalIds
}
}
}
"""
ISSUE_CREATE_MUTATION = """
mutation IssueCreate($input: IssueCreateInput!) {
issueCreate(input: $input) {
success
issue {
id
identifier
title
url
}
}
}
"""
CUSTOMER_NEED_CREATE_MUTATION = """
mutation CustomerNeedCreate($input: CustomerNeedCreateInput!) {
customerNeedCreate(input: $input) {
success
need {
id
body
customer {
id
name
}
issue {
id
identifier
title
url
}
}
}
}
"""
_settings: Settings | None = None
def _get_settings() -> Settings:
global _settings
if _settings is None:
_settings = Settings()
return _settings
def _get_linear_config() -> tuple[LinearClient, str, str]:
"""Return a configured Linear client, project ID, and team ID.
Raises RuntimeError if any required setting is missing.
"""
secrets = _get_settings().secrets
if not secrets.linear_api_key:
raise RuntimeError("LINEAR_API_KEY is not configured")
if not secrets.linear_feature_request_project_id:
raise RuntimeError("LINEAR_FEATURE_REQUEST_PROJECT_ID is not configured")
if not secrets.linear_feature_request_team_id:
raise RuntimeError("LINEAR_FEATURE_REQUEST_TEAM_ID is not configured")
credentials = APIKeyCredentials(
id="system-linear",
provider="linear",
api_key=SecretStr(secrets.linear_api_key),
title="System Linear API Key",
)
client = LinearClient(credentials=credentials)
return (
client,
secrets.linear_feature_request_project_id,
secrets.linear_feature_request_team_id,
)
class SearchFeatureRequestsTool(BaseTool):
"""Tool for searching existing feature requests in Linear."""
@property
def name(self) -> str:
return "search_feature_requests"
@property
def description(self) -> str:
return (
"Search existing feature requests to check if a similar request "
"already exists before creating a new one. Returns matching feature "
"requests with their ID, title, and description."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search term to find matching feature requests.",
},
},
"required": ["query"],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
query = kwargs.get("query", "").strip()
session_id = session.session_id if session else None
if not query:
return ErrorResponse(
message="Please provide a search query.",
error="Missing query parameter",
session_id=session_id,
)
try:
client, project_id, _team_id = _get_linear_config()
data = await client.query(
SEARCH_ISSUES_QUERY,
{
"term": query,
"filter": {
"project": {"id": {"eq": project_id}},
},
"first": MAX_SEARCH_RESULTS,
},
)
nodes = data.get("searchIssues", {}).get("nodes", [])
if not nodes:
return NoResultsResponse(
message=f"No feature requests found matching '{query}'.",
suggestions=[
"Try different keywords",
"Use broader search terms",
"You can create a new feature request if none exists",
],
session_id=session_id,
)
results = [
FeatureRequestInfo(
id=node["id"],
identifier=node["identifier"],
title=node["title"],
description=node.get("description"),
)
for node in nodes
]
return FeatureRequestSearchResponse(
message=f"Found {len(results)} feature request(s) matching '{query}'.",
results=results,
count=len(results),
query=query,
session_id=session_id,
)
except Exception as e:
logger.exception("Failed to search feature requests")
return ErrorResponse(
message="Failed to search feature requests.",
error=str(e),
session_id=session_id,
)
class CreateFeatureRequestTool(BaseTool):
"""Tool for creating feature requests (or adding needs to existing ones)."""
@property
def name(self) -> str:
return "create_feature_request"
@property
def description(self) -> str:
return (
"Create a new feature request or add a customer need to an existing one. "
"Always search first with search_feature_requests to avoid duplicates. "
"If a matching request exists, pass its ID as existing_issue_id to add "
"the user's need to it instead of creating a duplicate."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"title": {
"type": "string",
"description": "Title for the feature request.",
},
"description": {
"type": "string",
"description": "Detailed description of what the user wants and why.",
},
"existing_issue_id": {
"type": "string",
"description": (
"If adding a need to an existing feature request, "
"provide its Linear issue ID (from search results). "
"Omit to create a new feature request."
),
},
},
"required": ["title", "description"],
}
@property
def requires_auth(self) -> bool:
return True
async def _find_or_create_customer(
self, client: LinearClient, user_id: str, name: str
) -> dict:
"""Find existing customer by user_id or create a new one via upsert.
Args:
client: Linear API client.
user_id: Stable external ID used to deduplicate customers.
name: Human-readable display name (e.g. the user's email).
"""
data = await client.mutate(
CUSTOMER_UPSERT_MUTATION,
{
"input": {
"name": name,
"externalId": user_id,
},
},
)
result = data.get("customerUpsert", {})
if not result.get("success"):
raise RuntimeError(f"Failed to upsert customer: {data}")
return result["customer"]
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
title = kwargs.get("title", "").strip()
description = kwargs.get("description", "").strip()
existing_issue_id = kwargs.get("existing_issue_id")
session_id = session.session_id if session else None
if not title or not description:
return ErrorResponse(
message="Both title and description are required.",
error="Missing required parameters",
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="Authentication required to create feature requests.",
error="Missing user_id",
session_id=session_id,
)
try:
client, project_id, team_id = _get_linear_config()
except Exception as e:
logger.exception("Failed to initialize Linear client")
return ErrorResponse(
message="Failed to create feature request.",
error=str(e),
session_id=session_id,
)
# Resolve a human-readable name (email) for the Linear customer record.
# Fall back to user_id if the lookup fails or returns None.
try:
customer_display_name = await get_user_email_by_id(user_id) or user_id
except Exception:
customer_display_name = user_id
# Step 1: Find or create customer for this user
try:
customer = await self._find_or_create_customer(
client, user_id, customer_display_name
)
customer_id = customer["id"]
customer_name = customer["name"]
except Exception as e:
logger.exception("Failed to upsert customer in Linear")
return ErrorResponse(
message="Failed to create feature request.",
error=str(e),
session_id=session_id,
)
# Step 2: Create or reuse issue
issue_id: str | None = None
issue_identifier: str | None = None
if existing_issue_id:
# Add need to existing issue - we still need the issue details for response
is_new_issue = False
issue_id = existing_issue_id
else:
# Create new issue in the feature requests project
try:
data = await client.mutate(
ISSUE_CREATE_MUTATION,
{
"input": {
"title": title,
"description": description,
"teamId": team_id,
"projectId": project_id,
},
},
)
result = data.get("issueCreate", {})
if not result.get("success"):
return ErrorResponse(
message="Failed to create feature request issue.",
error=str(data),
session_id=session_id,
)
issue = result["issue"]
issue_id = issue["id"]
issue_identifier = issue.get("identifier")
except Exception as e:
logger.exception("Failed to create feature request issue")
return ErrorResponse(
message="Failed to create feature request.",
error=str(e),
session_id=session_id,
)
is_new_issue = True
# Step 3: Create customer need on the issue
try:
data = await client.mutate(
CUSTOMER_NEED_CREATE_MUTATION,
{
"input": {
"customerId": customer_id,
"issueId": issue_id,
"body": description,
"priority": 0,
},
},
)
need_result = data.get("customerNeedCreate", {})
if not need_result.get("success"):
orphaned = (
{"issue_id": issue_id, "issue_identifier": issue_identifier}
if is_new_issue
else None
)
return ErrorResponse(
message="Failed to attach customer need to the feature request.",
error=str(data),
details=orphaned,
session_id=session_id,
)
need = need_result["need"]
issue_info = need["issue"]
except Exception as e:
logger.exception("Failed to create customer need")
orphaned = (
{"issue_id": issue_id, "issue_identifier": issue_identifier}
if is_new_issue
else None
)
return ErrorResponse(
message="Failed to attach customer need to the feature request.",
error=str(e),
details=orphaned,
session_id=session_id,
)
return FeatureRequestCreatedResponse(
message=(
f"{'Created new feature request' if is_new_issue else 'Added your request to existing feature request'}: "
f"{issue_info['title']}."
),
issue_id=issue_info["id"],
issue_identifier=issue_info["identifier"],
issue_title=issue_info["title"],
issue_url=issue_info.get("url", ""),
is_new_issue=is_new_issue,
customer_name=customer_name,
session_id=session_id,
)

View File

@@ -1,615 +0,0 @@
"""Tests for SearchFeatureRequestsTool and CreateFeatureRequestTool."""
from unittest.mock import AsyncMock, patch
import pytest
from backend.api.features.chat.tools.feature_requests import (
CreateFeatureRequestTool,
SearchFeatureRequestsTool,
)
from backend.api.features.chat.tools.models import (
ErrorResponse,
FeatureRequestCreatedResponse,
FeatureRequestSearchResponse,
NoResultsResponse,
)
from ._test_data import make_session
_TEST_USER_ID = "test-user-feature-requests"
_TEST_USER_EMAIL = "testuser@example.com"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_FAKE_PROJECT_ID = "test-project-id"
_FAKE_TEAM_ID = "test-team-id"
def _mock_linear_config(*, query_return=None, mutate_return=None):
"""Return a patched _get_linear_config that yields a mock LinearClient."""
client = AsyncMock()
if query_return is not None:
client.query.return_value = query_return
if mutate_return is not None:
client.mutate.return_value = mutate_return
return (
patch(
"backend.api.features.chat.tools.feature_requests._get_linear_config",
return_value=(client, _FAKE_PROJECT_ID, _FAKE_TEAM_ID),
),
client,
)
def _search_response(nodes: list[dict]) -> dict:
return {"searchIssues": {"nodes": nodes}}
def _customer_upsert_response(
customer_id: str = "cust-1", name: str = _TEST_USER_EMAIL, success: bool = True
) -> dict:
return {
"customerUpsert": {
"success": success,
"customer": {"id": customer_id, "name": name, "externalIds": [name]},
}
}
def _issue_create_response(
issue_id: str = "issue-1",
identifier: str = "FR-1",
title: str = "New Feature",
success: bool = True,
) -> dict:
return {
"issueCreate": {
"success": success,
"issue": {
"id": issue_id,
"identifier": identifier,
"title": title,
"url": f"https://linear.app/issue/{identifier}",
},
}
}
def _need_create_response(
need_id: str = "need-1",
issue_id: str = "issue-1",
identifier: str = "FR-1",
title: str = "New Feature",
success: bool = True,
) -> dict:
return {
"customerNeedCreate": {
"success": success,
"need": {
"id": need_id,
"body": "description",
"customer": {"id": "cust-1", "name": _TEST_USER_EMAIL},
"issue": {
"id": issue_id,
"identifier": identifier,
"title": title,
"url": f"https://linear.app/issue/{identifier}",
},
},
}
}
# ===========================================================================
# SearchFeatureRequestsTool
# ===========================================================================
class TestSearchFeatureRequestsTool:
"""Tests for SearchFeatureRequestsTool._execute."""
@pytest.mark.asyncio(loop_scope="session")
async def test_successful_search(self):
session = make_session(user_id=_TEST_USER_ID)
nodes = [
{
"id": "id-1",
"identifier": "FR-1",
"title": "Dark mode",
"description": "Add dark mode support",
},
{
"id": "id-2",
"identifier": "FR-2",
"title": "Dark theme",
"description": None,
},
]
patcher, _ = _mock_linear_config(query_return=_search_response(nodes))
with patcher:
tool = SearchFeatureRequestsTool()
resp = await tool._execute(
user_id=_TEST_USER_ID, session=session, query="dark mode"
)
assert isinstance(resp, FeatureRequestSearchResponse)
assert resp.count == 2
assert resp.results[0].id == "id-1"
assert resp.results[1].identifier == "FR-2"
assert resp.query == "dark mode"
@pytest.mark.asyncio(loop_scope="session")
async def test_no_results(self):
session = make_session(user_id=_TEST_USER_ID)
patcher, _ = _mock_linear_config(query_return=_search_response([]))
with patcher:
tool = SearchFeatureRequestsTool()
resp = await tool._execute(
user_id=_TEST_USER_ID, session=session, query="nonexistent"
)
assert isinstance(resp, NoResultsResponse)
assert "nonexistent" in resp.message
@pytest.mark.asyncio(loop_scope="session")
async def test_empty_query_returns_error(self):
session = make_session(user_id=_TEST_USER_ID)
tool = SearchFeatureRequestsTool()
resp = await tool._execute(user_id=_TEST_USER_ID, session=session, query=" ")
assert isinstance(resp, ErrorResponse)
assert resp.error is not None
assert "query" in resp.error.lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_missing_query_returns_error(self):
session = make_session(user_id=_TEST_USER_ID)
tool = SearchFeatureRequestsTool()
resp = await tool._execute(user_id=_TEST_USER_ID, session=session)
assert isinstance(resp, ErrorResponse)
@pytest.mark.asyncio(loop_scope="session")
async def test_api_failure(self):
session = make_session(user_id=_TEST_USER_ID)
patcher, client = _mock_linear_config()
client.query.side_effect = RuntimeError("Linear API down")
with patcher:
tool = SearchFeatureRequestsTool()
resp = await tool._execute(
user_id=_TEST_USER_ID, session=session, query="test"
)
assert isinstance(resp, ErrorResponse)
assert resp.error is not None
assert "Linear API down" in resp.error
@pytest.mark.asyncio(loop_scope="session")
async def test_malformed_node_returns_error(self):
"""A node missing required keys should be caught by the try/except."""
session = make_session(user_id=_TEST_USER_ID)
# Node missing 'identifier' key
bad_nodes = [{"id": "id-1", "title": "Missing identifier"}]
patcher, _ = _mock_linear_config(query_return=_search_response(bad_nodes))
with patcher:
tool = SearchFeatureRequestsTool()
resp = await tool._execute(
user_id=_TEST_USER_ID, session=session, query="test"
)
assert isinstance(resp, ErrorResponse)
@pytest.mark.asyncio(loop_scope="session")
async def test_linear_client_init_failure(self):
session = make_session(user_id=_TEST_USER_ID)
with patch(
"backend.api.features.chat.tools.feature_requests._get_linear_config",
side_effect=RuntimeError("No API key"),
):
tool = SearchFeatureRequestsTool()
resp = await tool._execute(
user_id=_TEST_USER_ID, session=session, query="test"
)
assert isinstance(resp, ErrorResponse)
assert resp.error is not None
assert "No API key" in resp.error
# ===========================================================================
# CreateFeatureRequestTool
# ===========================================================================
class TestCreateFeatureRequestTool:
"""Tests for CreateFeatureRequestTool._execute."""
@pytest.fixture(autouse=True)
def _patch_email_lookup(self):
with patch(
"backend.api.features.chat.tools.feature_requests.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TEST_USER_EMAIL,
):
yield
# ---- Happy paths -------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_create_new_issue(self):
"""Full happy path: upsert customer -> create issue -> attach need."""
session = make_session(user_id=_TEST_USER_ID)
patcher, client = _mock_linear_config()
client.mutate.side_effect = [
_customer_upsert_response(),
_issue_create_response(),
_need_create_response(),
]
with patcher:
tool = CreateFeatureRequestTool()
resp = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
title="New Feature",
description="Please add this",
)
assert isinstance(resp, FeatureRequestCreatedResponse)
assert resp.is_new_issue is True
assert resp.issue_identifier == "FR-1"
assert resp.customer_name == _TEST_USER_EMAIL
assert client.mutate.call_count == 3
@pytest.mark.asyncio(loop_scope="session")
async def test_add_need_to_existing_issue(self):
"""When existing_issue_id is provided, skip issue creation."""
session = make_session(user_id=_TEST_USER_ID)
patcher, client = _mock_linear_config()
client.mutate.side_effect = [
_customer_upsert_response(),
_need_create_response(issue_id="existing-1", identifier="FR-99"),
]
with patcher:
tool = CreateFeatureRequestTool()
resp = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
title="Existing Feature",
description="Me too",
existing_issue_id="existing-1",
)
assert isinstance(resp, FeatureRequestCreatedResponse)
assert resp.is_new_issue is False
assert resp.issue_id == "existing-1"
# Only 2 mutations: customer upsert + need create (no issue create)
assert client.mutate.call_count == 2
# ---- Validation errors -------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_missing_title(self):
session = make_session(user_id=_TEST_USER_ID)
tool = CreateFeatureRequestTool()
resp = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
title="",
description="some desc",
)
assert isinstance(resp, ErrorResponse)
assert resp.error is not None
assert "required" in resp.error.lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_missing_description(self):
session = make_session(user_id=_TEST_USER_ID)
tool = CreateFeatureRequestTool()
resp = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
title="Some title",
description="",
)
assert isinstance(resp, ErrorResponse)
assert resp.error is not None
assert "required" in resp.error.lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_missing_user_id(self):
session = make_session(user_id=_TEST_USER_ID)
tool = CreateFeatureRequestTool()
resp = await tool._execute(
user_id=None,
session=session,
title="Some title",
description="Some desc",
)
assert isinstance(resp, ErrorResponse)
assert resp.error is not None
assert "user_id" in resp.error.lower()
# ---- Linear client init failure ----------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_linear_client_init_failure(self):
session = make_session(user_id=_TEST_USER_ID)
with patch(
"backend.api.features.chat.tools.feature_requests._get_linear_config",
side_effect=RuntimeError("No API key"),
):
tool = CreateFeatureRequestTool()
resp = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
title="Title",
description="Desc",
)
assert isinstance(resp, ErrorResponse)
assert resp.error is not None
assert "No API key" in resp.error
# ---- Customer upsert failures ------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_customer_upsert_api_error(self):
session = make_session(user_id=_TEST_USER_ID)
patcher, client = _mock_linear_config()
client.mutate.side_effect = RuntimeError("Customer API error")
with patcher:
tool = CreateFeatureRequestTool()
resp = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
title="Title",
description="Desc",
)
assert isinstance(resp, ErrorResponse)
assert resp.error is not None
assert "Customer API error" in resp.error
@pytest.mark.asyncio(loop_scope="session")
async def test_customer_upsert_not_success(self):
session = make_session(user_id=_TEST_USER_ID)
patcher, client = _mock_linear_config()
client.mutate.return_value = _customer_upsert_response(success=False)
with patcher:
tool = CreateFeatureRequestTool()
resp = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
title="Title",
description="Desc",
)
assert isinstance(resp, ErrorResponse)
@pytest.mark.asyncio(loop_scope="session")
async def test_customer_malformed_response(self):
"""Customer dict missing 'id' key should be caught."""
session = make_session(user_id=_TEST_USER_ID)
patcher, client = _mock_linear_config()
# success=True but customer has no 'id'
client.mutate.return_value = {
"customerUpsert": {
"success": True,
"customer": {"name": _TEST_USER_ID},
}
}
with patcher:
tool = CreateFeatureRequestTool()
resp = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
title="Title",
description="Desc",
)
assert isinstance(resp, ErrorResponse)
# ---- Issue creation failures -------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_issue_create_api_error(self):
session = make_session(user_id=_TEST_USER_ID)
patcher, client = _mock_linear_config()
client.mutate.side_effect = [
_customer_upsert_response(),
RuntimeError("Issue create failed"),
]
with patcher:
tool = CreateFeatureRequestTool()
resp = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
title="Title",
description="Desc",
)
assert isinstance(resp, ErrorResponse)
assert resp.error is not None
assert "Issue create failed" in resp.error
@pytest.mark.asyncio(loop_scope="session")
async def test_issue_create_not_success(self):
session = make_session(user_id=_TEST_USER_ID)
patcher, client = _mock_linear_config()
client.mutate.side_effect = [
_customer_upsert_response(),
_issue_create_response(success=False),
]
with patcher:
tool = CreateFeatureRequestTool()
resp = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
title="Title",
description="Desc",
)
assert isinstance(resp, ErrorResponse)
assert "Failed to create feature request issue" in resp.message
@pytest.mark.asyncio(loop_scope="session")
async def test_issue_create_malformed_response(self):
"""issueCreate success=True but missing 'issue' key."""
session = make_session(user_id=_TEST_USER_ID)
patcher, client = _mock_linear_config()
client.mutate.side_effect = [
_customer_upsert_response(),
{"issueCreate": {"success": True}}, # no 'issue' key
]
with patcher:
tool = CreateFeatureRequestTool()
resp = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
title="Title",
description="Desc",
)
assert isinstance(resp, ErrorResponse)
# ---- Customer need attachment failures ---------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_need_create_api_error_new_issue(self):
"""Need creation fails after new issue was created -> orphaned issue info."""
session = make_session(user_id=_TEST_USER_ID)
patcher, client = _mock_linear_config()
client.mutate.side_effect = [
_customer_upsert_response(),
_issue_create_response(issue_id="orphan-1", identifier="FR-10"),
RuntimeError("Need attach failed"),
]
with patcher:
tool = CreateFeatureRequestTool()
resp = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
title="Title",
description="Desc",
)
assert isinstance(resp, ErrorResponse)
assert resp.error is not None
assert "Need attach failed" in resp.error
assert resp.details is not None
assert resp.details["issue_id"] == "orphan-1"
assert resp.details["issue_identifier"] == "FR-10"
@pytest.mark.asyncio(loop_scope="session")
async def test_need_create_api_error_existing_issue(self):
"""Need creation fails on existing issue -> no orphaned info."""
session = make_session(user_id=_TEST_USER_ID)
patcher, client = _mock_linear_config()
client.mutate.side_effect = [
_customer_upsert_response(),
RuntimeError("Need attach failed"),
]
with patcher:
tool = CreateFeatureRequestTool()
resp = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
title="Title",
description="Desc",
existing_issue_id="existing-1",
)
assert isinstance(resp, ErrorResponse)
assert resp.details is None
@pytest.mark.asyncio(loop_scope="session")
async def test_need_create_not_success_includes_orphaned_info(self):
"""customerNeedCreate returns success=False -> includes orphaned issue."""
session = make_session(user_id=_TEST_USER_ID)
patcher, client = _mock_linear_config()
client.mutate.side_effect = [
_customer_upsert_response(),
_issue_create_response(issue_id="orphan-2", identifier="FR-20"),
_need_create_response(success=False),
]
with patcher:
tool = CreateFeatureRequestTool()
resp = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
title="Title",
description="Desc",
)
assert isinstance(resp, ErrorResponse)
assert resp.details is not None
assert resp.details["issue_id"] == "orphan-2"
assert resp.details["issue_identifier"] == "FR-20"
@pytest.mark.asyncio(loop_scope="session")
async def test_need_create_not_success_existing_issue_no_details(self):
"""customerNeedCreate fails on existing issue -> no orphaned info."""
session = make_session(user_id=_TEST_USER_ID)
patcher, client = _mock_linear_config()
client.mutate.side_effect = [
_customer_upsert_response(),
_need_create_response(success=False),
]
with patcher:
tool = CreateFeatureRequestTool()
resp = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
title="Title",
description="Desc",
existing_issue_id="existing-1",
)
assert isinstance(resp, ErrorResponse)
assert resp.details is None
@pytest.mark.asyncio(loop_scope="session")
async def test_need_create_malformed_response(self):
"""need_result missing 'need' key after success=True."""
session = make_session(user_id=_TEST_USER_ID)
patcher, client = _mock_linear_config()
client.mutate.side_effect = [
_customer_upsert_response(),
_issue_create_response(),
{"customerNeedCreate": {"success": True}}, # no 'need' key
]
with patcher:
tool = CreateFeatureRequestTool()
resp = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
title="Title",
description="Desc",
)
assert isinstance(resp, ErrorResponse)
assert resp.details is not None
assert resp.details["issue_id"] == "issue-1"

View File

@@ -146,7 +146,6 @@ class FindBlockTool(BaseTool):
id=block_id,
name=block.name,
description=block.description or "",
categories=[c.value for c in block.categories],
)
)

View File

@@ -41,15 +41,6 @@ class ResponseType(str, Enum):
OPERATION_IN_PROGRESS = "operation_in_progress"
# Input validation
INPUT_VALIDATION_ERROR = "input_validation_error"
# Web fetch
WEB_FETCH = "web_fetch"
# Code execution
BASH_EXEC = "bash_exec"
# Operation status check
OPERATION_STATUS = "operation_status"
# Feature request types
FEATURE_REQUEST_SEARCH = "feature_request_search"
FEATURE_REQUEST_CREATED = "feature_request_created"
# Base response model
@@ -344,19 +335,6 @@ class BlockInfoSummary(BaseModel):
id: str
name: str
description: str
categories: list[str]
input_schema: dict[str, Any] = Field(
default_factory=dict,
description="Full JSON schema for block inputs",
)
output_schema: dict[str, Any] = Field(
default_factory=dict,
description="Full JSON schema for block outputs",
)
required_inputs: list[BlockInputFieldInfo] = Field(
default_factory=list,
description="List of input fields for this block",
)
class BlockListResponse(ToolResponseBase):
@@ -366,10 +344,6 @@ class BlockListResponse(ToolResponseBase):
blocks: list[BlockInfoSummary]
count: int
query: str
usage_hint: str = Field(
default="To execute a block, call run_block with block_id set to the block's "
"'id' field and input_data containing the fields listed in required_inputs."
)
class BlockDetails(BaseModel):
@@ -456,55 +430,3 @@ class AsyncProcessingResponse(ToolResponseBase):
status: str = "accepted" # Must be "accepted" for detection
operation_id: str | None = None
task_id: str | None = None
class WebFetchResponse(ToolResponseBase):
"""Response for web_fetch tool."""
type: ResponseType = ResponseType.WEB_FETCH
url: str
status_code: int
content_type: str
content: str
truncated: bool = False
class BashExecResponse(ToolResponseBase):
"""Response for bash_exec tool."""
type: ResponseType = ResponseType.BASH_EXEC
stdout: str
stderr: str
exit_code: int
timed_out: bool = False
# Feature request models
class FeatureRequestInfo(BaseModel):
"""Information about a feature request issue."""
id: str
identifier: str
title: str
description: str | None = None
class FeatureRequestSearchResponse(ToolResponseBase):
"""Response for search_feature_requests tool."""
type: ResponseType = ResponseType.FEATURE_REQUEST_SEARCH
results: list[FeatureRequestInfo]
count: int
query: str
class FeatureRequestCreatedResponse(ToolResponseBase):
"""Response for create_feature_request tool."""
type: ResponseType = ResponseType.FEATURE_REQUEST_CREATED
issue_id: str
issue_identifier: str
issue_title: str
issue_url: str
is_new_issue: bool # False if added to existing
customer_name: str

View File

@@ -1,265 +0,0 @@
"""Sandbox execution utilities for code execution tools.
Provides filesystem + network isolated command execution using **bubblewrap**
(``bwrap``): whitelist-only filesystem (only system dirs visible read-only),
writable workspace only, clean environment, network blocked.
Tools that call :func:`run_sandboxed` must first check :func:`has_full_sandbox`
and refuse to run if bubblewrap is not available.
"""
import asyncio
import logging
import os
import platform
import shutil
logger = logging.getLogger(__name__)
_DEFAULT_TIMEOUT = 30
_MAX_TIMEOUT = 120
# ---------------------------------------------------------------------------
# Sandbox capability detection (cached at first call)
# ---------------------------------------------------------------------------
_BWRAP_AVAILABLE: bool | None = None
def has_full_sandbox() -> bool:
"""Return True if bubblewrap is available (filesystem + network isolation).
On non-Linux platforms (macOS), always returns False.
"""
global _BWRAP_AVAILABLE
if _BWRAP_AVAILABLE is None:
_BWRAP_AVAILABLE = (
platform.system() == "Linux" and shutil.which("bwrap") is not None
)
return _BWRAP_AVAILABLE
WORKSPACE_PREFIX = "/tmp/copilot-"
def make_session_path(session_id: str) -> str:
"""Build a sanitized, session-specific path under :data:`WORKSPACE_PREFIX`.
Shared by both the SDK working-directory setup and the sandbox tools so
they always resolve to the same directory for a given session.
Steps:
1. Strip all characters except ``[A-Za-z0-9-]``.
2. Construct ``/tmp/copilot-<safe_id>``.
3. Validate via ``os.path.normpath`` + ``startswith`` (CodeQL-recognised
sanitizer) to prevent path traversal.
Raises:
ValueError: If the resulting path escapes the prefix.
"""
import re
safe_id = re.sub(r"[^A-Za-z0-9-]", "", session_id)
if not safe_id:
safe_id = "default"
path = os.path.normpath(f"{WORKSPACE_PREFIX}{safe_id}")
if not path.startswith(WORKSPACE_PREFIX):
raise ValueError(f"Session path escaped prefix: {path}")
return path
def get_workspace_dir(session_id: str) -> str:
"""Get or create the workspace directory for a session.
Uses :func:`make_session_path` — the same path the SDK uses — so that
bash_exec shares the workspace with the SDK file tools.
"""
workspace = make_session_path(session_id)
os.makedirs(workspace, exist_ok=True)
return workspace
# ---------------------------------------------------------------------------
# Bubblewrap command builder
# ---------------------------------------------------------------------------
# System directories mounted read-only inside the sandbox.
# ONLY these are visible — /app, /root, /home, /opt, /var etc. are NOT accessible.
_SYSTEM_RO_BINDS = [
"/usr", # binaries, libraries, Python interpreter
"/etc", # system config: ld.so, locale, passwd, alternatives
]
# Compat paths: symlinks to /usr/* on modern Debian, real dirs on older systems.
# On Debian 13 these are symlinks (e.g. /bin -> usr/bin). bwrap --ro-bind
# can't create a symlink target, so we detect and use --symlink instead.
# /lib64 is critical: the ELF dynamic linker lives at /lib64/ld-linux-x86-64.so.2.
_COMPAT_PATHS = [
("/bin", "usr/bin"), # -> /usr/bin on Debian 13
("/sbin", "usr/sbin"), # -> /usr/sbin on Debian 13
("/lib", "usr/lib"), # -> /usr/lib on Debian 13
("/lib64", "usr/lib64"), # 64-bit libraries / ELF interpreter
]
# Resource limits to prevent fork bombs, memory exhaustion, and disk abuse.
# Applied via ulimit inside the sandbox before exec'ing the user command.
_RESOURCE_LIMITS = (
"ulimit -u 64" # max 64 processes (prevents fork bombs)
" -v 524288" # 512 MB virtual memory
" -f 51200" # 50 MB max file size (1024-byte blocks)
" -n 256" # 256 open file descriptors
" 2>/dev/null"
)
def _build_bwrap_command(
command: list[str], cwd: str, env: dict[str, str]
) -> list[str]:
"""Build a bubblewrap command with strict filesystem + network isolation.
Security model:
- **Whitelist-only filesystem**: only system directories (``/usr``, ``/etc``,
``/bin``, ``/lib``) are mounted read-only. Application code (``/app``),
home directories, ``/var``, ``/opt``, etc. are NOT accessible at all.
- **Writable workspace only**: the per-session workspace is the sole
writable path.
- **Clean environment**: ``--clearenv`` wipes all inherited env vars.
Only the explicitly-passed safe env vars are set inside the sandbox.
- **Network isolation**: ``--unshare-net`` blocks all network access.
- **Resource limits**: ulimit caps on processes (64), memory (512MB),
file size (50MB), and open FDs (256) to prevent fork bombs and abuse.
- **New session**: prevents terminal control escape.
- **Die with parent**: prevents orphaned sandbox processes.
"""
cmd = [
"bwrap",
# Create a new user namespace so bwrap can set up sandboxing
# inside unprivileged Docker containers (no CAP_SYS_ADMIN needed).
"--unshare-user",
# Wipe all inherited environment variables (API keys, secrets, etc.)
"--clearenv",
]
# Set only the safe env vars inside the sandbox
for key, value in env.items():
cmd.extend(["--setenv", key, value])
# System directories: read-only
for path in _SYSTEM_RO_BINDS:
cmd.extend(["--ro-bind", path, path])
# Compat paths: use --symlink when host path is a symlink (Debian 13),
# --ro-bind when it's a real directory (older distros).
for path, symlink_target in _COMPAT_PATHS:
if os.path.islink(path):
cmd.extend(["--symlink", symlink_target, path])
elif os.path.exists(path):
cmd.extend(["--ro-bind", path, path])
# Wrap the user command with resource limits:
# sh -c 'ulimit ...; exec "$@"' -- <original command>
# `exec "$@"` replaces the shell so there's no extra process overhead,
# and properly handles arguments with spaces.
limited_command = [
"sh",
"-c",
f'{_RESOURCE_LIMITS}; exec "$@"',
"--",
*command,
]
cmd.extend(
[
# Fresh virtual filesystems
"--dev",
"/dev",
"--proc",
"/proc",
"--tmpfs",
"/tmp",
# Workspace bind AFTER --tmpfs /tmp so it's visible through the tmpfs.
# (workspace lives under /tmp/copilot-<session>)
"--bind",
cwd,
cwd,
# Isolation
"--unshare-net",
"--die-with-parent",
"--new-session",
"--chdir",
cwd,
"--",
*limited_command,
]
)
return cmd
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
async def run_sandboxed(
command: list[str],
cwd: str,
timeout: int = _DEFAULT_TIMEOUT,
env: dict[str, str] | None = None,
) -> tuple[str, str, int, bool]:
"""Run a command inside a bubblewrap sandbox.
Callers **must** check :func:`has_full_sandbox` before calling this
function. If bubblewrap is not available, this function raises
:class:`RuntimeError` rather than running unsandboxed.
Returns:
(stdout, stderr, exit_code, timed_out)
"""
if not has_full_sandbox():
raise RuntimeError(
"run_sandboxed() requires bubblewrap but bwrap is not available. "
"Callers must check has_full_sandbox() before calling this function."
)
timeout = min(max(timeout, 1), _MAX_TIMEOUT)
safe_env = {
"PATH": "/usr/local/bin:/usr/bin:/bin",
"HOME": cwd,
"TMPDIR": cwd,
"LANG": "en_US.UTF-8",
"PYTHONDONTWRITEBYTECODE": "1",
"PYTHONIOENCODING": "utf-8",
}
if env:
safe_env.update(env)
full_command = _build_bwrap_command(command, cwd, safe_env)
try:
proc = await asyncio.create_subprocess_exec(
*full_command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=cwd,
env=safe_env,
)
try:
stdout_bytes, stderr_bytes = await asyncio.wait_for(
proc.communicate(), timeout=timeout
)
stdout = stdout_bytes.decode("utf-8", errors="replace")
stderr = stderr_bytes.decode("utf-8", errors="replace")
return stdout, stderr, proc.returncode or 0, False
except asyncio.TimeoutError:
proc.kill()
await proc.communicate()
return "", f"Execution timed out after {timeout}s", -1, True
except RuntimeError:
raise
except Exception as e:
return "", f"Sandbox error: {e}", -1, False

View File

@@ -15,7 +15,6 @@ from backend.data.model import (
OAuth2Credentials,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.exceptions import NotFoundError
logger = logging.getLogger(__name__)
@@ -360,7 +359,7 @@ async def match_user_credentials_to_graph(
_,
_,
) in aggregated_creds.items():
# Find first matching credential by provider, type, scopes, and host/URL
# Find first matching credential by provider, type, and scopes
matching_cred = next(
(
cred
@@ -375,10 +374,6 @@ async def match_user_credentials_to_graph(
cred.type != "host_scoped"
or _credential_is_for_host(cred, credential_requirements)
)
and (
cred.provider != ProviderName.MCP
or _credential_is_for_mcp_server(cred, credential_requirements)
)
),
None,
)
@@ -449,22 +444,6 @@ def _credential_is_for_host(
return credential.matches_url(list(requirements.discriminator_values)[0])
def _credential_is_for_mcp_server(
credential: Credentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""Check if an MCP OAuth credential matches the required server URL."""
if not requirements.discriminator_values:
return True
server_url = (
credential.metadata.get("mcp_server_url")
if isinstance(credential, OAuth2Credentials)
else None
)
return server_url in requirements.discriminator_values if server_url else False
async def check_user_has_required_credentials(
user_id: str,
required_credentials: list[CredentialsMetaInput],

View File

@@ -1,151 +0,0 @@
"""Web fetch tool — safely retrieve public web page content."""
import logging
from typing import Any
import aiohttp
import html2text
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
ErrorResponse,
ToolResponseBase,
WebFetchResponse,
)
from backend.util.request import Requests
logger = logging.getLogger(__name__)
# Limits
_MAX_CONTENT_BYTES = 102_400 # 100 KB download cap
_REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=15)
# Content types we'll read as text
_TEXT_CONTENT_TYPES = {
"text/html",
"text/plain",
"text/xml",
"text/csv",
"text/markdown",
"application/json",
"application/xml",
"application/xhtml+xml",
"application/rss+xml",
"application/atom+xml",
}
def _is_text_content(content_type: str) -> bool:
base = content_type.split(";")[0].strip().lower()
return base in _TEXT_CONTENT_TYPES or base.startswith("text/")
def _html_to_text(html: str) -> str:
h = html2text.HTML2Text()
h.ignore_links = False
h.ignore_images = True
h.body_width = 0
return h.handle(html)
class WebFetchTool(BaseTool):
"""Safely fetch content from a public URL using SSRF-protected HTTP."""
@property
def name(self) -> str:
return "web_fetch"
@property
def description(self) -> str:
return (
"Fetch the content of a public web page by URL. "
"Returns readable text extracted from HTML by default. "
"Useful for reading documentation, articles, and API responses. "
"Only supports HTTP/HTTPS GET requests to public URLs "
"(private/internal network addresses are blocked)."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The public HTTP/HTTPS URL to fetch.",
},
"extract_text": {
"type": "boolean",
"description": (
"If true (default), extract readable text from HTML. "
"If false, return raw content."
),
"default": True,
},
},
"required": ["url"],
}
@property
def requires_auth(self) -> bool:
return False
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
url: str = (kwargs.get("url") or "").strip()
extract_text: bool = kwargs.get("extract_text", True)
session_id = session.session_id if session else None
if not url:
return ErrorResponse(
message="Please provide a URL to fetch.",
error="missing_url",
session_id=session_id,
)
try:
client = Requests(raise_for_status=False, retry_max_attempts=1)
response = await client.get(url, timeout=_REQUEST_TIMEOUT)
except ValueError as e:
# validate_url raises ValueError for SSRF / blocked IPs
return ErrorResponse(
message=f"URL blocked: {e}",
error="url_blocked",
session_id=session_id,
)
except Exception as e:
logger.warning(f"[web_fetch] Request failed for {url}: {e}")
return ErrorResponse(
message=f"Failed to fetch URL: {e}",
error="fetch_failed",
session_id=session_id,
)
content_type = response.headers.get("content-type", "")
if not _is_text_content(content_type):
return ErrorResponse(
message=f"Non-text content type: {content_type.split(';')[0]}",
error="unsupported_content_type",
session_id=session_id,
)
raw = response.content[:_MAX_CONTENT_BYTES]
text = raw.decode("utf-8", errors="replace")
if extract_text and "html" in content_type.lower():
text = _html_to_text(text)
return WebFetchResponse(
message=f"Fetched {url}",
url=response.url,
status_code=response.status,
content_type=content_type.split(";")[0].strip(),
content=text,
truncated=False,
session_id=session_id,
)

View File

@@ -88,9 +88,7 @@ class ListWorkspaceFilesTool(BaseTool):
@property
def description(self) -> str:
return (
"List files in the user's persistent workspace (cloud storage). "
"These files survive across sessions. "
"For ephemeral session files, use the SDK Read/Glob tools instead. "
"List files in the user's workspace. "
"Returns file names, paths, sizes, and metadata. "
"Optionally filter by path prefix."
)
@@ -206,9 +204,7 @@ class ReadWorkspaceFileTool(BaseTool):
@property
def description(self) -> str:
return (
"Read a file from the user's persistent workspace (cloud storage). "
"These files survive across sessions. "
"For ephemeral session files, use the SDK Read tool instead. "
"Read a file from the user's workspace. "
"Specify either file_id or path to identify the file. "
"For small text files, returns content directly. "
"For large or binary files, returns metadata and a download URL. "
@@ -382,9 +378,7 @@ class WriteWorkspaceFileTool(BaseTool):
@property
def description(self) -> str:
return (
"Write or create a file in the user's persistent workspace (cloud storage). "
"These files survive across sessions. "
"For ephemeral session files, use the SDK Write tool instead. "
"Write or create a file in the user's workspace. "
"Provide the content as a base64-encoded string. "
f"Maximum file size is {Config().max_file_size_mb}MB. "
"Files are saved to the current session's folder by default. "
@@ -529,7 +523,7 @@ class DeleteWorkspaceFileTool(BaseTool):
@property
def description(self) -> str:
return (
"Delete a file from the user's persistent workspace (cloud storage). "
"Delete a file from the user's workspace. "
"Specify either file_id or path to identify the file. "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."

View File

@@ -1,7 +1,7 @@
import asyncio
import logging
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Annotated, Any, List, Literal
from typing import TYPE_CHECKING, Annotated, List, Literal
from autogpt_libs.auth import get_user_id
from fastapi import (
@@ -14,7 +14,7 @@ from fastapi import (
Security,
status,
)
from pydantic import BaseModel, Field, SecretStr, model_validator
from pydantic import BaseModel, Field, SecretStr
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
from backend.api.features.library.db import set_preset_webhook, update_preset
@@ -39,11 +39,7 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
from backend.data.user import get_user_integrations
from backend.executor.utils import add_graph_execution
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
from backend.integrations.credentials_store import provider_matches
from backend.integrations.creds_manager import (
IntegrationCredentialsManager,
create_mcp_oauth_handler,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
@@ -106,37 +102,9 @@ class CredentialsMetaResponse(BaseModel):
scopes: list[str] | None
username: str | None
host: str | None = Field(
default=None,
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
default=None, description="Host pattern for host-scoped credentials"
)
@model_validator(mode="before")
@classmethod
def _normalize_provider(cls, data: Any) -> Any:
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug."""
if isinstance(data, dict):
prov = data.get("provider", "")
if isinstance(prov, str) and prov.startswith("ProviderName."):
member = prov.removeprefix("ProviderName.")
try:
data = {**data, "provider": ProviderName[member].value}
except KeyError:
pass
return data
@staticmethod
def get_host(cred: Credentials) -> str | None:
"""Extract host from credential: HostScoped host or MCP server URL."""
if isinstance(cred, HostScopedCredentials):
return cred.host
if isinstance(cred, OAuth2Credentials) and cred.provider in (
ProviderName.MCP,
ProviderName.MCP.value,
"ProviderName.MCP",
):
return (cred.metadata or {}).get("mcp_server_url")
return None
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
async def callback(
@@ -211,7 +179,9 @@ async def callback(
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
host=(CredentialsMetaResponse.get_host(credentials)),
host=(
credentials.host if isinstance(credentials, HostScopedCredentials) else None
),
)
@@ -229,7 +199,7 @@ async def list_credentials(
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred),
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
)
for cred in credentials
]
@@ -252,7 +222,7 @@ async def list_credentials_by_provider(
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred),
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
)
for cred in credentials
]
@@ -352,11 +322,7 @@ async def delete_credentials(
tokens_revoked = None
if isinstance(creds, OAuth2Credentials):
if provider_matches(provider.value, ProviderName.MCP.value):
# MCP uses dynamic per-server OAuth — create handler from metadata
handler = create_mcp_oauth_handler(creds)
else:
handler = _get_provider_oauth_handler(request, provider)
handler = _get_provider_oauth_handler(request, provider)
tokens_revoked = await handler.revoke_tokens(creds)
return CredentialsDeletionResponse(revoked=tokens_revoked)

View File

@@ -1,404 +0,0 @@
"""
MCP (Model Context Protocol) API routes.
Provides endpoints for MCP tool discovery and OAuth authentication so the
frontend can list available tools on an MCP server before placing a block.
"""
import logging
from typing import Annotated, Any
from urllib.parse import urlparse
import fastapi
from autogpt_libs.auth import get_user_id
from fastapi import Security
from pydantic import BaseModel, Field
from backend.api.features.integrations.router import CredentialsMetaResponse
from backend.blocks.mcp.client import MCPClient, MCPClientError
from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.request import HTTPClientError, Requests
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
router = fastapi.APIRouter(tags=["mcp"])
creds_manager = IntegrationCredentialsManager()
# ====================== Tool Discovery ====================== #
class DiscoverToolsRequest(BaseModel):
"""Request to discover tools on an MCP server."""
server_url: str = Field(description="URL of the MCP server")
auth_token: str | None = Field(
default=None,
description="Optional Bearer token for authenticated MCP servers",
)
class MCPToolResponse(BaseModel):
"""A single MCP tool returned by discovery."""
name: str
description: str
input_schema: dict[str, Any]
class DiscoverToolsResponse(BaseModel):
"""Response containing the list of tools available on an MCP server."""
tools: list[MCPToolResponse]
server_name: str | None = None
protocol_version: str | None = None
@router.post(
"/discover-tools",
summary="Discover available tools on an MCP server",
response_model=DiscoverToolsResponse,
)
async def discover_tools(
request: DiscoverToolsRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> DiscoverToolsResponse:
"""
Connect to an MCP server and return its available tools.
If the user has a stored MCP credential for this server URL, it will be
used automatically — no need to pass an explicit auth token.
"""
auth_token = request.auth_token
# Auto-use stored MCP credential when no explicit token is provided.
if not auth_token:
mcp_creds = await creds_manager.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
# Find the freshest credential for this server URL
best_cred: OAuth2Credentials | None = None
for cred in mcp_creds:
if (
isinstance(cred, OAuth2Credentials)
and (cred.metadata or {}).get("mcp_server_url") == request.server_url
):
if best_cred is None or (
(cred.access_token_expires_at or 0)
> (best_cred.access_token_expires_at or 0)
):
best_cred = cred
if best_cred:
# Refresh the token if expired before using it
best_cred = await creds_manager.refresh_if_needed(user_id, best_cred)
logger.info(
f"Using MCP credential {best_cred.id} for {request.server_url}, "
f"expires_at={best_cred.access_token_expires_at}"
)
auth_token = best_cred.access_token.get_secret_value()
client = MCPClient(request.server_url, auth_token=auth_token)
try:
init_result = await client.initialize()
tools = await client.list_tools()
except HTTPClientError as e:
if e.status_code in (401, 403):
raise fastapi.HTTPException(
status_code=401,
detail="This MCP server requires authentication. "
"Please provide a valid auth token.",
)
raise fastapi.HTTPException(status_code=502, detail=str(e))
except MCPClientError as e:
raise fastapi.HTTPException(status_code=502, detail=str(e))
except Exception as e:
raise fastapi.HTTPException(
status_code=502,
detail=f"Failed to connect to MCP server: {e}",
)
return DiscoverToolsResponse(
tools=[
MCPToolResponse(
name=t.name,
description=t.description,
input_schema=t.input_schema,
)
for t in tools
],
server_name=(
init_result.get("serverInfo", {}).get("name")
or urlparse(request.server_url).hostname
or "MCP"
),
protocol_version=init_result.get("protocolVersion"),
)
# ======================== OAuth Flow ======================== #
class MCPOAuthLoginRequest(BaseModel):
"""Request to start an OAuth flow for an MCP server."""
server_url: str = Field(description="URL of the MCP server that requires OAuth")
class MCPOAuthLoginResponse(BaseModel):
"""Response with the OAuth login URL for the user to authenticate."""
login_url: str
state_token: str
@router.post(
"/oauth/login",
summary="Initiate OAuth login for an MCP server",
)
async def mcp_oauth_login(
request: MCPOAuthLoginRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> MCPOAuthLoginResponse:
"""
Discover OAuth metadata from the MCP server and return a login URL.
1. Discovers the protected-resource metadata (RFC 9728)
2. Fetches the authorization server metadata (RFC 8414)
3. Performs Dynamic Client Registration (RFC 7591) if available
4. Returns the authorization URL for the frontend to open in a popup
"""
client = MCPClient(request.server_url)
# Step 1: Discover protected-resource metadata (RFC 9728)
protected_resource = await client.discover_auth()
metadata: dict[str, Any] | None = None
if protected_resource and protected_resource.get("authorization_servers"):
auth_server_url = protected_resource["authorization_servers"][0]
resource_url = protected_resource.get("resource", request.server_url)
# Step 2a: Discover auth-server metadata (RFC 8414)
metadata = await client.discover_auth_server_metadata(auth_server_url)
else:
# Fallback: Some MCP servers (e.g. Linear) are their own auth server
# and serve OAuth metadata directly without protected-resource metadata.
# Don't assume a resource_url — omitting it lets the auth server choose
# the correct audience for the token (RFC 8707 resource is optional).
resource_url = None
metadata = await client.discover_auth_server_metadata(request.server_url)
if (
not metadata
or "authorization_endpoint" not in metadata
or "token_endpoint" not in metadata
):
raise fastapi.HTTPException(
status_code=400,
detail="This MCP server does not advertise OAuth support. "
"You may need to provide an auth token manually.",
)
authorize_url = metadata["authorization_endpoint"]
token_url = metadata["token_endpoint"]
registration_endpoint = metadata.get("registration_endpoint")
revoke_url = metadata.get("revocation_endpoint")
# Step 3: Dynamic Client Registration (RFC 7591) if available
frontend_base_url = settings.config.frontend_base_url
if not frontend_base_url:
raise fastapi.HTTPException(
status_code=500,
detail="Frontend base URL is not configured.",
)
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
client_id = ""
client_secret = ""
if registration_endpoint:
reg_result = await _register_mcp_client(
registration_endpoint, redirect_uri, request.server_url
)
if reg_result:
client_id = reg_result.get("client_id", "")
client_secret = reg_result.get("client_secret", "")
if not client_id:
client_id = "autogpt-platform"
# Step 4: Store state token with OAuth metadata for the callback
scopes = (protected_resource or {}).get("scopes_supported") or metadata.get(
"scopes_supported", []
)
state_token, code_challenge = await creds_manager.store.store_state_token(
user_id,
ProviderName.MCP.value,
scopes,
state_metadata={
"authorize_url": authorize_url,
"token_url": token_url,
"revoke_url": revoke_url,
"resource_url": resource_url,
"server_url": request.server_url,
"client_id": client_id,
"client_secret": client_secret,
},
)
# Step 5: Build and return the login URL
handler = MCPOAuthHandler(
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
authorize_url=authorize_url,
token_url=token_url,
resource_url=resource_url,
)
login_url = handler.get_login_url(
scopes, state_token, code_challenge=code_challenge
)
return MCPOAuthLoginResponse(login_url=login_url, state_token=state_token)
class MCPOAuthCallbackRequest(BaseModel):
"""Request to exchange an OAuth code for tokens."""
code: str = Field(description="Authorization code from OAuth callback")
state_token: str = Field(description="State token for CSRF verification")
class MCPOAuthCallbackResponse(BaseModel):
"""Response after successfully storing OAuth credentials."""
credential_id: str
@router.post(
"/oauth/callback",
summary="Exchange OAuth code for MCP tokens",
)
async def mcp_oauth_callback(
request: MCPOAuthCallbackRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> CredentialsMetaResponse:
"""
Exchange the authorization code for tokens and store the credential.
The frontend calls this after receiving the OAuth code from the popup.
On success, subsequent ``/discover-tools`` calls for the same server URL
will automatically use the stored credential.
"""
valid_state = await creds_manager.store.verify_state_token(
user_id, request.state_token, ProviderName.MCP.value
)
if not valid_state:
raise fastapi.HTTPException(
status_code=400,
detail="Invalid or expired state token.",
)
meta = valid_state.state_metadata
frontend_base_url = settings.config.frontend_base_url
if not frontend_base_url:
raise fastapi.HTTPException(
status_code=500,
detail="Frontend base URL is not configured.",
)
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
handler = MCPOAuthHandler(
client_id=meta["client_id"],
client_secret=meta.get("client_secret", ""),
redirect_uri=redirect_uri,
authorize_url=meta["authorize_url"],
token_url=meta["token_url"],
revoke_url=meta.get("revoke_url"),
resource_url=meta.get("resource_url"),
)
try:
credentials = await handler.exchange_code_for_tokens(
request.code, valid_state.scopes, valid_state.code_verifier
)
except Exception as e:
raise fastapi.HTTPException(
status_code=400,
detail=f"OAuth token exchange failed: {e}",
)
# Enrich credential metadata for future lookup and token refresh
if credentials.metadata is None:
credentials.metadata = {}
credentials.metadata["mcp_server_url"] = meta["server_url"]
credentials.metadata["mcp_client_id"] = meta["client_id"]
credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "")
credentials.metadata["mcp_token_url"] = meta["token_url"]
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
credentials.title = f"MCP: {hostname}"
# Remove old MCP credentials for the same server to prevent stale token buildup.
try:
old_creds = await creds_manager.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
for old in old_creds:
if (
isinstance(old, OAuth2Credentials)
and (old.metadata or {}).get("mcp_server_url") == meta["server_url"]
):
await creds_manager.store.delete_creds_by_id(user_id, old.id)
logger.info(
f"Removed old MCP credential {old.id} for {meta['server_url']}"
)
except Exception:
logger.debug("Could not clean up old MCP credentials", exc_info=True)
await creds_manager.create(user_id, credentials)
return CredentialsMetaResponse(
id=credentials.id,
provider=credentials.provider,
type=credentials.type,
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
host=credentials.metadata.get("mcp_server_url"),
)
# ======================== Helpers ======================== #
async def _register_mcp_client(
registration_endpoint: str,
redirect_uri: str,
server_url: str,
) -> dict[str, Any] | None:
"""Attempt Dynamic Client Registration (RFC 7591) with an MCP auth server."""
try:
response = await Requests(raise_for_status=True).post(
registration_endpoint,
json={
"client_name": "AutoGPT Platform",
"redirect_uris": [redirect_uri],
"grant_types": ["authorization_code"],
"response_types": ["code"],
"token_endpoint_auth_method": "client_secret_post",
},
)
data = response.json()
if isinstance(data, dict) and "client_id" in data:
return data
return None
except Exception as e:
logger.warning(f"Dynamic client registration failed for {server_url}: {e}")
return None

View File

@@ -1,436 +0,0 @@
"""Tests for MCP API routes.
Uses httpx.AsyncClient with ASGITransport instead of fastapi.testclient.TestClient
to avoid creating blocking portals that can corrupt pytest-asyncio's session event loop.
"""
from unittest.mock import AsyncMock, patch
import fastapi
import httpx
import pytest
import pytest_asyncio
from autogpt_libs.auth import get_user_id
from backend.api.features.mcp.routes import router
from backend.blocks.mcp.client import MCPClientError, MCPTool
from backend.util.request import HTTPClientError
app = fastapi.FastAPI()
app.include_router(router)
app.dependency_overrides[get_user_id] = lambda: "test-user-id"
@pytest_asyncio.fixture(scope="module")
async def client():
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
yield c
class TestDiscoverTools:
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_success(self, client):
mock_tools = [
MCPTool(
name="get_weather",
description="Get weather for a city",
input_schema={
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
),
MCPTool(
name="add_numbers",
description="Add two numbers",
input_schema={
"type": "object",
"properties": {
"a": {"type": "number"},
"b": {"type": "number"},
},
},
),
]
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={
"protocolVersion": "2025-03-26",
"serverInfo": {"name": "test-server"},
}
)
instance.list_tools = AsyncMock(return_value=mock_tools)
response = await client.post(
"/discover-tools",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
data = response.json()
assert len(data["tools"]) == 2
assert data["tools"][0]["name"] == "get_weather"
assert data["tools"][1]["name"] == "add_numbers"
assert data["server_name"] == "test-server"
assert data["protocol_version"] == "2025-03-26"
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_with_auth_token(self, client):
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
)
instance.list_tools = AsyncMock(return_value=[])
response = await client.post(
"/discover-tools",
json={
"server_url": "https://mcp.example.com/mcp",
"auth_token": "my-secret-token",
},
)
assert response.status_code == 200
MockClient.assert_called_once_with(
"https://mcp.example.com/mcp",
auth_token="my-secret-token",
)
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_auto_uses_stored_credential(self, client):
"""When no explicit token is given, stored MCP credentials are used."""
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
stored_cred = OAuth2Credentials(
provider="mcp",
title="MCP: example.com",
access_token=SecretStr("stored-token-123"),
refresh_token=None,
access_token_expires_at=None,
refresh_token_expires_at=None,
scopes=[],
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
)
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred])
mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred)
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
)
instance.list_tools = AsyncMock(return_value=[])
response = await client.post(
"/discover-tools",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
MockClient.assert_called_once_with(
"https://mcp.example.com/mcp",
auth_token="stored-token-123",
)
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_mcp_error(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=MCPClientError("Connection refused")
)
response = await client.post(
"/discover-tools",
json={"server_url": "https://bad-server.example.com/mcp"},
)
assert response.status_code == 502
assert "Connection refused" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_generic_error(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
response = await client.post(
"/discover-tools",
json={"server_url": "https://timeout.example.com/mcp"},
)
assert response.status_code == 502
assert "Failed to connect" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_auth_required(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
)
response = await client.post(
"/discover-tools",
json={"server_url": "https://auth-server.example.com/mcp"},
)
assert response.status_code == 401
assert "requires authentication" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_forbidden(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
)
response = await client.post(
"/discover-tools",
json={"server_url": "https://auth-server.example.com/mcp"},
)
assert response.status_code == 401
assert "requires authentication" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_missing_url(self, client):
response = await client.post("/discover-tools", json={})
assert response.status_code == 422
class TestOAuthLogin:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_success(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch(
"backend.api.features.mcp.routes._register_mcp_client"
) as mock_register,
):
instance = MockClient.return_value
instance.discover_auth = AsyncMock(
return_value={
"authorization_servers": ["https://auth.sentry.io"],
"resource": "https://mcp.sentry.dev/mcp",
"scopes_supported": ["openid"],
}
)
instance.discover_auth_server_metadata = AsyncMock(
return_value={
"authorization_endpoint": "https://auth.sentry.io/authorize",
"token_endpoint": "https://auth.sentry.io/token",
"registration_endpoint": "https://auth.sentry.io/register",
}
)
mock_register.return_value = {
"client_id": "registered-client-id",
"client_secret": "registered-secret",
}
mock_cm.store.store_state_token = AsyncMock(
return_value=("state-token-123", "code-challenge-abc")
)
mock_settings.config.frontend_base_url = "http://localhost:3000"
response = await client.post(
"/oauth/login",
json={"server_url": "https://mcp.sentry.dev/mcp"},
)
assert response.status_code == 200
data = response.json()
assert "login_url" in data
assert data["state_token"] == "state-token-123"
assert "auth.sentry.io/authorize" in data["login_url"]
assert "registered-client-id" in data["login_url"]
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_no_oauth_support(self, client):
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
instance = MockClient.return_value
instance.discover_auth = AsyncMock(return_value=None)
instance.discover_auth_server_metadata = AsyncMock(return_value=None)
response = await client.post(
"/oauth/login",
json={"server_url": "https://simple-server.example.com/mcp"},
)
assert response.status_code == 400
assert "does not advertise OAuth" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_fallback_to_public_client(self, client):
"""When DCR is unavailable, falls back to default public client ID."""
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
):
instance = MockClient.return_value
instance.discover_auth = AsyncMock(
return_value={
"authorization_servers": ["https://auth.example.com"],
"resource": "https://mcp.example.com/mcp",
}
)
instance.discover_auth_server_metadata = AsyncMock(
return_value={
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
# No registration_endpoint
}
)
mock_cm.store.store_state_token = AsyncMock(
return_value=("state-abc", "challenge-xyz")
)
mock_settings.config.frontend_base_url = "http://localhost:3000"
response = await client.post(
"/oauth/login",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
data = response.json()
assert "autogpt-platform" in data["login_url"]
class TestOAuthCallback:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_success(self, client):
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
mock_creds = OAuth2Credentials(
provider="mcp",
title=None,
access_token=SecretStr("access-token-xyz"),
refresh_token=None,
access_token_expires_at=None,
refresh_token_expires_at=None,
scopes=[],
metadata={
"mcp_token_url": "https://auth.sentry.io/token",
"mcp_resource_url": "https://mcp.sentry.dev/mcp",
},
)
with (
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
):
mock_settings.config.frontend_base_url = "http://localhost:3000"
# Mock state verification
mock_state = AsyncMock()
mock_state.state_metadata = {
"authorize_url": "https://auth.sentry.io/authorize",
"token_url": "https://auth.sentry.io/token",
"client_id": "test-client-id",
"client_secret": "test-secret",
"server_url": "https://mcp.sentry.dev/mcp",
}
mock_state.scopes = ["openid"]
mock_state.code_verifier = "verifier-123"
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
mock_cm.create = AsyncMock()
handler_instance = MockHandler.return_value
handler_instance.exchange_code_for_tokens = AsyncMock(
return_value=mock_creds
)
# Mock old credential cleanup
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
response = await client.post(
"/oauth/callback",
json={"code": "auth-code-abc", "state_token": "state-token-123"},
)
assert response.status_code == 200
data = response.json()
assert "id" in data
assert data["provider"] == "mcp"
assert data["type"] == "oauth2"
mock_cm.create.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_invalid_state(self, client):
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
mock_cm.store.verify_state_token = AsyncMock(return_value=None)
response = await client.post(
"/oauth/callback",
json={"code": "auth-code", "state_token": "bad-state"},
)
assert response.status_code == 400
assert "Invalid or expired" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_token_exchange_fails(self, client):
with (
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
):
mock_settings.config.frontend_base_url = "http://localhost:3000"
mock_state = AsyncMock()
mock_state.state_metadata = {
"authorize_url": "https://auth.example.com/authorize",
"token_url": "https://auth.example.com/token",
"client_id": "cid",
"server_url": "https://mcp.example.com/mcp",
}
mock_state.scopes = []
mock_state.code_verifier = "v"
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
handler_instance = MockHandler.return_value
handler_instance.exchange_code_for_tokens = AsyncMock(
side_effect=RuntimeError("Token exchange failed")
)
response = await client.post(
"/oauth/callback",
json={"code": "bad-code", "state_token": "state"},
)
assert response.status_code == 400
assert "token exchange failed" in response.json()["detail"].lower()

View File

@@ -393,6 +393,7 @@ async def get_creators(
@router.get(
"/creator/{username}",
summary="Get creator details",
operation_id="getV2GetCreatorDetails",
tags=["store", "public"],
response_model=store_model.CreatorDetails,
)

View File

@@ -18,6 +18,7 @@ from prisma.errors import PrismaError
import backend.api.features.admin.credit_admin_routes
import backend.api.features.admin.execution_analytics_routes
import backend.api.features.admin.llm_routes
import backend.api.features.admin.store_admin_routes
import backend.api.features.builder
import backend.api.features.builder.routes
@@ -26,7 +27,6 @@ import backend.api.features.executions.review.routes
import backend.api.features.library.db
import backend.api.features.library.model
import backend.api.features.library.routes
import backend.api.features.mcp.routes as mcp_routes
import backend.api.features.oauth
import backend.api.features.otto.routes
import backend.api.features.postmark.postmark
@@ -39,13 +39,15 @@ import backend.data.db
import backend.data.graph
import backend.data.user
import backend.integrations.webhooks.utils
import backend.server.v2.llm.routes as public_llm_routes
import backend.util.service
import backend.util.settings
from backend.api.features.chat.completion_consumer import (
start_completion_consumer,
stop_completion_consumer,
)
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.data import llm_registry
from backend.data.block_cost_config import refresh_llm_costs
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.monitoring.instrumentation import instrument_fastapi
@@ -116,11 +118,27 @@ async def lifespan_context(app: fastapi.FastAPI):
AutoRegistry.patch_integrations()
# Refresh LLM registry before initializing blocks so blocks can use registry data
await llm_registry.refresh_llm_registry()
await refresh_llm_costs()
# Clear block schema caches so they're regenerated with updated discriminator_mapping
from backend.blocks._base import BlockSchema
BlockSchema.clear_all_schema_caches()
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
# migrate_llm_models uses registry default model
from backend.blocks.llm import LlmModel
default_model_slug = llm_registry.get_default_model_slug()
if default_model_slug:
await backend.data.graph.migrate_llm_models(LlmModel(default_model_slug))
else:
logger.warning("Skipping LLM model migration: no default model available")
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
# Start chat completion consumer for Redis Streams notifications
@@ -322,6 +340,16 @@ app.include_router(
tags=["v2", "executions", "review"],
prefix="/api/review",
)
app.include_router(
backend.api.features.admin.llm_routes.router,
tags=["v2", "admin", "llm"],
prefix="/api/llm/admin",
)
app.include_router(
public_llm_routes.router,
tags=["v2", "llm"],
prefix="/api",
)
app.include_router(
backend.api.features.library.routes.router, tags=["v2"], prefix="/api/library"
)
@@ -344,11 +372,6 @@ app.include_router(
tags=["workspace"],
prefix="/api/workspace",
)
app.include_router(
mcp_routes.router,
tags=["v2", "mcp"],
prefix="/api/mcp",
)
app.include_router(
backend.api.features.oauth.router,
tags=["oauth"],

View File

@@ -79,11 +79,49 @@ async def event_broadcaster(manager: ConnectionManager):
payload=notification.payload,
)
await asyncio.gather(execution_worker(), notification_worker())
# Track registry pubsub for cleanup
registry_pubsub = None
async def registry_refresh_worker():
"""Listen for LLM registry refresh notifications and broadcast to all clients."""
nonlocal registry_pubsub
from backend.data.llm_registry import REGISTRY_REFRESH_CHANNEL
from backend.data.redis_client import connect_async
redis = await connect_async()
registry_pubsub = redis.pubsub()
await registry_pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
logger.info(
"Subscribed to LLM registry refresh notifications for WebSocket broadcast"
)
async for message in registry_pubsub.listen():
if (
message["type"] == "message"
and message["channel"] == REGISTRY_REFRESH_CHANNEL
):
logger.info(
"Broadcasting LLM registry refresh to all WebSocket clients"
)
await manager.broadcast_to_all(
method=WSMethod.NOTIFICATION,
data={
"type": "LLM_REGISTRY_REFRESH",
"event": "registry_updated",
},
)
await asyncio.gather(
execution_worker(),
notification_worker(),
registry_refresh_worker(),
)
finally:
# Ensure PubSub connections are closed on any exit to prevent leaks
await execution_bus.close()
await notification_bus.close()
if registry_pubsub:
await registry_pubsub.close()
async def authenticate_websocket(websocket: WebSocket) -> str:

View File

@@ -64,7 +64,6 @@ class BlockType(Enum):
AI = "AI"
AYRSHARE = "Ayrshare"
HUMAN_IN_THE_LOOP = "Human In The Loop"
MCP_TOOL = "MCP Tool"
class BlockCategory(Enum):
@@ -134,7 +133,26 @@ class BlockInfo(BaseModel):
class BlockSchema(BaseModel):
cached_jsonschema: ClassVar[dict[str, Any]]
cached_jsonschema: ClassVar[dict[str, Any] | None] = None
@classmethod
def clear_schema_cache(cls) -> None:
"""Clear the cached JSON schema for this class."""
# Use None instead of {} because {} is truthy and would prevent regeneration
cls.cached_jsonschema = None # type: ignore
@staticmethod
def clear_all_schema_caches() -> None:
"""Clear cached JSON schemas for all BlockSchema subclasses."""
def clear_recursive(cls: type) -> None:
"""Recursively clear cache for class and all subclasses."""
if hasattr(cls, "clear_schema_cache"):
cls.clear_schema_cache()
for subclass in cls.__subclasses__():
clear_recursive(subclass)
clear_recursive(BlockSchema)
@classmethod
def jsonschema(cls) -> dict[str, Any]:
@@ -225,7 +243,8 @@ class BlockSchema(BaseModel):
super().__pydantic_init_subclass__(**kwargs)
# Reset cached JSON schema to prevent inheriting it from parent class
cls.cached_jsonschema = {}
# Use None instead of {} because {} is truthy and would prevent regeneration
cls.cached_jsonschema = None
credentials_fields = cls.get_credentials_fields()

View File

@@ -7,7 +7,6 @@ from backend.blocks._base import (
BlockSchemaOutput,
)
from backend.blocks.llm import (
DEFAULT_LLM_MODEL,
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
AIBlockBase,
@@ -16,6 +15,7 @@ from backend.blocks.llm import (
LlmModel,
LLMResponse,
llm_call,
llm_model_schema_extra,
)
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
@@ -50,9 +50,10 @@ class AIConditionBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default_factory=LlmModel.default,
description="The language model to use for evaluating the condition.",
advanced=False,
json_schema_extra=llm_model_schema_extra(),
)
credentials: AICredentials = AICredentialsField()
@@ -82,7 +83,7 @@ class AIConditionBlock(AIBlockBase):
"condition": "the input is an email address",
"yes_value": "Valid email",
"no_value": "Not an email",
"model": DEFAULT_LLM_MODEL,
"model": LlmModel.default(),
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,

View File

@@ -126,7 +126,6 @@ class PrintToConsoleBlock(Block):
output_schema=PrintToConsoleBlock.Output,
test_input={"text": "Hello, World!"},
is_sensitive_action=True,
disabled=True, # Disabled per Nick Tindle's request (OPEN-3000)
test_output=[
("output", "Hello, World!"),
("status", "printed"),

View File

@@ -17,7 +17,6 @@ from backend.blocks.jina._auth import (
from backend.blocks.search import GetRequest
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
from backend.util.request import HTTPClientError, HTTPServerError, validate_url
class SearchTheWebBlock(Block, GetRequest):
@@ -111,12 +110,7 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
self, input_data: Input, *, credentials: JinaCredentials, **kwargs
) -> BlockOutput:
if input_data.raw_content:
try:
parsed_url, _, _ = await validate_url(input_data.url, [])
url = parsed_url.geturl()
except ValueError as e:
yield "error", f"Invalid URL: {e}"
return
url = input_data.url
headers = {}
else:
url = f"https://r.jina.ai/{input_data.url}"
@@ -125,20 +119,5 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
"Authorization": f"Bearer {credentials.api_key.get_secret_value()}",
}
try:
content = await self.get_request(url, json=False, headers=headers)
except HTTPClientError as e:
yield "error", f"Client error ({e.status_code}) fetching {input_data.url}: {e}"
return
except HTTPServerError as e:
yield "error", f"Server error ({e.status_code}) fetching {input_data.url}: {e}"
return
except Exception as e:
yield "error", f"Failed to fetch {input_data.url}: {e}"
return
if not content:
yield "error", f"No content returned for {input_data.url}"
return
content = await self.get_request(url, json=False, headers=headers)
yield "content", content

View File

@@ -4,16 +4,18 @@ import logging
import re
import secrets
from abc import ABC
from enum import Enum, EnumMeta
from dataclasses import dataclass
from enum import Enum
from json import JSONDecodeError
from typing import Any, Iterable, List, Literal, NamedTuple, Optional
from typing import Any, Iterable, List, Literal, Optional
import anthropic
import ollama
import openai
from anthropic.types import ToolParam
from groq import AsyncGroq
from pydantic import BaseModel, SecretStr
from pydantic import BaseModel, GetCoreSchemaHandler, SecretStr
from pydantic_core import CoreSchema, core_schema
from backend.blocks._base import (
Block,
@@ -22,6 +24,8 @@ from backend.blocks._base import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data import llm_registry
from backend.data.llm_registry import ModelMetadata
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -66,114 +70,123 @@ TEST_CREDENTIALS_INPUT = {
def AICredentialsField() -> AICredentials:
"""
Returns a CredentialsField for LLM providers.
The discriminator_mapping will be refreshed when the schema is generated
if it's empty, ensuring the LLM registry is loaded.
"""
# Get the mapping now - it may be empty initially, but will be refreshed
# when the schema is generated via CredentialsMetaInput._add_json_schema_extra
mapping = llm_registry.get_llm_discriminator_mapping()
return CredentialsField(
description="API key for the LLM provider.",
discriminator="model",
discriminator_mapping={
model.value: model.metadata.provider for model in LlmModel
},
discriminator_mapping=mapping, # May be empty initially, refreshed later
)
class ModelMetadata(NamedTuple):
provider: str
context_window: int
max_output_tokens: int | None
display_name: str
provider_name: str
creator_name: str
price_tier: Literal[1, 2, 3]
def llm_model_schema_extra() -> dict[str, Any]:
return {"options": llm_registry.get_llm_model_schema_options()}
class LlmModelMeta(EnumMeta):
pass
class LlmModelMeta(type):
"""
Metaclass for LlmModel that enables attribute-style access to dynamic models.
This allows code like `LlmModel.GPT4O` to work by converting the attribute
name to a slug format:
- GPT4O -> gpt-4o
- GPT4O_MINI -> gpt-4o-mini
- CLAUDE_3_5_SONNET -> claude-3-5-sonnet
"""
def __getattr__(cls, name: str):
# Don't intercept private/dunder attributes
if name.startswith("_"):
raise AttributeError(f"type object 'LlmModel' has no attribute '{name}'")
# Convert attribute name to slug format:
# 1. Lowercase: GPT4O -> gpt4o
# 2. Underscores to hyphens: GPT4O_MINI -> gpt4o-mini
slug = name.lower().replace("_", "-")
# Check for exact match in registry first (e.g., "o1" stays "o1")
registry_slugs = llm_registry.get_dynamic_model_slugs()
if slug in registry_slugs:
return cls(slug)
# If no exact match, try inserting hyphen between letter and digit
# e.g., gpt4o -> gpt-4o
transformed_slug = re.sub(r"([a-z])(\d)", r"\1-\2", slug)
return cls(transformed_slug)
def __iter__(cls):
"""Iterate over all models from the registry.
Yields LlmModel instances for each model in the dynamic registry.
Used by __get_pydantic_json_schema__ to build model metadata.
"""
for model in llm_registry.iter_dynamic_models():
yield cls(model.slug)
class LlmModel(str, Enum, metaclass=LlmModelMeta):
# OpenAI models
O3_MINI = "o3-mini"
O3 = "o3-2025-04-16"
O1 = "o1"
O1_MINI = "o1-mini"
# GPT-5 models
GPT5_2 = "gpt-5.2-2025-12-11"
GPT5_1 = "gpt-5.1-2025-11-13"
GPT5 = "gpt-5-2025-08-07"
GPT5_MINI = "gpt-5-mini-2025-08-07"
GPT5_NANO = "gpt-5-nano-2025-08-07"
GPT5_CHAT = "gpt-5-chat-latest"
GPT41 = "gpt-4.1-2025-04-14"
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
GPT4O_MINI = "gpt-4o-mini"
GPT4O = "gpt-4o"
GPT4_TURBO = "gpt-4-turbo"
GPT3_5_TURBO = "gpt-3.5-turbo"
# Anthropic models
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
CLAUDE_4_OPUS = "claude-opus-4-20250514"
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
CLAUDE_4_6_OPUS = "claude-opus-4-6"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
# AI/ML API models
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
AIML_API_LLAMA3_1_70B = "nvidia/llama-3.1-nemotron-70b-instruct"
AIML_API_LLAMA3_3_70B = "meta-llama/Llama-3.3-70B-Instruct-Turbo"
AIML_API_META_LLAMA_3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
AIML_API_LLAMA_3_2_3B = "meta-llama/Llama-3.2-3B-Instruct-Turbo"
# Groq models
LLAMA3_3_70B = "llama-3.3-70b-versatile"
LLAMA3_1_8B = "llama-3.1-8b-instant"
# Ollama models
OLLAMA_LLAMA3_3 = "llama3.3"
OLLAMA_LLAMA3_2 = "llama3.2"
OLLAMA_LLAMA3_8B = "llama3"
OLLAMA_LLAMA3_405B = "llama3.1:405b"
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
# OpenRouter models
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
MISTRAL_NEMO = "mistralai/mistral-nemo"
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
PERPLEXITY_SONAR = "perplexity/sonar"
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
AMAZON_NOVA_LITE_V1 = "amazon/nova-lite-v1"
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
GROK_4 = "x-ai/grok-4"
GROK_4_FAST = "x-ai/grok-4-fast"
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
KIMI_K2 = "moonshotai/kimi-k2"
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
QWEN3_CODER = "qwen/qwen3-coder"
# Llama API models
LLAMA_API_LLAMA_4_SCOUT = "Llama-4-Scout-17B-16E-Instruct-FP8"
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
LLAMA_API_LLAMA3_3_8B = "Llama-3.3-8B-Instruct"
LLAMA_API_LLAMA3_3_70B = "Llama-3.3-70B-Instruct"
# v0 by Vercel models
V0_1_5_MD = "v0-1.5-md"
V0_1_5_LG = "v0-1.5-lg"
V0_1_0_MD = "v0-1.0-md"
class LlmModel(str, metaclass=LlmModelMeta):
"""
Dynamic LLM model type that accepts any model slug from the registry.
This is a string subclass (not an Enum) that allows any model slug value.
All models are managed via the LLM Registry in the database.
Usage:
model = LlmModel("gpt-4o") # Direct construction
model = LlmModel.GPT4O # Attribute access (converted to "gpt-4o")
model.value # Returns the slug string
model.provider # Returns the provider from registry
"""
def __new__(cls, value: str):
if isinstance(value, LlmModel):
return value
return str.__new__(cls, value)
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema:
"""
Tell Pydantic how to validate LlmModel.
Accepts strings and converts them to LlmModel instances.
"""
return core_schema.no_info_after_validator_function(
cls, # The validator function (LlmModel constructor)
core_schema.str_schema(), # Accept string input
serialization=core_schema.to_string_ser_schema(), # Serialize as string
)
@property
def value(self) -> str:
"""Return the model slug (for compatibility with enum-style access)."""
return str(self)
@classmethod
def default(cls) -> "LlmModel":
"""
Get the default model from the registry.
Returns the recommended model if set, otherwise gpt-4o if available
and enabled, otherwise the first enabled model from the registry.
Falls back to "gpt-4o" if registry is empty (e.g., at module import time).
"""
from backend.data.llm_registry import get_default_model_slug
slug = get_default_model_slug()
if slug is None:
# Registry is empty (e.g., at module import time before DB connection).
# Fall back to gpt-4o for backward compatibility.
slug = "gpt-4o"
return cls(slug)
@classmethod
def __get_pydantic_json_schema__(cls, schema, handler):
@@ -181,7 +194,15 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
llm_model_metadata = {}
for model in cls:
model_name = model.value
metadata = model.metadata
# Skip disabled models - only show enabled models in the picker
if not llm_registry.is_model_enabled(model_name):
continue
# Use registry directly with None check to gracefully handle
# missing metadata during startup/import before registry is populated
metadata = llm_registry.get_llm_model_metadata(model_name)
if metadata is None:
# Skip models without metadata (registry not yet populated)
continue
llm_model_metadata[model_name] = {
"creator": metadata.creator_name,
"creator_name": metadata.creator_name,
@@ -197,7 +218,12 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
@property
def metadata(self) -> ModelMetadata:
return MODEL_METADATA[self]
metadata = llm_registry.get_llm_model_metadata(self.value)
if metadata:
return metadata
raise ValueError(
f"Missing metadata for model: {self.value}. Model not found in LLM registry."
)
@property
def provider(self) -> str:
@@ -212,300 +238,125 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
return self.metadata.max_output_tokens
MODEL_METADATA = {
# https://platform.openai.com/docs/models
LlmModel.O3: ModelMetadata("openai", 200000, 100000, "O3", "OpenAI", "OpenAI", 2),
LlmModel.O3_MINI: ModelMetadata(
"openai", 200000, 100000, "O3 Mini", "OpenAI", "OpenAI", 1
), # o3-mini-2025-01-31
LlmModel.O1: ModelMetadata(
"openai", 200000, 100000, "O1", "OpenAI", "OpenAI", 3
), # o1-2024-12-17
LlmModel.O1_MINI: ModelMetadata(
"openai", 128000, 65536, "O1 Mini", "OpenAI", "OpenAI", 2
), # o1-mini-2024-09-12
# GPT-5 models
LlmModel.GPT5_2: ModelMetadata(
"openai", 400000, 128000, "GPT-5.2", "OpenAI", "OpenAI", 3
),
LlmModel.GPT5_1: ModelMetadata(
"openai", 400000, 128000, "GPT-5.1", "OpenAI", "OpenAI", 2
),
LlmModel.GPT5: ModelMetadata(
"openai", 400000, 128000, "GPT-5", "OpenAI", "OpenAI", 1
),
LlmModel.GPT5_MINI: ModelMetadata(
"openai", 400000, 128000, "GPT-5 Mini", "OpenAI", "OpenAI", 1
),
LlmModel.GPT5_NANO: ModelMetadata(
"openai", 400000, 128000, "GPT-5 Nano", "OpenAI", "OpenAI", 1
),
LlmModel.GPT5_CHAT: ModelMetadata(
"openai", 400000, 16384, "GPT-5 Chat Latest", "OpenAI", "OpenAI", 2
),
LlmModel.GPT41: ModelMetadata(
"openai", 1047576, 32768, "GPT-4.1", "OpenAI", "OpenAI", 1
),
LlmModel.GPT41_MINI: ModelMetadata(
"openai", 1047576, 32768, "GPT-4.1 Mini", "OpenAI", "OpenAI", 1
),
LlmModel.GPT4O_MINI: ModelMetadata(
"openai", 128000, 16384, "GPT-4o Mini", "OpenAI", "OpenAI", 1
), # gpt-4o-mini-2024-07-18
LlmModel.GPT4O: ModelMetadata(
"openai", 128000, 16384, "GPT-4o", "OpenAI", "OpenAI", 2
), # gpt-4o-2024-08-06
LlmModel.GPT4_TURBO: ModelMetadata(
"openai", 128000, 4096, "GPT-4 Turbo", "OpenAI", "OpenAI", 3
), # gpt-4-turbo-2024-04-09
LlmModel.GPT3_5_TURBO: ModelMetadata(
"openai", 16385, 4096, "GPT-3.5 Turbo", "OpenAI", "OpenAI", 1
), # gpt-3.5-turbo-0125
# https://docs.anthropic.com/en/docs/about-claude/models
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
"anthropic", 200000, 32000, "Claude Opus 4.1", "Anthropic", "Anthropic", 3
), # claude-opus-4-1-20250805
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
"anthropic", 200000, 32000, "Claude Opus 4", "Anthropic", "Anthropic", 3
), # claude-4-opus-20250514
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
), # claude-4-sonnet-20250514
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
), # claude-opus-4-6
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
), # claude-opus-4-5-20251101
LlmModel.CLAUDE_4_5_SONNET: ModelMetadata(
"anthropic", 200000, 64000, "Claude Sonnet 4.5", "Anthropic", "Anthropic", 3
), # claude-sonnet-4-5-20250929
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
), # claude-haiku-4-5-20251001
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
), # claude-3-haiku-20240307
# https://docs.aimlapi.com/api-overview/model-database/text-models
LlmModel.AIML_API_QWEN2_5_72B: ModelMetadata(
"aiml_api", 32000, 8000, "Qwen 2.5 72B Instruct Turbo", "AI/ML", "Qwen", 1
),
LlmModel.AIML_API_LLAMA3_1_70B: ModelMetadata(
"aiml_api",
128000,
40000,
"Llama 3.1 Nemotron 70B Instruct",
"AI/ML",
"Nvidia",
1,
),
LlmModel.AIML_API_LLAMA3_3_70B: ModelMetadata(
"aiml_api", 128000, None, "Llama 3.3 70B Instruct Turbo", "AI/ML", "Meta", 1
),
LlmModel.AIML_API_META_LLAMA_3_1_70B: ModelMetadata(
"aiml_api", 131000, 2000, "Llama 3.1 70B Instruct Turbo", "AI/ML", "Meta", 1
),
LlmModel.AIML_API_LLAMA_3_2_3B: ModelMetadata(
"aiml_api", 128000, None, "Llama 3.2 3B Instruct Turbo", "AI/ML", "Meta", 1
),
# https://console.groq.com/docs/models
LlmModel.LLAMA3_3_70B: ModelMetadata(
"groq", 128000, 32768, "Llama 3.3 70B Versatile", "Groq", "Meta", 1
),
LlmModel.LLAMA3_1_8B: ModelMetadata(
"groq", 128000, 8192, "Llama 3.1 8B Instant", "Groq", "Meta", 1
),
# https://ollama.com/library
LlmModel.OLLAMA_LLAMA3_3: ModelMetadata(
"ollama", 8192, None, "Llama 3.3", "Ollama", "Meta", 1
),
LlmModel.OLLAMA_LLAMA3_2: ModelMetadata(
"ollama", 8192, None, "Llama 3.2", "Ollama", "Meta", 1
),
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata(
"ollama", 8192, None, "Llama 3", "Ollama", "Meta", 1
),
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata(
"ollama", 8192, None, "Llama 3.1 405B", "Ollama", "Meta", 1
),
LlmModel.OLLAMA_DOLPHIN: ModelMetadata(
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
),
# https://openrouter.ai/models
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
"open_router",
1050000,
8192,
"Gemini 2.5 Pro Preview 03.25",
"OpenRouter",
"Google",
2,
),
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
),
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
),
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
),
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
"open_router",
1048576,
65535,
"Gemini 2.5 Flash Lite Preview 06.17",
"OpenRouter",
"Google",
1,
),
LlmModel.GEMINI_2_0_FLASH_LITE: ModelMetadata(
"open_router",
1048576,
8192,
"Gemini 2.0 Flash Lite 001",
"OpenRouter",
"Google",
1,
),
LlmModel.MISTRAL_NEMO: ModelMetadata(
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
),
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
),
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
),
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
),
LlmModel.DEEPSEEK_R1_0528: ModelMetadata(
"open_router", 163840, 163840, "DeepSeek R1 0528", "OpenRouter", "DeepSeek", 1
),
LlmModel.PERPLEXITY_SONAR: ModelMetadata(
"open_router", 127000, 8000, "Sonar", "OpenRouter", "Perplexity", 1
),
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
),
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
"open_router",
128000,
16000,
"Sonar Deep Research",
"OpenRouter",
"Perplexity",
3,
),
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata(
"open_router",
131000,
4096,
"Hermes 3 Llama 3.1 405B",
"OpenRouter",
"Nous Research",
1,
),
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata(
"open_router",
12288,
12288,
"Hermes 3 Llama 3.1 70B",
"OpenRouter",
"Nous Research",
1,
),
LlmModel.OPENAI_GPT_OSS_120B: ModelMetadata(
"open_router", 131072, 131072, "GPT-OSS 120B", "OpenRouter", "OpenAI", 1
),
LlmModel.OPENAI_GPT_OSS_20B: ModelMetadata(
"open_router", 131072, 32768, "GPT-OSS 20B", "OpenRouter", "OpenAI", 1
),
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata(
"open_router", 300000, 5120, "Nova Lite V1", "OpenRouter", "Amazon", 1
),
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata(
"open_router", 128000, 5120, "Nova Micro V1", "OpenRouter", "Amazon", 1
),
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata(
"open_router", 300000, 5120, "Nova Pro V1", "OpenRouter", "Amazon", 1
),
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
),
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
),
LlmModel.META_LLAMA_4_SCOUT: ModelMetadata(
"open_router", 131072, 131072, "Llama 4 Scout", "OpenRouter", "Meta", 1
),
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
),
LlmModel.GROK_4: ModelMetadata(
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
),
LlmModel.GROK_4_FAST: ModelMetadata(
"open_router", 2000000, 30000, "Grok 4 Fast", "OpenRouter", "xAI", 1
),
LlmModel.GROK_4_1_FAST: ModelMetadata(
"open_router", 2000000, 30000, "Grok 4.1 Fast", "OpenRouter", "xAI", 1
),
LlmModel.GROK_CODE_FAST_1: ModelMetadata(
"open_router", 256000, 10000, "Grok Code Fast 1", "OpenRouter", "xAI", 1
),
LlmModel.KIMI_K2: ModelMetadata(
"open_router", 131000, 131000, "Kimi K2", "OpenRouter", "Moonshot AI", 1
),
LlmModel.QWEN3_235B_A22B_THINKING: ModelMetadata(
"open_router",
262144,
262144,
"Qwen 3 235B A22B Thinking 2507",
"OpenRouter",
"Qwen",
1,
),
LlmModel.QWEN3_CODER: ModelMetadata(
"open_router", 262144, 262144, "Qwen 3 Coder", "OpenRouter", "Qwen", 3
),
# Llama API models
LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata(
"llama_api",
128000,
4028,
"Llama 4 Scout 17B 16E Instruct FP8",
"Llama API",
"Meta",
1,
),
LlmModel.LLAMA_API_LLAMA4_MAVERICK: ModelMetadata(
"llama_api",
128000,
4028,
"Llama 4 Maverick 17B 128E Instruct FP8",
"Llama API",
"Meta",
1,
),
LlmModel.LLAMA_API_LLAMA3_3_8B: ModelMetadata(
"llama_api", 128000, 4028, "Llama 3.3 8B Instruct", "Llama API", "Meta", 1
),
LlmModel.LLAMA_API_LLAMA3_3_70B: ModelMetadata(
"llama_api", 128000, 4028, "Llama 3.3 70B Instruct", "Llama API", "Meta", 1
),
# v0 by Vercel models
LlmModel.V0_1_5_MD: ModelMetadata("v0", 128000, 64000, "v0 1.5 MD", "V0", "V0", 1),
LlmModel.V0_1_5_LG: ModelMetadata("v0", 512000, 64000, "v0 1.5 LG", "V0", "V0", 1),
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000, "v0 1.0 MD", "V0", "V0", 1),
}
# Default model constant for backward compatibility
# Uses the dynamic registry to get the default model
DEFAULT_LLM_MODEL = LlmModel.default()
DEFAULT_LLM_MODEL = LlmModel.GPT5_2
for model in LlmModel:
if model not in MODEL_METADATA:
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
class ModelUnavailableError(ValueError):
"""Raised when a requested LLM model cannot be resolved for use."""
pass
@dataclass
class ResolvedModel:
"""Result of resolving a model for an LLM call."""
slug: str # The actual model slug to use (may differ from requested if fallback)
provider: str
context_window: int
max_output_tokens: int
used_fallback: bool = False
original_slug: str | None = None # Set if fallback was used
async def resolve_model_for_call(llm_model: LlmModel) -> ResolvedModel:
"""
Resolve a model for use in an LLM call.
Handles:
- Checking if the model exists in the registry
- Falling back to an enabled model from the same provider if disabled
- Refreshing the registry cache if model not found (with DB access)
Args:
llm_model: The requested LlmModel
Returns:
ResolvedModel with all necessary metadata for the call
Raises:
ModelUnavailableError: If model cannot be resolved (not found, disabled with no fallback)
"""
from backend.data.llm_registry import (
get_fallback_model_for_disabled,
get_model_info,
)
model_info = get_model_info(llm_model.value)
# Case 1: Model found and disabled - try fallback
if model_info and not model_info.is_enabled:
fallback = get_fallback_model_for_disabled(llm_model.value)
if fallback:
logger.warning(
f"Model '{llm_model.value}' is disabled. Using fallback "
f"'{fallback.slug}' from same provider ({fallback.metadata.provider})."
)
return ResolvedModel(
slug=fallback.slug,
provider=fallback.metadata.provider,
context_window=fallback.metadata.context_window,
max_output_tokens=fallback.metadata.max_output_tokens or 2**15,
used_fallback=True,
original_slug=llm_model.value,
)
raise ModelUnavailableError(
f"Model '{llm_model.value}' is disabled and no fallback from the same "
f"provider is available. Enable the model or select a different one."
)
# Case 2: Model found and enabled - use it directly
if model_info:
return ResolvedModel(
slug=llm_model.value,
provider=model_info.metadata.provider,
context_window=model_info.metadata.context_window,
max_output_tokens=model_info.metadata.max_output_tokens or 2**15,
)
# Case 3: Model not in registry - try refresh if DB available
logger.warning(f"Model '{llm_model.value}' not found in registry cache")
from backend.data.db import is_connected
if not is_connected():
raise ModelUnavailableError(
f"Model '{llm_model.value}' not found in registry. "
f"The registry may need to be refreshed via the admin UI."
)
# Try refreshing the registry
try:
logger.info(f"Refreshing LLM registry for model '{llm_model.value}'")
await llm_registry.refresh_llm_registry()
except Exception as e:
raise ModelUnavailableError(
f"Model '{llm_model.value}' not found and registry refresh failed: {e}"
) from e
# Check again after refresh
model_info = get_model_info(llm_model.value)
if not model_info:
raise ModelUnavailableError(
f"Model '{llm_model.value}' not found in registry. "
f"Add it via the admin UI at /admin/llms."
)
if not model_info.is_enabled:
raise ModelUnavailableError(
f"Model '{llm_model.value}' exists but is disabled. "
f"Enable it via the admin UI at /admin/llms."
)
logger.info(f"Model '{llm_model.value}' loaded after registry refresh")
return ResolvedModel(
slug=llm_model.value,
provider=model_info.metadata.provider,
context_window=model_info.metadata.context_window,
max_output_tokens=model_info.metadata.max_output_tokens or 2**15,
)
class ToolCall(BaseModel):
@@ -531,12 +382,12 @@ class LLMResponse(BaseModel):
def convert_openai_tool_fmt_to_anthropic(
openai_tools: list[dict] | None = None,
) -> Iterable[ToolParam] | anthropic.Omit:
) -> Iterable[ToolParam] | anthropic.NotGiven:
"""
Convert OpenAI tool format to Anthropic tool format.
"""
if not openai_tools or len(openai_tools) == 0:
return anthropic.omit
return anthropic.NOT_GIVEN
anthropic_tools = []
for tool in openai_tools:
@@ -598,7 +449,12 @@ def get_parallel_tool_calls_param(
llm_model: LlmModel, parallel_tool_calls: bool | None
) -> bool | openai.Omit:
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
if llm_model.startswith("o") or parallel_tool_calls is None:
# Check for o-series models (o1, o1-mini, o3-mini, etc.) which don't support
# parallel tool calls. Handle both bare slugs ("o1-mini") and provider-prefixed
# slugs ("openai/o1-mini"). The pattern matches "o" followed by a digit at the
# start of the string or after a "/" separator.
is_o_series = re.search(r"(^|/)o\d", llm_model) is not None
if is_o_series or parallel_tool_calls is None:
return openai.omit
return parallel_tool_calls
@@ -634,15 +490,22 @@ async def llm_call(
- prompt_tokens: The number of tokens used in the prompt.
- completion_tokens: The number of tokens used in the completion.
"""
provider = llm_model.metadata.provider
context_window = llm_model.context_window
# Resolve the model - handles disabled models, fallbacks, and cache misses
resolved = await resolve_model_for_call(llm_model)
model_to_use = resolved.slug
provider = resolved.provider
context_window = resolved.context_window
model_max_output = resolved.max_output_tokens
# Create effective model for model-specific parameter resolution (e.g., o-series check)
effective_model = LlmModel(model_to_use)
if compress_prompt_to_fit:
result = await compress_context(
messages=prompt,
target_tokens=llm_model.context_window // 2,
target_tokens=context_window // 2,
client=None, # Truncation-only, no LLM summarization
reserve=0, # Caller handles response token budget separately
)
if result.error:
logger.warning(
@@ -653,7 +516,7 @@ async def llm_call(
# Calculate available tokens based on context window and input length
estimated_input_tokens = estimate_token_count(prompt)
model_max_output = llm_model.max_output_tokens or int(2**15)
# model_max_output already set above
user_max = max_tokens or model_max_output
available_tokens = max(context_window - estimated_input_tokens, 0)
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
@@ -664,14 +527,14 @@ async def llm_call(
response_format = None
parallel_tool_calls = get_parallel_tool_calls_param(
llm_model, parallel_tool_calls
effective_model, parallel_tool_calls
)
if force_json_output:
response_format = {"type": "json_object"}
response = await oai_client.chat.completions.create(
model=llm_model.value,
model=model_to_use,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_completion_tokens=max_tokens,
@@ -718,7 +581,7 @@ async def llm_call(
)
try:
resp = await client.messages.create(
model=llm_model.value,
model=model_to_use,
system=sysprompt,
messages=messages,
max_tokens=max_tokens,
@@ -782,7 +645,7 @@ async def llm_call(
client = AsyncGroq(api_key=credentials.api_key.get_secret_value())
response_format = {"type": "json_object"} if force_json_output else None
response = await client.chat.completions.create(
model=llm_model.value,
model=model_to_use,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,
@@ -804,7 +667,7 @@ async def llm_call(
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
response = await client.generate(
model=llm_model.value,
model=model_to_use,
prompt=f"{sys_messages}\n\n{usr_messages}",
stream=False,
options={"num_ctx": max_tokens},
@@ -826,7 +689,7 @@ async def llm_call(
)
parallel_tool_calls_param = get_parallel_tool_calls_param(
llm_model, parallel_tool_calls
effective_model, parallel_tool_calls
)
response = await client.chat.completions.create(
@@ -834,7 +697,7 @@ async def llm_call(
"HTTP-Referer": "https://agpt.co",
"X-Title": "AutoGPT",
},
model=llm_model.value,
model=model_to_use,
messages=prompt, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
@@ -868,7 +731,7 @@ async def llm_call(
)
parallel_tool_calls_param = get_parallel_tool_calls_param(
llm_model, parallel_tool_calls
effective_model, parallel_tool_calls
)
response = await client.chat.completions.create(
@@ -876,7 +739,7 @@ async def llm_call(
"HTTP-Referer": "https://agpt.co",
"X-Title": "AutoGPT",
},
model=llm_model.value,
model=model_to_use,
messages=prompt, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
@@ -903,7 +766,7 @@ async def llm_call(
reasoning=reasoning,
)
elif provider == "aiml_api":
client = openai.OpenAI(
client = openai.AsyncOpenAI(
base_url="https://api.aimlapi.com/v2",
api_key=credentials.api_key.get_secret_value(),
default_headers={
@@ -913,8 +776,8 @@ async def llm_call(
},
)
completion = client.chat.completions.create(
model=llm_model.value,
completion = await client.chat.completions.create(
model=model_to_use,
messages=prompt, # type: ignore
max_tokens=max_tokens,
)
@@ -942,11 +805,11 @@ async def llm_call(
response_format = {"type": "json_object"}
parallel_tool_calls_param = get_parallel_tool_calls_param(
llm_model, parallel_tool_calls
effective_model, parallel_tool_calls
)
response = await client.chat.completions.create(
model=llm_model.value,
model=model_to_use,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,
@@ -997,9 +860,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default_factory=LlmModel.default,
description="The language model to use for answering the prompt.",
advanced=False,
json_schema_extra=llm_model_schema_extra(),
)
force_json_output: bool = SchemaField(
title="Restrict LLM to pure JSON output",
@@ -1062,7 +926,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
input_schema=AIStructuredResponseGeneratorBlock.Input,
output_schema=AIStructuredResponseGeneratorBlock.Output,
test_input={
"model": DEFAULT_LLM_MODEL,
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
"credentials": TEST_CREDENTIALS_INPUT,
"expected_format": {
"key1": "value1",
@@ -1428,9 +1292,10 @@ class AITextGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default_factory=LlmModel.default,
description="The language model to use for answering the prompt.",
advanced=False,
json_schema_extra=llm_model_schema_extra(),
)
credentials: AICredentials = AICredentialsField()
sys_prompt: str = SchemaField(
@@ -1524,8 +1389,9 @@ class AITextSummarizerBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default_factory=LlmModel.default,
description="The language model to use for summarizing the text.",
json_schema_extra=llm_model_schema_extra(),
)
focus: str = SchemaField(
title="Focus",
@@ -1741,8 +1607,9 @@ class AIConversationBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default_factory=LlmModel.default,
description="The language model to use for the conversation.",
json_schema_extra=llm_model_schema_extra(),
)
credentials: AICredentials = AICredentialsField()
max_tokens: int | None = SchemaField(
@@ -1779,7 +1646,7 @@ class AIConversationBlock(AIBlockBase):
},
{"role": "user", "content": "Where was it played?"},
],
"model": DEFAULT_LLM_MODEL,
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
@@ -1842,9 +1709,10 @@ class AIListGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default_factory=LlmModel.default,
description="The language model to use for generating the list.",
advanced=True,
json_schema_extra=llm_model_schema_extra(),
)
credentials: AICredentials = AICredentialsField()
max_retries: int = SchemaField(
@@ -1899,7 +1767,7 @@ class AIListGeneratorBlock(AIBlockBase):
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
"fictional worlds."
),
"model": DEFAULT_LLM_MODEL,
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
"credentials": TEST_CREDENTIALS_INPUT,
"max_retries": 3,
"force_json_output": False,

View File

@@ -1,300 +0,0 @@
"""
MCP (Model Context Protocol) Tool Block.
A single dynamic block that can connect to any MCP server, discover available tools,
and execute them. Works like AgentExecutorBlock — the user selects a tool from a
dropdown and the input/output schema adapts dynamically.
"""
import json
import logging
from typing import Any, Literal
from pydantic import SecretStr
from backend.blocks._base import (
Block,
BlockCategory,
BlockSchemaInput,
BlockSchemaOutput,
BlockType,
)
from backend.blocks.mcp.client import MCPClient, MCPClientError
from backend.data.block import BlockInput, BlockOutput
from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
OAuth2Credentials,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.json import validate_with_jsonschema
logger = logging.getLogger(__name__)
TEST_CREDENTIALS = OAuth2Credentials(
id="test-mcp-cred",
provider="mcp",
access_token=SecretStr("mock-mcp-token"),
refresh_token=SecretStr("mock-refresh"),
scopes=[],
title="Mock MCP credential",
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
MCPCredentials = CredentialsMetaInput[Literal[ProviderName.MCP], Literal["oauth2"]]
class MCPToolBlock(Block):
"""
A block that connects to an MCP server, lets the user pick a tool,
and executes it with dynamic input/output schema.
The flow:
1. User provides an MCP server URL (and optional credentials)
2. Frontend calls the backend to get tool list from that URL
3. User selects a tool from a dropdown (available_tools)
4. The block's input schema updates to reflect the selected tool's parameters
5. On execution, the block calls the MCP server to run the tool
"""
class Input(BlockSchemaInput):
server_url: str = SchemaField(
description="URL of the MCP server (Streamable HTTP endpoint)",
placeholder="https://mcp.example.com/mcp",
)
credentials: MCPCredentials = CredentialsField(
discriminator="server_url",
description="MCP server OAuth credentials",
default={},
)
selected_tool: str = SchemaField(
description="The MCP tool to execute",
placeholder="Select a tool",
default="",
)
tool_input_schema: dict[str, Any] = SchemaField(
description="JSON Schema for the selected tool's input parameters. "
"Populated automatically when a tool is selected.",
default={},
hidden=True,
)
tool_arguments: dict[str, Any] = SchemaField(
description="Arguments to pass to the selected MCP tool. "
"The fields here are defined by the tool's input schema.",
default={},
)
@classmethod
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
"""Return the tool's input schema so the builder UI renders dynamic fields."""
return data.get("tool_input_schema", {})
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
"""Return the current tool_arguments as defaults for the dynamic fields."""
return data.get("tool_arguments", {})
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
"""Check which required tool arguments are missing."""
required_fields = cls.get_input_schema(data).get("required", [])
tool_arguments = data.get("tool_arguments", {})
return set(required_fields) - set(tool_arguments)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
"""Validate tool_arguments against the tool's input schema."""
tool_schema = cls.get_input_schema(data)
if not tool_schema:
return None
tool_arguments = data.get("tool_arguments", {})
return validate_with_jsonschema(tool_schema, tool_arguments)
class Output(BlockSchemaOutput):
result: Any = SchemaField(description="The result returned by the MCP tool")
error: str = SchemaField(description="Error message if the tool call failed")
def __init__(self):
super().__init__(
id="a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4",
description="Connect to any MCP server and execute its tools. "
"Provide a server URL, select a tool, and pass arguments dynamically.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=MCPToolBlock.Input,
output_schema=MCPToolBlock.Output,
block_type=BlockType.MCP_TOOL,
test_credentials=TEST_CREDENTIALS,
test_input={
"server_url": "https://mcp.example.com/mcp",
"credentials": TEST_CREDENTIALS_INPUT,
"selected_tool": "get_weather",
"tool_input_schema": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
"tool_arguments": {"city": "London"},
},
test_output=[
(
"result",
{"weather": "sunny", "temperature": 20},
),
],
test_mock={
"_call_mcp_tool": lambda *a, **kw: {
"weather": "sunny",
"temperature": 20,
},
},
)
async def _call_mcp_tool(
self,
server_url: str,
tool_name: str,
arguments: dict[str, Any],
auth_token: str | None = None,
) -> Any:
"""Call a tool on the MCP server. Extracted for easy mocking in tests."""
client = MCPClient(server_url, auth_token=auth_token)
await client.initialize()
result = await client.call_tool(tool_name, arguments)
if result.is_error:
error_text = ""
for item in result.content:
if item.get("type") == "text":
error_text += item.get("text", "")
raise MCPClientError(
f"MCP tool '{tool_name}' returned an error: "
f"{error_text or 'Unknown error'}"
)
# Extract text content from the result
output_parts = []
for item in result.content:
if item.get("type") == "text":
text = item.get("text", "")
# Try to parse as JSON for structured output
try:
output_parts.append(json.loads(text))
except (json.JSONDecodeError, ValueError):
output_parts.append(text)
elif item.get("type") == "image":
output_parts.append(
{
"type": "image",
"data": item.get("data"),
"mimeType": item.get("mimeType"),
}
)
elif item.get("type") == "resource":
output_parts.append(item.get("resource", {}))
# If single result, unwrap
if len(output_parts) == 1:
return output_parts[0]
return output_parts if output_parts else None
@staticmethod
async def _auto_lookup_credential(
user_id: str, server_url: str
) -> "OAuth2Credentials | None":
"""Auto-lookup stored MCP credential for a server URL.
This is a fallback for nodes that don't have ``credentials`` explicitly
set (e.g. nodes created before the credential field was wired up).
"""
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
try:
mgr = IntegrationCredentialsManager()
mcp_creds = await mgr.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
best: OAuth2Credentials | None = None
for cred in mcp_creds:
if (
isinstance(cred, OAuth2Credentials)
and (cred.metadata or {}).get("mcp_server_url") == server_url
):
if best is None or (
(cred.access_token_expires_at or 0)
> (best.access_token_expires_at or 0)
):
best = cred
if best:
best = await mgr.refresh_if_needed(user_id, best)
logger.info(
"Auto-resolved MCP credential %s for %s", best.id, server_url
)
return best
except Exception:
logger.warning("Auto-lookup MCP credential failed", exc_info=True)
return None
async def run(
self,
input_data: Input,
*,
user_id: str,
credentials: OAuth2Credentials | None = None,
**kwargs,
) -> BlockOutput:
if not input_data.server_url:
yield "error", "MCP server URL is required"
return
if not input_data.selected_tool:
yield "error", "No tool selected. Please select a tool from the dropdown."
return
# Validate required tool arguments before calling the server.
# The executor-level validation is bypassed for MCP blocks because
# get_input_defaults() flattens tool_arguments, stripping tool_input_schema
# from the validation context.
required = set(input_data.tool_input_schema.get("required", []))
if required:
missing = required - set(input_data.tool_arguments.keys())
if missing:
yield "error", (
f"Missing required argument(s): {', '.join(sorted(missing))}. "
f"Please fill in all required fields marked with * in the block form."
)
return
# If no credentials were injected by the executor (e.g. legacy nodes
# that don't have the credentials field set), try to auto-lookup
# the stored MCP credential for this server URL.
if credentials is None:
credentials = await self._auto_lookup_credential(
user_id, input_data.server_url
)
auth_token = (
credentials.access_token.get_secret_value() if credentials else None
)
try:
result = await self._call_mcp_tool(
server_url=input_data.server_url,
tool_name=input_data.selected_tool,
arguments=input_data.tool_arguments,
auth_token=auth_token,
)
yield "result", result
except MCPClientError as e:
yield "error", str(e)
except Exception as e:
logger.exception(f"MCP tool call failed: {e}")
yield "error", f"MCP tool call failed: {str(e)}"

View File

@@ -1,323 +0,0 @@
"""
MCP (Model Context Protocol) HTTP client.
Implements the MCP Streamable HTTP transport for listing tools and calling tools
on remote MCP servers. Uses JSON-RPC 2.0 over HTTP POST.
Handles both JSON and SSE (text/event-stream) response formats per the MCP spec.
Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
"""
import json
import logging
from dataclasses import dataclass, field
from typing import Any
from backend.util.request import Requests
logger = logging.getLogger(__name__)
@dataclass
class MCPTool:
"""Represents an MCP tool discovered from a server."""
name: str
description: str
input_schema: dict[str, Any]
@dataclass
class MCPCallResult:
"""Result from calling an MCP tool."""
content: list[dict[str, Any]] = field(default_factory=list)
is_error: bool = False
class MCPClientError(Exception):
"""Raised when an MCP protocol error occurs."""
pass
class MCPClient:
"""
Async HTTP client for the MCP Streamable HTTP transport.
Communicates with MCP servers using JSON-RPC 2.0 over HTTP POST.
Supports optional Bearer token authentication.
"""
def __init__(
self,
server_url: str,
auth_token: str | None = None,
):
self.server_url = server_url.rstrip("/")
self.auth_token = auth_token
self._request_id = 0
self._session_id: str | None = None
def _next_id(self) -> int:
self._request_id += 1
return self._request_id
def _build_headers(self) -> dict[str, str]:
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
}
if self.auth_token:
headers["Authorization"] = f"Bearer {self.auth_token}"
if self._session_id:
headers["Mcp-Session-Id"] = self._session_id
return headers
def _build_jsonrpc_request(
self, method: str, params: dict[str, Any] | None = None
) -> dict[str, Any]:
req: dict[str, Any] = {
"jsonrpc": "2.0",
"method": method,
"id": self._next_id(),
}
if params is not None:
req["params"] = params
return req
@staticmethod
def _parse_sse_response(text: str) -> dict[str, Any]:
"""Parse an SSE (text/event-stream) response body into JSON-RPC data.
MCP servers may return responses as SSE with format:
event: message
data: {"jsonrpc":"2.0","result":{...},"id":1}
We extract the last `data:` line that contains a JSON-RPC response
(i.e. has an "id" field), which is the reply to our request.
"""
last_data: dict[str, Any] | None = None
for line in text.splitlines():
stripped = line.strip()
if stripped.startswith("data:"):
payload = stripped[len("data:") :].strip()
if not payload:
continue
try:
parsed = json.loads(payload)
# Only keep JSON-RPC responses (have "id"), skip notifications
if isinstance(parsed, dict) and "id" in parsed:
last_data = parsed
except (json.JSONDecodeError, ValueError):
continue
if last_data is None:
raise MCPClientError("No JSON-RPC response found in SSE stream")
return last_data
async def _send_request(
self, method: str, params: dict[str, Any] | None = None
) -> Any:
"""Send a JSON-RPC request to the MCP server and return the result.
Handles both ``application/json`` and ``text/event-stream`` responses
as required by the MCP Streamable HTTP transport specification.
"""
payload = self._build_jsonrpc_request(method, params)
headers = self._build_headers()
requests = Requests(
raise_for_status=True,
extra_headers=headers,
)
response = await requests.post(self.server_url, json=payload)
# Capture session ID from response (MCP Streamable HTTP transport)
session_id = response.headers.get("Mcp-Session-Id")
if session_id:
self._session_id = session_id
content_type = response.headers.get("content-type", "")
if "text/event-stream" in content_type:
body = self._parse_sse_response(response.text())
else:
try:
body = response.json()
except Exception as e:
raise MCPClientError(
f"MCP server returned non-JSON response: {e}"
) from e
if not isinstance(body, dict):
raise MCPClientError(
f"MCP server returned unexpected JSON type: {type(body).__name__}"
)
# Handle JSON-RPC error
if "error" in body:
error = body["error"]
if isinstance(error, dict):
raise MCPClientError(
f"MCP server error [{error.get('code', '?')}]: "
f"{error.get('message', 'Unknown error')}"
)
raise MCPClientError(f"MCP server error: {error}")
return body.get("result")
async def _send_notification(self, method: str) -> None:
"""Send a JSON-RPC notification (no id, no response expected)."""
headers = self._build_headers()
notification = {"jsonrpc": "2.0", "method": method}
requests = Requests(
raise_for_status=False,
extra_headers=headers,
)
await requests.post(self.server_url, json=notification)
async def discover_auth(self) -> dict[str, Any] | None:
"""Probe the MCP server's OAuth metadata (RFC 9728 / MCP spec).
Returns ``None`` if the server doesn't require auth, otherwise returns
a dict with:
- ``authorization_servers``: list of authorization server URLs
- ``resource``: the resource indicator URL (usually the MCP endpoint)
- ``scopes_supported``: optional list of supported scopes
The caller can then fetch the authorization server metadata to get
``authorization_endpoint``, ``token_endpoint``, etc.
"""
from urllib.parse import urlparse
parsed = urlparse(self.server_url)
base = f"{parsed.scheme}://{parsed.netloc}"
# Build candidates for protected-resource metadata (per RFC 9728)
path = parsed.path.rstrip("/")
candidates = []
if path and path != "/":
candidates.append(f"{base}/.well-known/oauth-protected-resource{path}")
candidates.append(f"{base}/.well-known/oauth-protected-resource")
requests = Requests(
raise_for_status=False,
)
for url in candidates:
try:
resp = await requests.get(url)
if resp.status == 200:
data = resp.json()
if isinstance(data, dict) and "authorization_servers" in data:
return data
except Exception:
continue
return None
async def discover_auth_server_metadata(
self, auth_server_url: str
) -> dict[str, Any] | None:
"""Fetch the OAuth Authorization Server Metadata (RFC 8414).
Given an authorization server URL, returns a dict with:
- ``authorization_endpoint``
- ``token_endpoint``
- ``registration_endpoint`` (for dynamic client registration)
- ``scopes_supported``
- ``code_challenge_methods_supported``
- etc.
"""
from urllib.parse import urlparse
parsed = urlparse(auth_server_url)
base = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/")
# Try standard metadata endpoints (RFC 8414 and OpenID Connect)
candidates = []
if path and path != "/":
candidates.append(f"{base}/.well-known/oauth-authorization-server{path}")
candidates.append(f"{base}/.well-known/oauth-authorization-server")
candidates.append(f"{base}/.well-known/openid-configuration")
requests = Requests(
raise_for_status=False,
)
for url in candidates:
try:
resp = await requests.get(url)
if resp.status == 200:
data = resp.json()
if isinstance(data, dict) and "authorization_endpoint" in data:
return data
except Exception:
continue
return None
async def initialize(self) -> dict[str, Any]:
"""
Send the MCP initialize request.
This is required by the MCP protocol before any other requests.
Returns the server's capabilities.
"""
result = await self._send_request(
"initialize",
{
"protocolVersion": "2025-03-26",
"capabilities": {},
"clientInfo": {"name": "AutoGPT-Platform", "version": "1.0.0"},
},
)
# Send initialized notification (no response expected)
await self._send_notification("notifications/initialized")
return result or {}
async def list_tools(self) -> list[MCPTool]:
"""
Discover available tools from the MCP server.
Returns a list of MCPTool objects with name, description, and input schema.
"""
result = await self._send_request("tools/list")
if not result or "tools" not in result:
return []
tools = []
for tool_data in result["tools"]:
tools.append(
MCPTool(
name=tool_data.get("name", ""),
description=tool_data.get("description", ""),
input_schema=tool_data.get("inputSchema", {}),
)
)
return tools
async def call_tool(
self, tool_name: str, arguments: dict[str, Any]
) -> MCPCallResult:
"""
Call a tool on the MCP server.
Args:
tool_name: The name of the tool to call.
arguments: The arguments to pass to the tool.
Returns:
MCPCallResult with the tool's response content.
"""
result = await self._send_request(
"tools/call",
{"name": tool_name, "arguments": arguments},
)
if not result:
return MCPCallResult(is_error=True)
return MCPCallResult(
content=result.get("content", []),
is_error=result.get("isError", False),
)

View File

@@ -1,204 +0,0 @@
"""
MCP OAuth handler for MCP servers that use OAuth 2.1 authorization.
Unlike other OAuth handlers (GitHub, Google, etc.) where endpoints are fixed,
MCP servers have dynamic endpoints discovered via RFC 9728 / RFC 8414 metadata.
This handler accepts those endpoints at construction time.
"""
import logging
import time
import urllib.parse
from typing import ClassVar, Optional
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
from backend.integrations.oauth.base import BaseOAuthHandler
from backend.integrations.providers import ProviderName
from backend.util.request import Requests
logger = logging.getLogger(__name__)
class MCPOAuthHandler(BaseOAuthHandler):
"""
OAuth handler for MCP servers with dynamically-discovered endpoints.
Construction requires the authorization and token endpoint URLs,
which are obtained via MCP OAuth metadata discovery
(``MCPClient.discover_auth`` + ``discover_auth_server_metadata``).
"""
PROVIDER_NAME: ClassVar[ProviderName | str] = ProviderName.MCP
DEFAULT_SCOPES: ClassVar[list[str]] = []
def __init__(
self,
client_id: str,
client_secret: str,
redirect_uri: str,
*,
authorize_url: str,
token_url: str,
revoke_url: str | None = None,
resource_url: str | None = None,
):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
self.authorize_url = authorize_url
self.token_url = token_url
self.revoke_url = revoke_url
self.resource_url = resource_url
def get_login_url(
self,
scopes: list[str],
state: str,
code_challenge: Optional[str],
) -> str:
scopes = self.handle_default_scopes(scopes)
params: dict[str, str] = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"state": state,
}
if scopes:
params["scope"] = " ".join(scopes)
# PKCE (S256) — included when the caller provides a code_challenge
if code_challenge:
params["code_challenge"] = code_challenge
params["code_challenge_method"] = "S256"
# MCP spec requires resource indicator (RFC 8707)
if self.resource_url:
params["resource"] = self.resource_url
return f"{self.authorize_url}?{urllib.parse.urlencode(params)}"
async def exchange_code_for_tokens(
self,
code: str,
scopes: list[str],
code_verifier: Optional[str],
) -> OAuth2Credentials:
data: dict[str, str] = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": self.redirect_uri,
"client_id": self.client_id,
}
if self.client_secret:
data["client_secret"] = self.client_secret
if code_verifier:
data["code_verifier"] = code_verifier
if self.resource_url:
data["resource"] = self.resource_url
response = await Requests(raise_for_status=True).post(
self.token_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
tokens = response.json()
if "error" in tokens:
raise RuntimeError(
f"Token exchange failed: {tokens.get('error_description', tokens['error'])}"
)
if "access_token" not in tokens:
raise RuntimeError("OAuth token response missing 'access_token' field")
now = int(time.time())
expires_in = tokens.get("expires_in")
return OAuth2Credentials(
provider=self.PROVIDER_NAME,
title=None,
access_token=SecretStr(tokens["access_token"]),
refresh_token=(
SecretStr(tokens["refresh_token"])
if tokens.get("refresh_token")
else None
),
access_token_expires_at=now + expires_in if expires_in else None,
refresh_token_expires_at=None,
scopes=scopes,
metadata={
"mcp_token_url": self.token_url,
"mcp_resource_url": self.resource_url,
},
)
async def _refresh_tokens(
self, credentials: OAuth2Credentials
) -> OAuth2Credentials:
if not credentials.refresh_token:
raise ValueError("No refresh token available for MCP OAuth credentials")
data: dict[str, str] = {
"grant_type": "refresh_token",
"refresh_token": credentials.refresh_token.get_secret_value(),
"client_id": self.client_id,
}
if self.client_secret:
data["client_secret"] = self.client_secret
if self.resource_url:
data["resource"] = self.resource_url
response = await Requests(raise_for_status=True).post(
self.token_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
tokens = response.json()
if "error" in tokens:
raise RuntimeError(
f"Token refresh failed: {tokens.get('error_description', tokens['error'])}"
)
if "access_token" not in tokens:
raise RuntimeError("OAuth refresh response missing 'access_token' field")
now = int(time.time())
expires_in = tokens.get("expires_in")
return OAuth2Credentials(
id=credentials.id,
provider=self.PROVIDER_NAME,
title=credentials.title,
access_token=SecretStr(tokens["access_token"]),
refresh_token=(
SecretStr(tokens["refresh_token"])
if tokens.get("refresh_token")
else credentials.refresh_token
),
access_token_expires_at=now + expires_in if expires_in else None,
refresh_token_expires_at=credentials.refresh_token_expires_at,
scopes=credentials.scopes,
metadata=credentials.metadata,
)
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
if not self.revoke_url:
return False
try:
data = {
"token": credentials.access_token.get_secret_value(),
"token_type_hint": "access_token",
"client_id": self.client_id,
}
await Requests().post(
self.revoke_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
return True
except Exception:
logger.warning("Failed to revoke MCP OAuth tokens", exc_info=True)
return False

View File

@@ -1,109 +0,0 @@
"""
End-to-end tests against a real public MCP server.
These tests hit the OpenAI docs MCP server (https://developers.openai.com/mcp)
which is publicly accessible without authentication and returns SSE responses.
Mark: These are tagged with ``@pytest.mark.e2e`` so they can be run/skipped
independently of the rest of the test suite (they require network access).
"""
import json
import os
import pytest
from backend.blocks.mcp.client import MCPClient
# Public MCP server that requires no authentication
OPENAI_DOCS_MCP_URL = "https://developers.openai.com/mcp"
# Skip all tests in this module unless RUN_E2E env var is set
pytestmark = pytest.mark.skipif(
not os.environ.get("RUN_E2E"), reason="set RUN_E2E=1 to run e2e tests"
)
class TestRealMCPServer:
"""Tests against the live OpenAI docs MCP server."""
@pytest.mark.asyncio(loop_scope="session")
async def test_initialize(self):
"""Verify we can complete the MCP handshake with a real server."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
result = await client.initialize()
assert result["protocolVersion"] == "2025-03-26"
assert "serverInfo" in result
assert result["serverInfo"]["name"] == "openai-docs-mcp"
assert "tools" in result.get("capabilities", {})
@pytest.mark.asyncio(loop_scope="session")
async def test_list_tools(self):
"""Verify we can discover tools from a real MCP server."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
await client.initialize()
tools = await client.list_tools()
assert len(tools) >= 3 # server has at least 5 tools as of writing
tool_names = {t.name for t in tools}
# These tools are documented and should be stable
assert "search_openai_docs" in tool_names
assert "list_openai_docs" in tool_names
assert "fetch_openai_doc" in tool_names
# Verify schema structure
search_tool = next(t for t in tools if t.name == "search_openai_docs")
assert "query" in search_tool.input_schema.get("properties", {})
assert "query" in search_tool.input_schema.get("required", [])
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_list_api_endpoints(self):
"""Call the list_api_endpoints tool and verify we get real data."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
await client.initialize()
result = await client.call_tool("list_api_endpoints", {})
assert not result.is_error
assert len(result.content) >= 1
assert result.content[0]["type"] == "text"
data = json.loads(result.content[0]["text"])
assert "paths" in data or "urls" in data
# The OpenAI API should have many endpoints
total = data.get("total", len(data.get("paths", [])))
assert total > 50
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_search(self):
"""Search for docs and verify we get results."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
await client.initialize()
result = await client.call_tool(
"search_openai_docs", {"query": "chat completions", "limit": 3}
)
assert not result.is_error
assert len(result.content) >= 1
@pytest.mark.asyncio(loop_scope="session")
async def test_sse_response_handling(self):
"""Verify the client correctly handles SSE responses from a real server.
This is the key test — our local test server returns JSON,
but real MCP servers typically return SSE. This proves the
SSE parsing works end-to-end.
"""
client = MCPClient(OPENAI_DOCS_MCP_URL)
# initialize() internally calls _send_request which must parse SSE
result = await client.initialize()
# If we got here without error, SSE parsing works
assert isinstance(result, dict)
assert "protocolVersion" in result
# Also verify list_tools works (another SSE response)
tools = await client.list_tools()
assert len(tools) > 0
assert all(hasattr(t, "name") for t in tools)

View File

@@ -1,389 +0,0 @@
"""
Integration tests for MCP client and MCPToolBlock against a real HTTP server.
These tests spin up a local MCP test server and run the full client/block flow
against it — no mocking, real HTTP requests.
"""
import asyncio
import json
import threading
from unittest.mock import patch
import pytest
from aiohttp import web
from pydantic import SecretStr
from backend.blocks.mcp.block import MCPToolBlock
from backend.blocks.mcp.client import MCPClient
from backend.blocks.mcp.test_server import create_test_mcp_app
from backend.data.model import OAuth2Credentials
MOCK_USER_ID = "test-user-integration"
class _MCPTestServer:
"""
Run an MCP test server in a background thread with its own event loop.
This avoids event loop conflicts with pytest-asyncio.
"""
def __init__(self, auth_token: str | None = None):
self.auth_token = auth_token
self.url: str = ""
self._runner: web.AppRunner | None = None
self._loop: asyncio.AbstractEventLoop | None = None
self._thread: threading.Thread | None = None
self._started = threading.Event()
def _run(self):
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._loop.run_until_complete(self._start())
self._started.set()
self._loop.run_forever()
async def _start(self):
app = create_test_mcp_app(auth_token=self.auth_token)
self._runner = web.AppRunner(app)
await self._runner.setup()
site = web.TCPSite(self._runner, "127.0.0.1", 0)
await site.start()
port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr]
self.url = f"http://127.0.0.1:{port}/mcp"
def start(self):
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
if not self._started.wait(timeout=5):
raise RuntimeError("MCP test server failed to start within 5 seconds")
return self
def stop(self):
if self._loop and self._runner:
asyncio.run_coroutine_threadsafe(self._runner.cleanup(), self._loop).result(
timeout=5
)
self._loop.call_soon_threadsafe(self._loop.stop)
if self._thread:
self._thread.join(timeout=5)
@pytest.fixture(scope="module")
def mcp_server():
"""Start a local MCP test server in a background thread."""
server = _MCPTestServer()
server.start()
yield server.url
server.stop()
@pytest.fixture(scope="module")
def mcp_server_with_auth():
"""Start a local MCP test server with auth in a background thread."""
server = _MCPTestServer(auth_token="test-secret-token")
server.start()
yield server.url, "test-secret-token"
server.stop()
@pytest.fixture(autouse=True)
def _allow_localhost():
"""
Allow 127.0.0.1 through SSRF protection for integration tests.
The Requests class blocks private IPs by default. We patch the Requests
constructor to always include 127.0.0.1 as a trusted origin so the local
test server is reachable.
"""
from backend.util.request import Requests
original_init = Requests.__init__
def patched_init(self, *args, **kwargs):
trusted = list(kwargs.get("trusted_origins") or [])
trusted.append("http://127.0.0.1")
kwargs["trusted_origins"] = trusted
original_init(self, *args, **kwargs)
with patch.object(Requests, "__init__", patched_init):
yield
def _make_client(url: str, auth_token: str | None = None) -> MCPClient:
"""Create an MCPClient for integration tests."""
return MCPClient(url, auth_token=auth_token)
# ── MCPClient integration tests ──────────────────────────────────────
class TestMCPClientIntegration:
"""Test MCPClient against a real local MCP server."""
@pytest.mark.asyncio(loop_scope="session")
async def test_initialize(self, mcp_server):
client = _make_client(mcp_server)
result = await client.initialize()
assert result["protocolVersion"] == "2025-03-26"
assert result["serverInfo"]["name"] == "test-mcp-server"
assert "tools" in result["capabilities"]
@pytest.mark.asyncio(loop_scope="session")
async def test_list_tools(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
tools = await client.list_tools()
assert len(tools) == 3
tool_names = {t.name for t in tools}
assert tool_names == {"get_weather", "add_numbers", "echo"}
# Check get_weather schema
weather = next(t for t in tools if t.name == "get_weather")
assert weather.description == "Get current weather for a city"
assert "city" in weather.input_schema["properties"]
assert weather.input_schema["required"] == ["city"]
# Check add_numbers schema
add = next(t for t in tools if t.name == "add_numbers")
assert "a" in add.input_schema["properties"]
assert "b" in add.input_schema["properties"]
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_get_weather(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
result = await client.call_tool("get_weather", {"city": "London"})
assert not result.is_error
assert len(result.content) == 1
assert result.content[0]["type"] == "text"
data = json.loads(result.content[0]["text"])
assert data["city"] == "London"
assert data["temperature"] == 22
assert data["condition"] == "sunny"
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_add_numbers(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
result = await client.call_tool("add_numbers", {"a": 3, "b": 7})
assert not result.is_error
data = json.loads(result.content[0]["text"])
assert data["result"] == 10
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_echo(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
result = await client.call_tool("echo", {"message": "Hello MCP!"})
assert not result.is_error
assert result.content[0]["text"] == "Hello MCP!"
@pytest.mark.asyncio(loop_scope="session")
async def test_call_unknown_tool(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
result = await client.call_tool("nonexistent_tool", {})
assert result.is_error
assert "Unknown tool" in result.content[0]["text"]
@pytest.mark.asyncio(loop_scope="session")
async def test_auth_success(self, mcp_server_with_auth):
url, token = mcp_server_with_auth
client = _make_client(url, auth_token=token)
result = await client.initialize()
assert result["protocolVersion"] == "2025-03-26"
tools = await client.list_tools()
assert len(tools) == 3
@pytest.mark.asyncio(loop_scope="session")
async def test_auth_failure(self, mcp_server_with_auth):
url, _ = mcp_server_with_auth
client = _make_client(url, auth_token="wrong-token")
with pytest.raises(Exception):
await client.initialize()
@pytest.mark.asyncio(loop_scope="session")
async def test_auth_missing(self, mcp_server_with_auth):
url, _ = mcp_server_with_auth
client = _make_client(url)
with pytest.raises(Exception):
await client.initialize()
# ── MCPToolBlock integration tests ───────────────────────────────────
class TestMCPToolBlockIntegration:
"""Test MCPToolBlock end-to-end against a real local MCP server."""
@pytest.mark.asyncio(loop_scope="session")
async def test_full_flow_get_weather(self, mcp_server):
"""Full flow: discover tools, select one, execute it."""
# Step 1: Discover tools (simulating what the frontend/API would do)
client = _make_client(mcp_server)
await client.initialize()
tools = await client.list_tools()
assert len(tools) == 3
# Step 2: User selects "get_weather" and we get its schema
weather_tool = next(t for t in tools if t.name == "get_weather")
# Step 3: Execute the block — no credentials (public server)
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="get_weather",
tool_input_schema=weather_tool.input_schema,
tool_arguments={"city": "Paris"},
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
result = outputs[0][1]
assert result["city"] == "Paris"
assert result["temperature"] == 22
assert result["condition"] == "sunny"
@pytest.mark.asyncio(loop_scope="session")
async def test_full_flow_add_numbers(self, mcp_server):
"""Full flow for add_numbers tool."""
client = _make_client(mcp_server)
await client.initialize()
tools = await client.list_tools()
add_tool = next(t for t in tools if t.name == "add_numbers")
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="add_numbers",
tool_input_schema=add_tool.input_schema,
tool_arguments={"a": 42, "b": 58},
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1]["result"] == 100
@pytest.mark.asyncio(loop_scope="session")
async def test_full_flow_echo_plain_text(self, mcp_server):
"""Verify plain text (non-JSON) responses work."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="echo",
tool_input_schema={
"type": "object",
"properties": {"message": {"type": "string"}},
"required": ["message"],
},
tool_arguments={"message": "Hello from AutoGPT!"},
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1] == "Hello from AutoGPT!"
@pytest.mark.asyncio(loop_scope="session")
async def test_full_flow_unknown_tool_yields_error(self, mcp_server):
"""Calling an unknown tool should yield an error output."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="nonexistent_tool",
tool_arguments={},
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "error"
assert "returned an error" in outputs[0][1]
@pytest.mark.asyncio(loop_scope="session")
async def test_full_flow_with_auth(self, mcp_server_with_auth):
"""Full flow with authentication via credentials kwarg."""
url, token = mcp_server_with_auth
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=url,
selected_tool="echo",
tool_input_schema={
"type": "object",
"properties": {"message": {"type": "string"}},
"required": ["message"],
},
tool_arguments={"message": "Authenticated!"},
)
# Pass credentials via the standard kwarg (as the executor would)
test_creds = OAuth2Credentials(
id="test-cred",
provider="mcp",
access_token=SecretStr(token),
refresh_token=SecretStr(""),
scopes=[],
title="Test MCP credential",
)
outputs = []
async for name, data in block.run(
input_data, user_id=MOCK_USER_ID, credentials=test_creds
):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1] == "Authenticated!"
@pytest.mark.asyncio(loop_scope="session")
async def test_no_credentials_runs_without_auth(self, mcp_server):
"""Block runs without auth when no credentials are provided."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="echo",
tool_input_schema={
"type": "object",
"properties": {"message": {"type": "string"}},
"required": ["message"],
},
tool_arguments={"message": "No auth needed"},
)
outputs = []
async for name, data in block.run(
input_data, user_id=MOCK_USER_ID, credentials=None
):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1] == "No auth needed"

View File

@@ -1,619 +0,0 @@
"""
Tests for MCP client and MCPToolBlock.
"""
import json
from unittest.mock import AsyncMock, patch
import pytest
from backend.blocks.mcp.block import MCPToolBlock
from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError
from backend.util.test import execute_block_test
# ── SSE parsing unit tests ───────────────────────────────────────────
class TestSSEParsing:
"""Tests for SSE (text/event-stream) response parsing."""
def test_parse_sse_simple(self):
sse = (
"event: message\n"
'data: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == {"tools": []}
assert body["id"] == 1
def test_parse_sse_with_notifications(self):
"""SSE streams can contain notifications (no id) before the response."""
sse = (
"event: message\n"
'data: {"jsonrpc":"2.0","method":"some/notification"}\n'
"\n"
"event: message\n"
'data: {"jsonrpc":"2.0","result":{"ok":true},"id":2}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == {"ok": True}
assert body["id"] == 2
def test_parse_sse_error_response(self):
sse = (
"event: message\n"
'data: {"jsonrpc":"2.0","error":{"code":-32600,"message":"Bad Request"},"id":1}\n'
)
body = MCPClient._parse_sse_response(sse)
assert "error" in body
assert body["error"]["code"] == -32600
def test_parse_sse_no_data_raises(self):
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
MCPClient._parse_sse_response("event: message\n\n")
def test_parse_sse_empty_raises(self):
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
MCPClient._parse_sse_response("")
def test_parse_sse_ignores_non_data_lines(self):
sse = (
": comment line\n"
"event: message\n"
"id: 123\n"
'data: {"jsonrpc":"2.0","result":"ok","id":1}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == "ok"
def test_parse_sse_uses_last_response(self):
"""If multiple responses exist, use the last one."""
sse = (
'data: {"jsonrpc":"2.0","result":"first","id":1}\n'
"\n"
'data: {"jsonrpc":"2.0","result":"second","id":2}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == "second"
# ── MCPClient unit tests ─────────────────────────────────────────────
class TestMCPClient:
"""Tests for the MCP HTTP client."""
def test_build_headers_without_auth(self):
client = MCPClient("https://mcp.example.com")
headers = client._build_headers()
assert "Authorization" not in headers
assert headers["Content-Type"] == "application/json"
def test_build_headers_with_auth(self):
client = MCPClient("https://mcp.example.com", auth_token="my-token")
headers = client._build_headers()
assert headers["Authorization"] == "Bearer my-token"
def test_build_jsonrpc_request(self):
client = MCPClient("https://mcp.example.com")
req = client._build_jsonrpc_request("tools/list")
assert req["jsonrpc"] == "2.0"
assert req["method"] == "tools/list"
assert "id" in req
assert "params" not in req
def test_build_jsonrpc_request_with_params(self):
client = MCPClient("https://mcp.example.com")
req = client._build_jsonrpc_request(
"tools/call", {"name": "test", "arguments": {"x": 1}}
)
assert req["params"] == {"name": "test", "arguments": {"x": 1}}
def test_request_id_increments(self):
client = MCPClient("https://mcp.example.com")
req1 = client._build_jsonrpc_request("tools/list")
req2 = client._build_jsonrpc_request("tools/list")
assert req2["id"] > req1["id"]
def test_server_url_trailing_slash_stripped(self):
client = MCPClient("https://mcp.example.com/mcp/")
assert client.server_url == "https://mcp.example.com/mcp"
@pytest.mark.asyncio(loop_scope="session")
async def test_send_request_success(self):
client = MCPClient("https://mcp.example.com")
mock_response = AsyncMock()
mock_response.json.return_value = {
"jsonrpc": "2.0",
"result": {"tools": []},
"id": 1,
}
with patch.object(client, "_send_request", return_value={"tools": []}):
result = await client._send_request("tools/list")
assert result == {"tools": []}
@pytest.mark.asyncio(loop_scope="session")
async def test_send_request_error(self):
client = MCPClient("https://mcp.example.com")
async def mock_send(*args, **kwargs):
raise MCPClientError("MCP server error [-32600]: Invalid Request")
with patch.object(client, "_send_request", side_effect=mock_send):
with pytest.raises(MCPClientError, match="Invalid Request"):
await client._send_request("tools/list")
@pytest.mark.asyncio(loop_scope="session")
async def test_list_tools(self):
client = MCPClient("https://mcp.example.com")
mock_result = {
"tools": [
{
"name": "get_weather",
"description": "Get current weather for a city",
"inputSchema": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
},
{
"name": "search",
"description": "Search the web",
"inputSchema": {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
},
},
]
}
with patch.object(client, "_send_request", return_value=mock_result):
tools = await client.list_tools()
assert len(tools) == 2
assert tools[0].name == "get_weather"
assert tools[0].description == "Get current weather for a city"
assert tools[0].input_schema["properties"]["city"]["type"] == "string"
assert tools[1].name == "search"
@pytest.mark.asyncio(loop_scope="session")
async def test_list_tools_empty(self):
client = MCPClient("https://mcp.example.com")
with patch.object(client, "_send_request", return_value={"tools": []}):
tools = await client.list_tools()
assert tools == []
@pytest.mark.asyncio(loop_scope="session")
async def test_list_tools_none_result(self):
client = MCPClient("https://mcp.example.com")
with patch.object(client, "_send_request", return_value=None):
tools = await client.list_tools()
assert tools == []
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_success(self):
client = MCPClient("https://mcp.example.com")
mock_result = {
"content": [
{"type": "text", "text": json.dumps({"temp": 20, "city": "London"})}
],
"isError": False,
}
with patch.object(client, "_send_request", return_value=mock_result):
result = await client.call_tool("get_weather", {"city": "London"})
assert not result.is_error
assert len(result.content) == 1
assert result.content[0]["type"] == "text"
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_error(self):
client = MCPClient("https://mcp.example.com")
mock_result = {
"content": [{"type": "text", "text": "City not found"}],
"isError": True,
}
with patch.object(client, "_send_request", return_value=mock_result):
result = await client.call_tool("get_weather", {"city": "???"})
assert result.is_error
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_none_result(self):
client = MCPClient("https://mcp.example.com")
with patch.object(client, "_send_request", return_value=None):
result = await client.call_tool("get_weather", {"city": "London"})
assert result.is_error
@pytest.mark.asyncio(loop_scope="session")
async def test_initialize(self):
client = MCPClient("https://mcp.example.com")
mock_result = {
"protocolVersion": "2025-03-26",
"capabilities": {"tools": {}},
"serverInfo": {"name": "test-server", "version": "1.0.0"},
}
with (
patch.object(client, "_send_request", return_value=mock_result) as mock_req,
patch.object(client, "_send_notification") as mock_notif,
):
result = await client.initialize()
mock_req.assert_called_once()
mock_notif.assert_called_once_with("notifications/initialized")
assert result["protocolVersion"] == "2025-03-26"
# ── MCPToolBlock unit tests ──────────────────────────────────────────
MOCK_USER_ID = "test-user-123"
class TestMCPToolBlock:
"""Tests for the MCPToolBlock."""
def test_block_instantiation(self):
block = MCPToolBlock()
assert block.id == "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
assert block.name == "MCPToolBlock"
def test_input_schema_has_required_fields(self):
block = MCPToolBlock()
schema = block.input_schema.jsonschema()
props = schema.get("properties", {})
assert "server_url" in props
assert "selected_tool" in props
assert "tool_arguments" in props
assert "credentials" in props
def test_output_schema(self):
block = MCPToolBlock()
schema = block.output_schema.jsonschema()
props = schema.get("properties", {})
assert "result" in props
assert "error" in props
def test_get_input_schema_with_tool_schema(self):
tool_schema = {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
}
data = {"tool_input_schema": tool_schema}
result = MCPToolBlock.Input.get_input_schema(data)
assert result == tool_schema
def test_get_input_schema_without_tool_schema(self):
result = MCPToolBlock.Input.get_input_schema({})
assert result == {}
def test_get_input_defaults(self):
data = {"tool_arguments": {"city": "London"}}
result = MCPToolBlock.Input.get_input_defaults(data)
assert result == {"city": "London"}
def test_get_missing_input(self):
data = {
"tool_input_schema": {
"type": "object",
"properties": {
"city": {"type": "string"},
"units": {"type": "string"},
},
"required": ["city", "units"],
},
"tool_arguments": {"city": "London"},
}
missing = MCPToolBlock.Input.get_missing_input(data)
assert missing == {"units"}
def test_get_missing_input_all_present(self):
data = {
"tool_input_schema": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
"tool_arguments": {"city": "London"},
}
missing = MCPToolBlock.Input.get_missing_input(data)
assert missing == set()
@pytest.mark.asyncio(loop_scope="session")
async def test_run_with_mock(self):
"""Test the block using the built-in test infrastructure."""
block = MCPToolBlock()
await execute_block_test(block)
@pytest.mark.asyncio(loop_scope="session")
async def test_run_missing_server_url(self):
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="",
selected_tool="test",
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert outputs == [("error", "MCP server URL is required")]
@pytest.mark.asyncio(loop_scope="session")
async def test_run_missing_tool(self):
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="",
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert outputs == [
("error", "No tool selected. Please select a tool from the dropdown.")
]
@pytest.mark.asyncio(loop_scope="session")
async def test_run_success(self):
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="get_weather",
tool_input_schema={
"type": "object",
"properties": {"city": {"type": "string"}},
},
tool_arguments={"city": "London"},
)
async def mock_call(*args, **kwargs):
return {"temp": 20, "city": "London"}
block._call_mcp_tool = mock_call # type: ignore
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1] == {"temp": 20, "city": "London"}
@pytest.mark.asyncio(loop_scope="session")
async def test_run_mcp_error(self):
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="bad_tool",
)
async def mock_call(*args, **kwargs):
raise MCPClientError("Tool not found")
block._call_mcp_tool = mock_call # type: ignore
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert outputs[0][0] == "error"
assert "Tool not found" in outputs[0][1]
@pytest.mark.asyncio(loop_scope="session")
async def test_call_mcp_tool_parses_json_text(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[
{"type": "text", "text": '{"temp": 20}'},
],
is_error=False,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
result = await block._call_mcp_tool(
"https://mcp.example.com", "test_tool", {}
)
assert result == {"temp": 20}
@pytest.mark.asyncio(loop_scope="session")
async def test_call_mcp_tool_plain_text(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[
{"type": "text", "text": "Hello, world!"},
],
is_error=False,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
result = await block._call_mcp_tool(
"https://mcp.example.com", "test_tool", {}
)
assert result == "Hello, world!"
@pytest.mark.asyncio(loop_scope="session")
async def test_call_mcp_tool_multiple_content(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[
{"type": "text", "text": "Part 1"},
{"type": "text", "text": '{"part": 2}'},
],
is_error=False,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
result = await block._call_mcp_tool(
"https://mcp.example.com", "test_tool", {}
)
assert result == ["Part 1", {"part": 2}]
@pytest.mark.asyncio(loop_scope="session")
async def test_call_mcp_tool_error_result(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[{"type": "text", "text": "Something went wrong"}],
is_error=True,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
with pytest.raises(MCPClientError, match="returned an error"):
await block._call_mcp_tool("https://mcp.example.com", "test_tool", {})
@pytest.mark.asyncio(loop_scope="session")
async def test_call_mcp_tool_image_content(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[
{
"type": "image",
"data": "base64data==",
"mimeType": "image/png",
}
],
is_error=False,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
result = await block._call_mcp_tool(
"https://mcp.example.com", "test_tool", {}
)
assert result == {
"type": "image",
"data": "base64data==",
"mimeType": "image/png",
}
@pytest.mark.asyncio(loop_scope="session")
async def test_run_with_credentials(self):
"""Verify the block uses OAuth2Credentials and passes auth token."""
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="test_tool",
)
captured_tokens: list[str | None] = []
async def mock_call(server_url, tool_name, arguments, auth_token=None):
captured_tokens.append(auth_token)
return "ok"
block._call_mcp_tool = mock_call # type: ignore
test_creds = OAuth2Credentials(
id="cred-123",
provider="mcp",
access_token=SecretStr("resolved-token"),
refresh_token=SecretStr(""),
scopes=[],
title="Test MCP credential",
)
async for _ in block.run(
input_data, user_id=MOCK_USER_ID, credentials=test_creds
):
pass
assert captured_tokens == ["resolved-token"]
@pytest.mark.asyncio(loop_scope="session")
async def test_run_without_credentials(self):
"""Verify the block works without credentials (public server)."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="test_tool",
)
captured_tokens: list[str | None] = []
async def mock_call(server_url, tool_name, arguments, auth_token=None):
captured_tokens.append(auth_token)
return "ok"
block._call_mcp_tool = mock_call # type: ignore
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert captured_tokens == [None]
assert outputs == [("result", "ok")]

View File

@@ -1,242 +0,0 @@
"""
Tests for MCP OAuth handler.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
from backend.blocks.mcp.client import MCPClient
from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
def _mock_response(json_data: dict, status: int = 200) -> MagicMock:
"""Create a mock Response with synchronous json() (matching Requests.Response)."""
resp = MagicMock()
resp.status = status
resp.ok = 200 <= status < 300
resp.json.return_value = json_data
return resp
class TestMCPOAuthHandler:
"""Tests for the MCPOAuthHandler."""
def _make_handler(self, **overrides) -> MCPOAuthHandler:
defaults = {
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"redirect_uri": "https://app.example.com/callback",
"authorize_url": "https://auth.example.com/authorize",
"token_url": "https://auth.example.com/token",
}
defaults.update(overrides)
return MCPOAuthHandler(**defaults)
def test_get_login_url_basic(self):
handler = self._make_handler()
url = handler.get_login_url(
scopes=["read", "write"],
state="random-state-token",
code_challenge="S256-challenge-value",
)
assert "https://auth.example.com/authorize?" in url
assert "response_type=code" in url
assert "client_id=test-client-id" in url
assert "state=random-state-token" in url
assert "code_challenge=S256-challenge-value" in url
assert "code_challenge_method=S256" in url
assert "scope=read+write" in url
def test_get_login_url_with_resource(self):
handler = self._make_handler(resource_url="https://mcp.example.com/mcp")
url = handler.get_login_url(
scopes=[], state="state", code_challenge="challenge"
)
assert "resource=https" in url
def test_get_login_url_without_pkce(self):
handler = self._make_handler()
url = handler.get_login_url(scopes=["read"], state="state", code_challenge=None)
assert "code_challenge" not in url
assert "code_challenge_method" not in url
@pytest.mark.asyncio(loop_scope="session")
async def test_exchange_code_for_tokens(self):
handler = self._make_handler()
resp = _mock_response(
{
"access_token": "new-access-token",
"refresh_token": "new-refresh-token",
"expires_in": 3600,
"token_type": "Bearer",
}
)
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
instance = MockRequests.return_value
instance.post = AsyncMock(return_value=resp)
creds = await handler.exchange_code_for_tokens(
code="auth-code",
scopes=["read"],
code_verifier="pkce-verifier",
)
assert isinstance(creds, OAuth2Credentials)
assert creds.access_token.get_secret_value() == "new-access-token"
assert creds.refresh_token is not None
assert creds.refresh_token.get_secret_value() == "new-refresh-token"
assert creds.scopes == ["read"]
assert creds.access_token_expires_at is not None
@pytest.mark.asyncio(loop_scope="session")
async def test_refresh_tokens(self):
handler = self._make_handler()
existing_creds = OAuth2Credentials(
id="existing-id",
provider="mcp",
access_token=SecretStr("old-token"),
refresh_token=SecretStr("old-refresh"),
scopes=["read"],
title="test",
)
resp = _mock_response(
{
"access_token": "refreshed-token",
"refresh_token": "new-refresh",
"expires_in": 3600,
}
)
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
instance = MockRequests.return_value
instance.post = AsyncMock(return_value=resp)
refreshed = await handler._refresh_tokens(existing_creds)
assert refreshed.id == "existing-id"
assert refreshed.access_token.get_secret_value() == "refreshed-token"
assert refreshed.refresh_token is not None
assert refreshed.refresh_token.get_secret_value() == "new-refresh"
@pytest.mark.asyncio(loop_scope="session")
async def test_refresh_tokens_no_refresh_token(self):
handler = self._make_handler()
creds = OAuth2Credentials(
provider="mcp",
access_token=SecretStr("token"),
scopes=["read"],
title="test",
)
with pytest.raises(ValueError, match="No refresh token"):
await handler._refresh_tokens(creds)
@pytest.mark.asyncio(loop_scope="session")
async def test_revoke_tokens_no_url(self):
handler = self._make_handler(revoke_url=None)
creds = OAuth2Credentials(
provider="mcp",
access_token=SecretStr("token"),
scopes=[],
title="test",
)
result = await handler.revoke_tokens(creds)
assert result is False
@pytest.mark.asyncio(loop_scope="session")
async def test_revoke_tokens_with_url(self):
handler = self._make_handler(revoke_url="https://auth.example.com/revoke")
creds = OAuth2Credentials(
provider="mcp",
access_token=SecretStr("token"),
scopes=[],
title="test",
)
resp = _mock_response({}, status=200)
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
instance = MockRequests.return_value
instance.post = AsyncMock(return_value=resp)
result = await handler.revoke_tokens(creds)
assert result is True
class TestMCPClientDiscovery:
"""Tests for MCPClient OAuth metadata discovery."""
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_auth_found(self):
client = MCPClient("https://mcp.example.com/mcp")
metadata = {
"authorization_servers": ["https://auth.example.com"],
"resource": "https://mcp.example.com/mcp",
}
resp = _mock_response(metadata, status=200)
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
instance = MockRequests.return_value
instance.get = AsyncMock(return_value=resp)
result = await client.discover_auth()
assert result is not None
assert result["authorization_servers"] == ["https://auth.example.com"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_auth_not_found(self):
client = MCPClient("https://mcp.example.com/mcp")
resp = _mock_response({}, status=404)
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
instance = MockRequests.return_value
instance.get = AsyncMock(return_value=resp)
result = await client.discover_auth()
assert result is None
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_auth_server_metadata(self):
client = MCPClient("https://mcp.example.com/mcp")
server_metadata = {
"issuer": "https://auth.example.com",
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
"registration_endpoint": "https://auth.example.com/register",
"code_challenge_methods_supported": ["S256"],
}
resp = _mock_response(server_metadata, status=200)
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
instance = MockRequests.return_value
instance.get = AsyncMock(return_value=resp)
result = await client.discover_auth_server_metadata(
"https://auth.example.com"
)
assert result is not None
assert result["authorization_endpoint"] == "https://auth.example.com/authorize"
assert result["token_endpoint"] == "https://auth.example.com/token"

View File

@@ -1,162 +0,0 @@
"""
Minimal MCP server for integration testing.
Implements the MCP Streamable HTTP transport (JSON-RPC 2.0 over HTTP POST)
with a few sample tools. Runs on localhost with a random available port.
"""
import json
import logging
from aiohttp import web
logger = logging.getLogger(__name__)
# Sample tools this test server exposes
TEST_TOOLS = [
{
"name": "get_weather",
"description": "Get current weather for a city",
"inputSchema": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City name",
},
},
"required": ["city"],
},
},
{
"name": "add_numbers",
"description": "Add two numbers together",
"inputSchema": {
"type": "object",
"properties": {
"a": {"type": "number", "description": "First number"},
"b": {"type": "number", "description": "Second number"},
},
"required": ["a", "b"],
},
},
{
"name": "echo",
"description": "Echo back the input message",
"inputSchema": {
"type": "object",
"properties": {
"message": {"type": "string", "description": "Message to echo"},
},
"required": ["message"],
},
},
]
def _handle_initialize(params: dict) -> dict:
return {
"protocolVersion": "2025-03-26",
"capabilities": {"tools": {"listChanged": False}},
"serverInfo": {"name": "test-mcp-server", "version": "1.0.0"},
}
def _handle_tools_list(params: dict) -> dict:
return {"tools": TEST_TOOLS}
def _handle_tools_call(params: dict) -> dict:
tool_name = params.get("name", "")
arguments = params.get("arguments", {})
if tool_name == "get_weather":
city = arguments.get("city", "Unknown")
return {
"content": [
{
"type": "text",
"text": json.dumps(
{"city": city, "temperature": 22, "condition": "sunny"}
),
}
],
}
elif tool_name == "add_numbers":
a = arguments.get("a", 0)
b = arguments.get("b", 0)
return {
"content": [{"type": "text", "text": json.dumps({"result": a + b})}],
}
elif tool_name == "echo":
message = arguments.get("message", "")
return {
"content": [{"type": "text", "text": message}],
}
else:
return {
"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}],
"isError": True,
}
HANDLERS = {
"initialize": _handle_initialize,
"tools/list": _handle_tools_list,
"tools/call": _handle_tools_call,
}
async def handle_mcp_request(request: web.Request) -> web.Response:
"""Handle incoming MCP JSON-RPC 2.0 requests."""
# Check auth if configured
expected_token = request.app.get("auth_token")
if expected_token:
auth_header = request.headers.get("Authorization", "")
if auth_header != f"Bearer {expected_token}":
return web.json_response(
{
"jsonrpc": "2.0",
"error": {"code": -32001, "message": "Unauthorized"},
"id": None,
},
status=401,
)
body = await request.json()
# Handle notifications (no id field) — just acknowledge
if "id" not in body:
return web.Response(status=202)
method = body.get("method", "")
params = body.get("params", {})
request_id = body.get("id")
handler = HANDLERS.get(method)
if not handler:
return web.json_response(
{
"jsonrpc": "2.0",
"error": {
"code": -32601,
"message": f"Method not found: {method}",
},
"id": request_id,
}
)
result = handler(params)
return web.json_response({"jsonrpc": "2.0", "result": result, "id": request_id})
def create_test_mcp_app(auth_token: str | None = None) -> web.Application:
"""Create an aiohttp app that acts as an MCP server."""
app = web.Application()
app.router.add_post("/mcp", handle_mcp_request)
if auth_token:
app["auth_token"] = auth_token
return app

View File

@@ -226,9 +226,10 @@ class SmartDecisionMakerBlock(Block):
)
model: llm.LlmModel = SchemaField(
title="LLM Model",
default=llm.DEFAULT_LLM_MODEL,
default_factory=llm.LlmModel.default,
description="The language model to use for answering the prompt.",
advanced=False,
json_schema_extra=llm.llm_model_schema_extra(),
)
credentials: llm.AICredentials = llm.AICredentialsField()
multiple_tool_calls: bool = SchemaField(

View File

@@ -10,13 +10,13 @@ import stagehand.main
from stagehand import Stagehand
from backend.blocks.llm import (
MODEL_METADATA,
AICredentials,
AICredentialsField,
LlmModel,
ModelMetadata,
)
from backend.blocks.stagehand._config import stagehand as stagehand_provider
from backend.data import llm_registry
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -91,7 +91,7 @@ class StagehandRecommendedLlmModel(str, Enum):
Returns the provider name for the model in the required format for Stagehand:
provider/model_name
"""
model_metadata = MODEL_METADATA[LlmModel(self.value)]
model_metadata = self.metadata
model_name = self.value
if len(model_name.split("/")) == 1 and not self.value.startswith(
@@ -102,24 +102,28 @@ class StagehandRecommendedLlmModel(str, Enum):
), "Logic failed and open_router provider attempted to be prepended to model name! in stagehand/_config.py"
model_name = f"{model_metadata.provider}/{model_name}"
logger.error(f"Model name: {model_name}")
logger.debug(f"Model name: {model_name}")
return model_name
@property
def provider(self) -> str:
return MODEL_METADATA[LlmModel(self.value)].provider
return self.metadata.provider
@property
def metadata(self) -> ModelMetadata:
return MODEL_METADATA[LlmModel(self.value)]
metadata = llm_registry.get_llm_model_metadata(self.value)
if metadata:
return metadata
# Fallback to LlmModel enum if registry lookup fails
return LlmModel(self.value).metadata
@property
def context_window(self) -> int:
return MODEL_METADATA[LlmModel(self.value)].context_window
return self.metadata.context_window
@property
def max_output_tokens(self) -> int | None:
return MODEL_METADATA[LlmModel(self.value)].max_output_tokens
return self.metadata.max_output_tokens
class StagehandObserveBlock(Block):

View File

@@ -19,6 +19,30 @@ CompletedBlockOutput = dict[str, list[Any]] # Completed stream, collected as a
async def initialize_blocks() -> None:
# Refresh LLM registry before initializing blocks so blocks can use registry data
# This ensures the registry cache is populated even in executor context
try:
from backend.data import llm_registry
from backend.data.block_cost_config import refresh_llm_costs
# Only refresh if we have DB access (check if Prisma is connected)
from backend.data.db import is_connected
if is_connected():
await llm_registry.refresh_llm_registry()
await refresh_llm_costs()
logger.info("LLM registry refreshed during block initialization")
else:
logger.warning(
"Prisma not connected, skipping LLM registry refresh during block initialization"
)
except Exception as exc:
logger.warning(
"Failed to refresh LLM registry during block initialization: %s", exc
)
# First, sync all provider costs to blocks
# Imported here to avoid circular import
from backend.blocks import get_blocks
from backend.sdk.cost_integration import sync_all_provider_costs
from backend.util.retry import func_retry

View File

@@ -1,5 +1,8 @@
import logging
from typing import Type
import prisma.models
from backend.blocks._base import Block, BlockCost, BlockCostType
from backend.blocks.ai_image_customizer import AIImageCustomizerBlock, GeminiImageModel
from backend.blocks.ai_image_generator_block import AIImageGeneratorBlock, ImageGenModel
@@ -24,13 +27,11 @@ from backend.blocks.ideogram import IdeogramModelBlock
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
from backend.blocks.jina.search import ExtractWebsiteContentBlock, SearchTheWebBlock
from backend.blocks.llm import (
MODEL_METADATA,
AIConversationBlock,
AIListGeneratorBlock,
AIStructuredResponseGeneratorBlock,
AITextGeneratorBlock,
AITextSummarizerBlock,
LlmModel,
)
from backend.blocks.replicate.flux_advanced import ReplicateFluxAdvancedModelBlock
from backend.blocks.replicate.replicate_block import ReplicateModelBlock
@@ -38,6 +39,7 @@ from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
from backend.blocks.video.narration import VideoNarrationBlock
from backend.data import llm_registry
from backend.integrations.credentials_store import (
aiml_api_credentials,
anthropic_credentials,
@@ -57,210 +59,116 @@ from backend.integrations.credentials_store import (
v0_credentials,
)
# =============== Configure the cost for each LLM Model call =============== #
logger = logging.getLogger(__name__)
MODEL_COST: dict[LlmModel, int] = {
LlmModel.O3: 4,
LlmModel.O3_MINI: 2,
LlmModel.O1: 16,
LlmModel.O1_MINI: 4,
# GPT-5 models
LlmModel.GPT5_2: 6,
LlmModel.GPT5_1: 5,
LlmModel.GPT5: 2,
LlmModel.GPT5_MINI: 1,
LlmModel.GPT5_NANO: 1,
LlmModel.GPT5_CHAT: 5,
LlmModel.GPT41: 2,
LlmModel.GPT41_MINI: 1,
LlmModel.GPT4O_MINI: 1,
LlmModel.GPT4O: 3,
LlmModel.GPT4_TURBO: 10,
LlmModel.GPT3_5_TURBO: 1,
LlmModel.CLAUDE_4_1_OPUS: 21,
LlmModel.CLAUDE_4_OPUS: 21,
LlmModel.CLAUDE_4_SONNET: 5,
LlmModel.CLAUDE_4_6_OPUS: 14,
LlmModel.CLAUDE_4_5_HAIKU: 4,
LlmModel.CLAUDE_4_5_OPUS: 14,
LlmModel.CLAUDE_4_5_SONNET: 9,
LlmModel.CLAUDE_3_HAIKU: 1,
LlmModel.AIML_API_QWEN2_5_72B: 1,
LlmModel.AIML_API_LLAMA3_1_70B: 1,
LlmModel.AIML_API_LLAMA3_3_70B: 1,
LlmModel.AIML_API_META_LLAMA_3_1_70B: 1,
LlmModel.AIML_API_LLAMA_3_2_3B: 1,
LlmModel.LLAMA3_3_70B: 1,
LlmModel.LLAMA3_1_8B: 1,
LlmModel.OLLAMA_LLAMA3_3: 1,
LlmModel.OLLAMA_LLAMA3_2: 1,
LlmModel.OLLAMA_LLAMA3_8B: 1,
LlmModel.OLLAMA_LLAMA3_405B: 1,
LlmModel.OLLAMA_DOLPHIN: 1,
LlmModel.OPENAI_GPT_OSS_120B: 1,
LlmModel.OPENAI_GPT_OSS_20B: 1,
LlmModel.GEMINI_2_5_PRO: 4,
LlmModel.GEMINI_3_PRO_PREVIEW: 5,
LlmModel.GEMINI_2_5_FLASH: 1,
LlmModel.GEMINI_2_0_FLASH: 1,
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: 1,
LlmModel.GEMINI_2_0_FLASH_LITE: 1,
LlmModel.MISTRAL_NEMO: 1,
LlmModel.COHERE_COMMAND_R_08_2024: 1,
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: 3,
LlmModel.DEEPSEEK_CHAT: 2,
LlmModel.DEEPSEEK_R1_0528: 1,
LlmModel.PERPLEXITY_SONAR: 1,
LlmModel.PERPLEXITY_SONAR_PRO: 5,
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: 10,
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: 1,
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: 1,
LlmModel.AMAZON_NOVA_LITE_V1: 1,
LlmModel.AMAZON_NOVA_MICRO_V1: 1,
LlmModel.AMAZON_NOVA_PRO_V1: 1,
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
LlmModel.META_LLAMA_4_SCOUT: 1,
LlmModel.META_LLAMA_4_MAVERICK: 1,
LlmModel.LLAMA_API_LLAMA_4_SCOUT: 1,
LlmModel.LLAMA_API_LLAMA4_MAVERICK: 1,
LlmModel.LLAMA_API_LLAMA3_3_8B: 1,
LlmModel.LLAMA_API_LLAMA3_3_70B: 1,
LlmModel.GROK_4: 9,
LlmModel.GROK_4_FAST: 1,
LlmModel.GROK_4_1_FAST: 1,
LlmModel.GROK_CODE_FAST_1: 1,
LlmModel.KIMI_K2: 1,
LlmModel.QWEN3_235B_A22B_THINKING: 1,
LlmModel.QWEN3_CODER: 9,
# v0 by Vercel models
LlmModel.V0_1_5_MD: 1,
LlmModel.V0_1_5_LG: 2,
LlmModel.V0_1_0_MD: 1,
PROVIDER_CREDENTIALS = {
"openai": openai_credentials,
"anthropic": anthropic_credentials,
"groq": groq_credentials,
"open_router": open_router_credentials,
"llama_api": llama_api_credentials,
"aiml_api": aiml_api_credentials,
"v0": v0_credentials,
}
for model in LlmModel:
if model not in MODEL_COST:
raise ValueError(f"Missing MODEL_COST for model: {model}")
# =============== Configure the cost for each LLM Model call =============== #
# All LLM costs now come from the database via llm_registry
LLM_COST: list[BlockCost] = []
LLM_COST = (
# Anthropic Models
[
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
async def _build_llm_costs_from_registry() -> list[BlockCost]:
"""
Build BlockCost list from all models in the LLM registry.
This function checks for active model migrations with customCreditCost overrides.
When a model has been migrated with a custom price, that price is used instead
of the target model's default cost.
"""
# Query active migrations with custom pricing overrides.
# Note: LlmModelMigration is system-level data (no userId field) and this function
# is only called during app startup and admin operations, so no user ID filter needed.
migration_overrides: dict[str, int] = {}
try:
active_migrations = await prisma.models.LlmModelMigration.prisma().find_many(
where={
"isReverted": False,
"customCreditCost": {"not": None},
}
)
# Key by targetModelSlug since that's the model nodes are now using
# after migration. The custom cost applies to the target model.
migration_overrides = {
migration.targetModelSlug: migration.customCreditCost
for migration in active_migrations
if migration.customCreditCost is not None
}
if migration_overrides:
logger.info(
"Found %d active model migrations with custom pricing overrides",
len(migration_overrides),
)
except Exception as exc:
logger.warning(
"Failed to query model migration overrides: %s. Proceeding with default costs.",
exc,
exc_info=True,
)
costs: list[BlockCost] = []
for model in llm_registry.iter_dynamic_models():
for cost in model.costs:
credentials = PROVIDER_CREDENTIALS.get(cost.credential_provider)
if not credentials:
logger.warning(
"Skipping cost entry for %s due to unknown credentials provider %s",
model.slug,
cost.credential_provider,
)
continue
# Check if this model has a custom cost override from migration
cost_amount = migration_overrides.get(model.slug, cost.credit_cost)
if model.slug in migration_overrides:
logger.debug(
"Applying custom cost override for model %s: %d credits (default: %d)",
model.slug,
cost_amount,
cost.credit_cost,
)
cost_filter = {
"model": model.slug,
"credentials": {
"id": anthropic_credentials.id,
"provider": anthropic_credentials.provider,
"type": anthropic_credentials.type,
"id": credentials.id,
"provider": credentials.provider,
"type": credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "anthropic"
]
# OpenAI Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": openai_credentials.id,
"provider": openai_credentials.provider,
"type": openai_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "openai"
]
# Groq Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {"id": groq_credentials.id},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "groq"
]
# Open Router Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": open_router_credentials.id,
"provider": open_router_credentials.provider,
"type": open_router_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "open_router"
]
# Llama API Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": llama_api_credentials.id,
"provider": llama_api_credentials.provider,
"type": llama_api_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "llama_api"
]
# v0 by Vercel Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": v0_credentials.id,
"provider": v0_credentials.provider,
"type": v0_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "v0"
]
# AI/ML Api Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": aiml_api_credentials.id,
"provider": aiml_api_credentials.provider,
"type": aiml_api_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "aiml_api"
]
)
}
costs.append(
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter=cost_filter,
cost_amount=cost_amount,
)
)
return costs
async def refresh_llm_costs() -> None:
"""
Refresh LLM costs from the registry. All costs now come from the database.
This function also checks for active model migrations with custom pricing overrides
and applies them to ensure accurate billing.
"""
LLM_COST.clear()
LLM_COST.extend(await _build_llm_costs_from_registry())
# Initial load will happen after registry is refreshed at startup
# Don't call refresh_llm_costs() here - it will be called after registry refresh
# =============== This is the exhaustive list of cost for each Block =============== #

View File

@@ -33,7 +33,6 @@ from backend.util import type as type_utils
from backend.util.exceptions import GraphNotAccessibleError, GraphNotInLibraryError
from backend.util.json import SafeJson
from backend.util.models import Pagination
from backend.util.request import parse_url
from .block import BlockInput
from .db import BaseDbModel
@@ -450,9 +449,6 @@ class GraphModel(Graph, GraphMeta):
continue
if ProviderName.HTTP in field.provider:
continue
# MCP credentials are intentionally split by server URL
if ProviderName.MCP in field.provider:
continue
# If this happens, that means a block implementation probably needs
# to be updated.
@@ -509,18 +505,6 @@ class GraphModel(Graph, GraphMeta):
"required": ["id", "provider", "type"],
}
# Add a descriptive display title when URL-based discriminator values
# are present (e.g. "mcp.sentry.dev" instead of just "Mcp")
if (
field_info.discriminator
and not field_info.discriminator_mapping
and field_info.discriminator_values
):
hostnames = sorted(
parse_url(str(v)).netloc for v in field_info.discriminator_values
)
field_schema["display_name"] = ", ".join(hostnames)
# Add other (optional) field info items
field_schema.update(
field_info.model_dump(
@@ -565,17 +549,8 @@ class GraphModel(Graph, GraphMeta):
for graph in [self] + self.sub_graphs:
for node in graph.nodes:
# A node's credentials are optional if either:
# 1. The node metadata says so (credentials_optional=True), or
# 2. All credential fields on the block have defaults (not required by schema)
block_required = node.block.input_schema.get_required_fields()
creds_required_by_schema = any(
fname in block_required
for fname in node.block.input_schema.get_credentials_fields()
)
node_required_map[node.id] = (
not node.credentials_optional and creds_required_by_schema
)
# Track if this node requires credentials (credentials_optional=False means required)
node_required_map[node.id] = not node.credentials_optional
for (
field_name,
@@ -801,19 +776,6 @@ class GraphModel(Graph, GraphMeta):
"'credentials' and `*_credentials` are reserved"
)
# Check custom block-level validation (e.g., MCP dynamic tool arguments).
# Blocks can override get_missing_input to report additional missing fields
# beyond the standard top-level required fields.
if for_run:
credential_fields = InputSchema.get_credentials_fields()
custom_missing = InputSchema.get_missing_input(node.input_default)
for field_name in custom_missing:
if (
field_name not in provided_inputs
and field_name not in credential_fields
):
node_errors[node.id][field_name] = "This field is required"
# Get input schema properties and check dependencies
input_fields = InputSchema.model_fields
@@ -1663,8 +1625,10 @@ async def migrate_llm_models(migrate_to: LlmModel):
if field.annotation == LlmModel:
llm_model_fields[block.id] = field_name
# Convert enum values to a list of strings for the SQL query
enum_values = [v.value for v in LlmModel]
# Get all model slugs from the registry (dynamic, not hardcoded enum)
from backend.data import llm_registry
enum_values = list(llm_registry.get_all_model_slugs_for_validation())
escaped_enum_values = repr(tuple(enum_values)) # hack but works
# Update each block

View File

@@ -462,120 +462,3 @@ def test_node_credentials_optional_with_other_metadata():
assert node.credentials_optional is True
assert node.metadata["position"] == {"x": 100, "y": 200}
assert node.metadata["customized_name"] == "My Custom Node"
# ============================================================================
# Tests for MCP Credential Deduplication
# ============================================================================
def test_mcp_credential_combine_different_servers():
"""Two MCP credential fields with different server URLs should produce
separate entries when combined (not merged into one)."""
from backend.data.model import CredentialsFieldInfo, CredentialsType
from backend.integrations.providers import ProviderName
oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"])
field_sentry = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.MCP]),
credentials_types=oauth2_types,
credentials_scopes=None,
discriminator="server_url",
discriminator_values={"https://mcp.sentry.dev/mcp"},
)
field_linear = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.MCP]),
credentials_types=oauth2_types,
credentials_scopes=None,
discriminator="server_url",
discriminator_values={"https://mcp.linear.app/mcp"},
)
combined = CredentialsFieldInfo.combine(
(field_sentry, ("node-sentry", "credentials")),
(field_linear, ("node-linear", "credentials")),
)
# Should produce 2 separate credential entries
assert len(combined) == 2, (
f"Expected 2 credential entries for 2 MCP blocks with different servers, "
f"got {len(combined)}: {list(combined.keys())}"
)
# Each entry should contain the server hostname in its key
keys = list(combined.keys())
assert any(
"mcp.sentry.dev" in k for k in keys
), f"Expected 'mcp.sentry.dev' in one key, got {keys}"
assert any(
"mcp.linear.app" in k for k in keys
), f"Expected 'mcp.linear.app' in one key, got {keys}"
def test_mcp_credential_combine_same_server():
"""Two MCP credential fields with the same server URL should be combined
into one credential entry."""
from backend.data.model import CredentialsFieldInfo, CredentialsType
from backend.integrations.providers import ProviderName
oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"])
field_a = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.MCP]),
credentials_types=oauth2_types,
credentials_scopes=None,
discriminator="server_url",
discriminator_values={"https://mcp.sentry.dev/mcp"},
)
field_b = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.MCP]),
credentials_types=oauth2_types,
credentials_scopes=None,
discriminator="server_url",
discriminator_values={"https://mcp.sentry.dev/mcp"},
)
combined = CredentialsFieldInfo.combine(
(field_a, ("node-a", "credentials")),
(field_b, ("node-b", "credentials")),
)
# Should produce 1 credential entry (same server URL)
assert len(combined) == 1, (
f"Expected 1 credential entry for 2 MCP blocks with same server, "
f"got {len(combined)}: {list(combined.keys())}"
)
def test_mcp_credential_combine_no_discriminator_values():
"""MCP credential fields without discriminator_values should be merged
into a single entry (backwards compat for blocks without server_url set)."""
from backend.data.model import CredentialsFieldInfo, CredentialsType
from backend.integrations.providers import ProviderName
oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"])
field_a = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.MCP]),
credentials_types=oauth2_types,
credentials_scopes=None,
discriminator="server_url",
)
field_b = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.MCP]),
credentials_types=oauth2_types,
credentials_scopes=None,
discriminator="server_url",
)
combined = CredentialsFieldInfo.combine(
(field_a, ("node-a", "credentials")),
(field_b, ("node-b", "credentials")),
)
# Should produce 1 entry (no URL differentiation)
assert len(combined) == 1, (
f"Expected 1 credential entry for MCP blocks without discriminator_values, "
f"got {len(combined)}: {list(combined.keys())}"
)

View File

@@ -0,0 +1,72 @@
"""
LLM Registry module for managing LLM models, providers, and costs dynamically.
This module provides a database-driven registry system for LLM models,
replacing hardcoded model configurations with a flexible admin-managed system.
"""
from backend.data.llm_registry.model import ModelMetadata
# Re-export for backwards compatibility
from backend.data.llm_registry.notifications import (
REGISTRY_REFRESH_CHANNEL,
publish_registry_refresh_notification,
subscribe_to_registry_refresh,
)
from backend.data.llm_registry.registry import (
RegistryModel,
RegistryModelCost,
RegistryModelCreator,
get_all_model_slugs_for_validation,
get_default_model_slug,
get_dynamic_model_slugs,
get_fallback_model_for_disabled,
get_llm_discriminator_mapping,
get_llm_model_cost,
get_llm_model_metadata,
get_llm_model_schema_options,
get_model_info,
is_model_enabled,
iter_dynamic_models,
refresh_llm_registry,
register_static_costs,
register_static_metadata,
)
from backend.data.llm_registry.schema_utils import (
is_llm_model_field,
refresh_llm_discriminator_mapping,
refresh_llm_model_options,
update_schema_with_llm_registry,
)
__all__ = [
# Types
"ModelMetadata",
"RegistryModel",
"RegistryModelCost",
"RegistryModelCreator",
# Registry functions
"get_all_model_slugs_for_validation",
"get_default_model_slug",
"get_dynamic_model_slugs",
"get_fallback_model_for_disabled",
"get_llm_discriminator_mapping",
"get_llm_model_cost",
"get_llm_model_metadata",
"get_llm_model_schema_options",
"get_model_info",
"is_model_enabled",
"iter_dynamic_models",
"refresh_llm_registry",
"register_static_costs",
"register_static_metadata",
# Notifications
"REGISTRY_REFRESH_CHANNEL",
"publish_registry_refresh_notification",
"subscribe_to_registry_refresh",
# Schema utilities
"is_llm_model_field",
"refresh_llm_discriminator_mapping",
"refresh_llm_model_options",
"update_schema_with_llm_registry",
]

View File

@@ -0,0 +1,25 @@
"""Type definitions for LLM model metadata."""
from typing import Literal, NamedTuple
class ModelMetadata(NamedTuple):
"""Metadata for an LLM model.
Attributes:
provider: The provider identifier (e.g., "openai", "anthropic")
context_window: Maximum context window size in tokens
max_output_tokens: Maximum output tokens (None if unlimited)
display_name: Human-readable name for the model
provider_name: Human-readable provider name (e.g., "OpenAI", "Anthropic")
creator_name: Name of the organization that created the model
price_tier: Relative cost tier (1=cheapest, 2=medium, 3=expensive)
"""
provider: str
context_window: int
max_output_tokens: int | None
display_name: str
provider_name: str
creator_name: str
price_tier: Literal[1, 2, 3]

View File

@@ -0,0 +1,89 @@
"""
Redis pub/sub notifications for LLM registry updates.
When models are added/updated/removed via the admin UI, this module
publishes notifications to Redis that all executor services subscribe to,
ensuring they refresh their registry cache in real-time.
"""
import asyncio
import logging
from typing import Any
from backend.data.redis_client import connect_async
logger = logging.getLogger(__name__)
# Redis channel name for LLM registry refresh notifications
REGISTRY_REFRESH_CHANNEL = "llm_registry:refresh"
async def publish_registry_refresh_notification() -> None:
"""
Publish a notification to Redis that the LLM registry has been updated.
All executor services subscribed to this channel will refresh their registry.
"""
try:
redis = await connect_async()
await redis.publish(REGISTRY_REFRESH_CHANNEL, "refresh")
logger.info("Published LLM registry refresh notification to Redis")
except Exception as exc:
logger.warning(
"Failed to publish LLM registry refresh notification: %s",
exc,
exc_info=True,
)
async def subscribe_to_registry_refresh(
on_refresh: Any, # Async callable that takes no args
) -> None:
"""
Subscribe to Redis notifications for LLM registry updates.
This runs in a loop and processes messages as they arrive.
Args:
on_refresh: Async callable to execute when a refresh notification is received
"""
try:
redis = await connect_async()
pubsub = redis.pubsub()
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
logger.info(
"Subscribed to LLM registry refresh notifications on channel: %s",
REGISTRY_REFRESH_CHANNEL,
)
# Process messages in a loop
while True:
try:
message = await pubsub.get_message(
ignore_subscribe_messages=True, timeout=1.0
)
if (
message
and message["type"] == "message"
and message["channel"] == REGISTRY_REFRESH_CHANNEL
):
logger.info("Received LLM registry refresh notification")
try:
await on_refresh()
except Exception as exc:
logger.error(
"Error refreshing LLM registry from notification: %s",
exc,
exc_info=True,
)
except Exception as exc:
logger.warning(
"Error processing registry refresh message: %s", exc, exc_info=True
)
# Continue listening even if one message fails
await asyncio.sleep(1)
except Exception as exc:
logger.error(
"Failed to subscribe to LLM registry refresh notifications: %s",
exc,
exc_info=True,
)
raise

View File

@@ -0,0 +1,388 @@
"""Core LLM registry implementation for managing models dynamically."""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Any, Iterable
import prisma.models
from backend.data.llm_registry.model import ModelMetadata
logger = logging.getLogger(__name__)
def _json_to_dict(value: Any) -> dict[str, Any]:
"""Convert Prisma Json type to dict, with fallback to empty dict."""
if value is None:
return {}
if isinstance(value, dict):
return value
# Prisma Json type should always be a dict at runtime
return dict(value) if value else {}
@dataclass(frozen=True)
class RegistryModelCost:
"""Cost configuration for an LLM model."""
credit_cost: int
credential_provider: str
credential_id: str | None
credential_type: str | None
currency: str | None
metadata: dict[str, Any]
@dataclass(frozen=True)
class RegistryModelCreator:
"""Creator information for an LLM model."""
id: str
name: str
display_name: str
description: str | None
website_url: str | None
logo_url: str | None
@dataclass(frozen=True)
class RegistryModel:
"""Represents a model in the LLM registry."""
slug: str
display_name: str
description: str | None
metadata: ModelMetadata
capabilities: dict[str, Any]
extra_metadata: dict[str, Any]
provider_display_name: str
is_enabled: bool
is_recommended: bool = False
costs: tuple[RegistryModelCost, ...] = field(default_factory=tuple)
creator: RegistryModelCreator | None = None
_static_metadata: dict[str, ModelMetadata] = {}
_static_costs: dict[str, int] = {}
_dynamic_models: dict[str, RegistryModel] = {}
_schema_options: list[dict[str, str]] = []
_discriminator_mapping: dict[str, str] = {}
_lock = asyncio.Lock()
def register_static_metadata(metadata: dict[Any, ModelMetadata]) -> None:
"""Register static metadata for legacy models (deprecated)."""
_static_metadata.update({str(key): value for key, value in metadata.items()})
_refresh_cached_schema()
def register_static_costs(costs: dict[Any, int]) -> None:
"""Register static costs for legacy models (deprecated)."""
_static_costs.update({str(key): value for key, value in costs.items()})
def _build_schema_options() -> list[dict[str, str]]:
"""Build schema options for model selection dropdown. Only includes enabled models."""
options: list[dict[str, str]] = []
# Only include enabled models in the dropdown options
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
if model.is_enabled:
options.append(
{
"label": model.display_name,
"value": model.slug,
"group": model.metadata.provider,
"description": model.description or "",
}
)
for slug, metadata in _static_metadata.items():
if slug in _dynamic_models:
continue
options.append(
{
"label": slug,
"value": slug,
"group": metadata.provider,
"description": "",
}
)
return options
async def refresh_llm_registry() -> None:
"""Refresh the LLM registry from the database. Loads all models (enabled and disabled)."""
async with _lock:
try:
records = await prisma.models.LlmModel.prisma().find_many(
include={
"Provider": True,
"Costs": True,
"Creator": True,
}
)
logger.debug("Found %d LLM model records in database", len(records))
except Exception as exc:
logger.error(
"Failed to refresh LLM registry from DB: %s", exc, exc_info=True
)
return
dynamic: dict[str, RegistryModel] = {}
for record in records:
provider_name = (
record.Provider.name if record.Provider else record.providerId
)
provider_display_name = (
record.Provider.displayName if record.Provider else record.providerId
)
# Creator name: prefer Creator.name, fallback to provider display name
creator_name = (
record.Creator.name if record.Creator else provider_display_name
)
# Price tier: default to 1 (cheapest) if not set
price_tier = getattr(record, "priceTier", 1) or 1
# Clamp to valid range 1-3
price_tier = max(1, min(3, price_tier))
metadata = ModelMetadata(
provider=provider_name,
context_window=record.contextWindow,
max_output_tokens=record.maxOutputTokens,
display_name=record.displayName,
provider_name=provider_display_name,
creator_name=creator_name,
price_tier=price_tier, # type: ignore[arg-type]
)
costs = tuple(
RegistryModelCost(
credit_cost=cost.creditCost,
credential_provider=cost.credentialProvider,
credential_id=cost.credentialId,
credential_type=cost.credentialType,
currency=cost.currency,
metadata=_json_to_dict(cost.metadata),
)
for cost in (record.Costs or [])
)
# Map creator if present
creator = None
if record.Creator:
creator = RegistryModelCreator(
id=record.Creator.id,
name=record.Creator.name,
display_name=record.Creator.displayName,
description=record.Creator.description,
website_url=record.Creator.websiteUrl,
logo_url=record.Creator.logoUrl,
)
dynamic[record.slug] = RegistryModel(
slug=record.slug,
display_name=record.displayName,
description=record.description,
metadata=metadata,
capabilities=_json_to_dict(record.capabilities),
extra_metadata=_json_to_dict(record.metadata),
provider_display_name=(
record.Provider.displayName
if record.Provider
else record.providerId
),
is_enabled=record.isEnabled,
is_recommended=record.isRecommended,
costs=costs,
creator=creator,
)
# Atomic swap - build new structures then replace references
# This ensures readers never see partially updated state
global _dynamic_models
_dynamic_models = dynamic
_refresh_cached_schema()
logger.info(
"LLM registry refreshed with %s dynamic models (enabled: %s, disabled: %s)",
len(dynamic),
sum(1 for m in dynamic.values() if m.is_enabled),
sum(1 for m in dynamic.values() if not m.is_enabled),
)
def _refresh_cached_schema() -> None:
"""Refresh cached schema options and discriminator mapping."""
global _schema_options, _discriminator_mapping
# Build new structures
new_options = _build_schema_options()
new_mapping = {
slug: entry.metadata.provider for slug, entry in _dynamic_models.items()
}
for slug, metadata in _static_metadata.items():
new_mapping.setdefault(slug, metadata.provider)
# Atomic swap - replace references to ensure readers see consistent state
_schema_options = new_options
_discriminator_mapping = new_mapping
def get_llm_model_metadata(slug: str) -> ModelMetadata | None:
"""Get model metadata by slug. Checks dynamic models first, then static metadata."""
if slug in _dynamic_models:
return _dynamic_models[slug].metadata
return _static_metadata.get(slug)
def get_llm_model_cost(slug: str) -> tuple[RegistryModelCost, ...]:
"""Get model cost configuration by slug."""
if slug in _dynamic_models:
return _dynamic_models[slug].costs
cost_value = _static_costs.get(slug)
if cost_value is None:
return tuple()
return (
RegistryModelCost(
credit_cost=cost_value,
credential_provider="static",
credential_id=None,
credential_type=None,
currency=None,
metadata={},
),
)
def get_llm_model_schema_options() -> list[dict[str, str]]:
"""
Get schema options for LLM model selection dropdown.
Returns a copy of cached schema options that are refreshed when the registry is
updated via refresh_llm_registry() (called on startup and via Redis pub/sub).
"""
# Return a copy to prevent external mutation
return list(_schema_options)
def get_llm_discriminator_mapping() -> dict[str, str]:
"""
Get discriminator mapping for LLM models.
Returns a copy of cached discriminator mapping that is refreshed when the registry
is updated via refresh_llm_registry() (called on startup and via Redis pub/sub).
"""
# Return a copy to prevent external mutation
return dict(_discriminator_mapping)
def get_dynamic_model_slugs() -> set[str]:
"""Get all dynamic model slugs from the registry."""
return set(_dynamic_models.keys())
def get_all_model_slugs_for_validation() -> set[str]:
"""
Get ALL model slugs (both enabled and disabled) for validation purposes.
This is used for JSON schema enum validation - we need to accept any known
model value (even disabled ones) so that existing graphs don't fail validation.
The actual fallback/enforcement happens at runtime in llm_call().
"""
all_slugs = set(_dynamic_models.keys())
all_slugs.update(_static_metadata.keys())
return all_slugs
def iter_dynamic_models() -> Iterable[RegistryModel]:
"""Iterate over all dynamic models in the registry."""
return tuple(_dynamic_models.values())
def get_fallback_model_for_disabled(disabled_model_slug: str) -> RegistryModel | None:
"""
Find a fallback model when the requested model is disabled.
Looks for an enabled model from the same provider. Prefers models with
similar names or capabilities if possible.
Args:
disabled_model_slug: The slug of the disabled model
Returns:
An enabled RegistryModel from the same provider, or None if no fallback found
"""
disabled_model = _dynamic_models.get(disabled_model_slug)
if not disabled_model:
return None
provider = disabled_model.metadata.provider
# Find all enabled models from the same provider
candidates = [
model
for model in _dynamic_models.values()
if model.is_enabled and model.metadata.provider == provider
]
if not candidates:
return None
# Sort by: prefer models with similar context window, then by name
candidates.sort(
key=lambda m: (
abs(m.metadata.context_window - disabled_model.metadata.context_window),
m.display_name.lower(),
)
)
return candidates[0]
def is_model_enabled(model_slug: str) -> bool:
"""Check if a model is enabled in the registry."""
model = _dynamic_models.get(model_slug)
if not model:
# Model not in registry - assume it's a static/legacy model and allow it
return True
return model.is_enabled
def get_model_info(model_slug: str) -> RegistryModel | None:
"""Get model info from the registry."""
return _dynamic_models.get(model_slug)
def get_default_model_slug() -> str | None:
"""
Get the default model slug to use for block defaults.
Returns the recommended model if set (configured via admin UI),
otherwise returns the first enabled model alphabetically.
Returns None if no models are available or enabled.
"""
# Return the recommended model if one is set and enabled
for model in _dynamic_models.values():
if model.is_recommended and model.is_enabled:
return model.slug
# No recommended model set - find first enabled model alphabetically
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
if model.is_enabled:
logger.warning(
"No recommended model set, using '%s' as default",
model.slug,
)
return model.slug
# No enabled models available
if _dynamic_models:
logger.error(
"No enabled models found in registry (%d models registered but all disabled)",
len(_dynamic_models),
)
else:
logger.error("No models registered in LLM registry")
return None

View File

@@ -0,0 +1,130 @@
"""
Helper utilities for LLM registry integration with block schemas.
This module handles the dynamic injection of discriminator mappings
and model options from the LLM registry into block schemas.
"""
import logging
from typing import Any
from backend.data.llm_registry.registry import (
get_all_model_slugs_for_validation,
get_default_model_slug,
get_llm_discriminator_mapping,
get_llm_model_schema_options,
)
logger = logging.getLogger(__name__)
def is_llm_model_field(field_name: str, field_info: Any) -> bool:
"""
Check if a field is an LLM model selection field.
Returns True if the field has 'options' in json_schema_extra
(set by llm_model_schema_extra() in blocks/llm.py).
"""
if not hasattr(field_info, "json_schema_extra"):
return False
extra = field_info.json_schema_extra
if isinstance(extra, dict):
return "options" in extra
return False
def refresh_llm_model_options(field_schema: dict[str, Any]) -> None:
"""
Refresh LLM model options from the registry.
Updates 'options' (for frontend dropdown) to show only enabled models,
but keeps the 'enum' (for validation) inclusive of ALL known models.
This is important because:
- Options: What users see in the dropdown (enabled models only)
- Enum: What values pass validation (all known models, including disabled)
Existing graphs may have disabled models selected - they should pass validation
and the fallback logic in llm_call() will handle using an alternative model.
"""
fresh_options = get_llm_model_schema_options()
if not fresh_options:
return
# Update options array (UI dropdown) - only enabled models
if "options" in field_schema:
field_schema["options"] = fresh_options
all_known_slugs = get_all_model_slugs_for_validation()
if all_known_slugs and "enum" in field_schema:
existing_enum = set(field_schema.get("enum", []))
combined_enum = existing_enum | all_known_slugs
field_schema["enum"] = sorted(combined_enum)
# Set the default value from the registry (gpt-4o if available, else first enabled)
# This ensures new blocks have a sensible default pre-selected
default_slug = get_default_model_slug()
if default_slug:
field_schema["default"] = default_slug
def refresh_llm_discriminator_mapping(field_schema: dict[str, Any]) -> None:
"""
Refresh discriminator_mapping for fields that use model-based discrimination.
The discriminator is already set when AICredentialsField() creates the field.
We only need to refresh the mapping when models are added/removed.
"""
if field_schema.get("discriminator") != "model":
return
# Always refresh the mapping to get latest models
fresh_mapping = get_llm_discriminator_mapping()
if fresh_mapping is not None:
field_schema["discriminator_mapping"] = fresh_mapping
def update_schema_with_llm_registry(
schema: dict[str, Any], model_class: type | None = None
) -> None:
"""
Update a JSON schema with current LLM registry data.
Refreshes:
1. Model options for LLM model selection fields (dropdown choices)
2. Discriminator mappings for credentials fields (model → provider)
Args:
schema: The JSON schema to update (mutated in-place)
model_class: The Pydantic model class (optional, for field introspection)
"""
properties = schema.get("properties", {})
for field_name, field_schema in properties.items():
if not isinstance(field_schema, dict):
continue
# Refresh model options for LLM model fields
if model_class and hasattr(model_class, "model_fields"):
field_info = model_class.model_fields.get(field_name)
if field_info and is_llm_model_field(field_name, field_info):
try:
refresh_llm_model_options(field_schema)
except Exception as exc:
logger.warning(
"Failed to refresh LLM options for field %s: %s",
field_name,
exc,
)
# Refresh discriminator mapping for fields that use model discrimination
try:
refresh_llm_discriminator_mapping(field_schema)
except Exception as exc:
logger.warning(
"Failed to refresh discriminator mapping for field %s: %s",
field_name,
exc,
)

View File

@@ -29,7 +29,6 @@ from pydantic import (
GetCoreSchemaHandler,
SecretStr,
field_serializer,
model_validator,
)
from pydantic_core import (
CoreSchema,
@@ -40,6 +39,7 @@ from pydantic_core import (
)
from typing_extensions import TypedDict
from backend.data.llm_registry import update_schema_with_llm_registry
from backend.integrations.providers import ProviderName
from backend.util.json import loads as json_loads
from backend.util.request import parse_url
@@ -503,25 +503,6 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
provider: CP
type: CT
@model_validator(mode="before")
@classmethod
def _normalize_legacy_provider(cls, data: Any) -> Any:
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug.
Python 3.13 changed ``str(StrEnum)`` to return ``"ClassName.MEMBER"``
instead of the plain value. Old stored credential references may have
``provider: "ProviderName.MCP"`` instead of ``"mcp"``.
"""
if isinstance(data, dict):
prov = data.get("provider", "")
if isinstance(prov, str) and prov.startswith("ProviderName."):
member = prov.removeprefix("ProviderName.")
try:
data = {**data, "provider": ProviderName[member].value}
except KeyError:
pass
return data
@classmethod
def allowed_providers(cls) -> tuple[ProviderName, ...] | None:
return get_args(cls.model_fields["provider"].annotation)
@@ -570,7 +551,9 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
else:
schema["credentials_provider"] = allowed_providers
schema["credentials_types"] = model_class.allowed_cred_types()
# Do not return anything, just mutate schema in place
# Ensure LLM discriminators are populated (delegates to shared helper)
update_schema_with_llm_registry(schema, model_class)
model_config = ConfigDict(
json_schema_extra=_add_json_schema_extra, # type: ignore
@@ -626,18 +609,11 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
] = defaultdict(list)
for field, key in fields:
if (
field.discriminator
and not field.discriminator_mapping
and field.discriminator_values
):
# URL-based discrimination (e.g. HTTP host-scoped, MCP server URL):
# Each unique host gets its own credential entry.
provider_prefix = next(iter(field.provider))
# Use .value for enum types to get the plain string (e.g. "mcp" not "ProviderName.MCP")
prefix_str = getattr(provider_prefix, "value", str(provider_prefix))
if field.provider == frozenset([ProviderName.HTTP]):
# HTTP host-scoped credentials can have different hosts that reqires different credential sets.
# Group by host extracted from the URL
providers = frozenset(
[cast(CP, prefix_str)]
[cast(CP, "http")]
+ [
cast(CP, parse_url(str(value)).netloc)
for value in field.discriminator_values
@@ -732,16 +708,20 @@ def CredentialsField(
This is enforced by the `BlockSchema` base class.
"""
field_schema_extra = {
k: v
for k, v in {
"credentials_scopes": list(required_scopes) or None,
"discriminator": discriminator,
"discriminator_mapping": discriminator_mapping,
"discriminator_values": discriminator_values,
}.items()
if v is not None
}
# Build field_schema_extra - always include discriminator and mapping if discriminator is set
field_schema_extra: dict[str, Any] = {}
# Always include discriminator if provided
if discriminator is not None:
field_schema_extra["discriminator"] = discriminator
# Always include discriminator_mapping when discriminator is set (even if empty initially)
field_schema_extra["discriminator_mapping"] = discriminator_mapping or {}
# Include other optional fields (only if not None)
if required_scopes:
field_schema_extra["credentials_scopes"] = list(required_scopes)
if discriminator_values:
field_schema_extra["discriminator_values"] = discriminator_values
# Merge any json_schema_extra passed in kwargs
if "json_schema_extra" in kwargs:

View File

@@ -0,0 +1,67 @@
"""
Helper functions for LLM registry initialization in executor context.
These functions handle refreshing the LLM registry when the executor starts
and subscribing to real-time updates via Redis pub/sub.
"""
import logging
from backend.blocks._base import BlockSchema
from backend.data import db, llm_registry
from backend.data.block import initialize_blocks
from backend.data.block_cost_config import refresh_llm_costs
from backend.data.llm_registry import subscribe_to_registry_refresh
logger = logging.getLogger(__name__)
async def initialize_registry_for_executor() -> None:
"""
Initialize blocks and refresh LLM registry in the executor context.
This must run in the executor's event loop to have access to the database.
"""
try:
# Connect to database if not already connected
if not db.is_connected():
await db.connect()
logger.info("[GraphExecutor] Connected to database for registry refresh")
# Initialize blocks (internally refreshes LLM registry and costs)
await initialize_blocks()
logger.info("[GraphExecutor] Blocks initialized")
except Exception as exc:
logger.warning(
"[GraphExecutor] Failed to refresh LLM registry on startup: %s",
exc,
exc_info=True,
)
async def refresh_registry_on_notification() -> None:
"""Refresh LLM registry when notified via Redis pub/sub."""
try:
# Ensure DB is connected
if not db.is_connected():
await db.connect()
# Refresh registry and costs
await llm_registry.refresh_llm_registry()
await refresh_llm_costs()
# Clear block schema caches so they regenerate with new model options
BlockSchema.clear_all_schema_caches()
logger.info("[GraphExecutor] LLM registry refreshed from notification")
except Exception as exc:
logger.error(
"[GraphExecutor] Failed to refresh LLM registry from notification: %s",
exc,
exc_info=True,
)
async def subscribe_to_registry_updates() -> None:
"""Subscribe to Redis pub/sub for LLM registry refresh notifications."""
await subscribe_to_registry_refresh(refresh_registry_on_notification)

View File

@@ -20,7 +20,6 @@ from backend.blocks import get_block
from backend.blocks._base import BlockSchema
from backend.blocks.agent import AgentExecutorBlock
from backend.blocks.io import AgentOutputBlock
from backend.blocks.mcp.block import MCPToolBlock
from backend.data import redis_client as redis
from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry
from backend.data.credit import UsageTransactionMetadata
@@ -229,18 +228,6 @@ async def execute_node(
_input_data.nodes_input_masks = nodes_input_masks
_input_data.user_id = user_id
input_data = _input_data.model_dump()
elif isinstance(node_block, MCPToolBlock):
_mcp_data = MCPToolBlock.Input(**node.input_default)
# Dynamic tool fields are flattened to top-level by validate_exec
# (via get_input_defaults). Collect them back into tool_arguments.
tool_schema = _mcp_data.tool_input_schema
tool_props = set(tool_schema.get("properties", {}).keys())
merged_args = {**_mcp_data.tool_arguments}
for key in tool_props:
if key in input_data:
merged_args[key] = input_data[key]
_mcp_data.tool_arguments = merged_args
input_data = _mcp_data.model_dump()
data.inputs = input_data
# Execute the node
@@ -277,34 +264,8 @@ async def execute_node(
# Handle regular credentials fields
for field_name, input_type in input_model.get_credentials_fields().items():
field_value = input_data.get(field_name)
if not field_value or (
isinstance(field_value, dict) and not field_value.get("id")
):
# No credentials configured — nullify so JSON schema validation
# doesn't choke on the empty default `{}`.
input_data[field_name] = None
continue # Block runs without credentials
credentials_meta = input_type(**field_value)
# Write normalized values back so JSON schema validation also passes
# (model_validator may have fixed legacy formats like "ProviderName.MCP")
input_data[field_name] = credentials_meta.model_dump(mode="json")
try:
credentials, lock = await creds_manager.acquire(
user_id, credentials_meta.id
)
except ValueError:
# Credential was deleted or doesn't exist.
# If the field has a default, run without credentials.
if input_model.model_fields[field_name].default is not None:
log_metadata.warning(
f"Credentials #{credentials_meta.id} not found, "
"running without (field has default)"
)
input_data[field_name] = None
continue
raise
credentials_meta = input_type(**input_data[field_name])
credentials, lock = await creds_manager.acquire(user_id, credentials_meta.id)
creds_locks.append(lock)
extra_exec_kwargs[field_name] = credentials
@@ -747,6 +708,20 @@ class ExecutionProcessor:
)
self.node_execution_thread.start()
self.node_evaluation_thread.start()
# Initialize LLM registry and subscribe to updates
from backend.executor.llm_registry_init import (
initialize_registry_for_executor,
subscribe_to_registry_updates,
)
asyncio.run_coroutine_threadsafe(
initialize_registry_for_executor(), self.node_execution_loop
)
asyncio.run_coroutine_threadsafe(
subscribe_to_registry_updates(), self.node_execution_loop
)
logger.info(f"[GraphExecutor] {self.tid} started")
@error_logged(swallow=False)

View File

@@ -260,13 +260,7 @@ async def _validate_node_input_credentials(
# Track if any credential field is missing for this node
has_missing_credentials = False
# A credential field is optional if the node metadata says so, or if
# the block schema declares a default for the field.
required_fields = block.input_schema.get_required_fields()
is_creds_optional = node.credentials_optional
for field_name, credentials_meta_type in credentials_fields.items():
field_is_optional = is_creds_optional or field_name not in required_fields
try:
# Check nodes_input_masks first, then input_default
field_value = None
@@ -279,7 +273,7 @@ async def _validate_node_input_credentials(
elif field_name in node.input_default:
# For optional credentials, don't use input_default - treat as missing
# This prevents stale credential IDs from failing validation
if field_is_optional:
if node.credentials_optional:
field_value = None
else:
field_value = node.input_default[field_name]
@@ -289,8 +283,8 @@ async def _validate_node_input_credentials(
isinstance(field_value, dict) and not field_value.get("id")
):
has_missing_credentials = True
# If credential field is optional, skip instead of error
if field_is_optional:
# If node has credentials_optional flag, mark for skipping instead of error
if node.credentials_optional:
continue # Don't add error, will be marked for skip after loop
else:
credential_errors[node.id][
@@ -340,16 +334,16 @@ async def _validate_node_input_credentials(
] = "Invalid credentials: type/provider mismatch"
continue
# If node has optional credentials and any are missing, allow running without.
# The executor will pass credentials=None to the block's run().
# If node has optional credentials and any are missing, mark for skipping
# But only if there are no other errors for this node
if (
has_missing_credentials
and is_creds_optional
and node.credentials_optional
and node.id not in credential_errors
):
nodes_to_skip.add(node.id)
logger.info(
f"Node #{node.id}: optional credentials not configured, "
"running without"
f"Node #{node.id} will be skipped: optional credentials not configured"
)
return credential_errors, nodes_to_skip

View File

@@ -495,7 +495,6 @@ async def test_validate_node_input_credentials_returns_nodes_to_skip(
mock_block.input_schema.get_credentials_fields.return_value = {
"credentials": mock_credentials_field_type
}
mock_block.input_schema.get_required_fields.return_value = {"credentials"}
mock_node.block = mock_block
# Create mock graph
@@ -509,8 +508,8 @@ async def test_validate_node_input_credentials_returns_nodes_to_skip(
nodes_input_masks=None,
)
# Node should NOT be in nodes_to_skip (runs without credentials) and not in errors
assert mock_node.id not in nodes_to_skip
# Node should be in nodes_to_skip, not in errors
assert mock_node.id in nodes_to_skip
assert mock_node.id not in errors
@@ -536,7 +535,6 @@ async def test_validate_node_input_credentials_required_missing_creds_error(
mock_block.input_schema.get_credentials_fields.return_value = {
"credentials": mock_credentials_field_type
}
mock_block.input_schema.get_required_fields.return_value = {"credentials"}
mock_node.block = mock_block
# Create mock graph

View File

@@ -22,27 +22,6 @@ from backend.util.settings import Settings
settings = Settings()
def provider_matches(stored: str, expected: str) -> bool:
"""Compare provider strings, handling Python 3.13 ``str(StrEnum)`` bug.
On Python 3.13, ``str(ProviderName.MCP)`` returns ``"ProviderName.MCP"``
instead of ``"mcp"``. OAuth states persisted with the buggy format need
to match when ``expected`` is the canonical value (e.g. ``"mcp"``).
"""
if stored == expected:
return True
if stored.startswith("ProviderName."):
member = stored.removeprefix("ProviderName.")
from backend.integrations.providers import ProviderName
try:
return ProviderName[member].value == expected
except KeyError:
pass
return False
# This is an overrride since ollama doesn't actually require an API key, but the creddential system enforces one be attached
ollama_credentials = APIKeyCredentials(
id="744fdc56-071a-4761-b5a5-0af0ce10a2b5",
@@ -410,7 +389,7 @@ class IntegrationCredentialsStore:
self, user_id: str, provider: str
) -> list[Credentials]:
credentials = await self.get_all_creds(user_id)
return [c for c in credentials if provider_matches(c.provider, provider)]
return [c for c in credentials if c.provider == provider]
async def get_authorized_providers(self, user_id: str) -> list[str]:
credentials = await self.get_all_creds(user_id)
@@ -506,6 +485,17 @@ class IntegrationCredentialsStore:
async with self.edit_user_integrations(user_id) as user_integrations:
user_integrations.oauth_states.append(state)
async with await self.locked_user_integrations(user_id):
user_integrations = await self._get_user_integrations(user_id)
oauth_states = user_integrations.oauth_states
oauth_states.append(state)
user_integrations.oauth_states = oauth_states
await self.db_manager.update_user_integrations(
user_id=user_id, data=user_integrations
)
return token, code_challenge
def _generate_code_challenge(self) -> tuple[str, str]:
@@ -531,7 +521,7 @@ class IntegrationCredentialsStore:
state
for state in oauth_states
if secrets.compare_digest(state.token, token)
and provider_matches(state.provider, provider)
and state.provider == provider
and state.expires_at > now.timestamp()
),
None,

View File

@@ -9,10 +9,7 @@ from redis.asyncio.lock import Lock as AsyncRedisLock
from backend.data.model import Credentials, OAuth2Credentials
from backend.data.redis_client import get_redis_async
from backend.integrations.credentials_store import (
IntegrationCredentialsStore,
provider_matches,
)
from backend.integrations.credentials_store import IntegrationCredentialsStore
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.util.exceptions import MissingConfigError
@@ -140,10 +137,7 @@ class IntegrationCredentialsManager:
self, user_id: str, credentials: OAuth2Credentials, lock: bool = True
) -> OAuth2Credentials:
async with self._locked(user_id, credentials.id, "refresh"):
if provider_matches(credentials.provider, ProviderName.MCP.value):
oauth_handler = create_mcp_oauth_handler(credentials)
else:
oauth_handler = await _get_provider_oauth_handler(credentials.provider)
oauth_handler = await _get_provider_oauth_handler(credentials.provider)
if oauth_handler.needs_refresh(credentials):
logger.debug(
f"Refreshing '{credentials.provider}' "
@@ -242,31 +236,3 @@ async def _get_provider_oauth_handler(provider_name_str: str) -> "BaseOAuthHandl
client_secret=client_secret,
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
)
def create_mcp_oauth_handler(
credentials: OAuth2Credentials,
) -> "BaseOAuthHandler":
"""Create an MCPOAuthHandler from credential metadata for token refresh.
MCP OAuth handlers have dynamic endpoints discovered per-server, so they
can't be registered as singletons in HANDLERS_BY_NAME. Instead, the handler
is reconstructed from metadata stored on the credential during initial auth.
"""
from backend.blocks.mcp.oauth import MCPOAuthHandler
meta = credentials.metadata or {}
token_url = meta.get("mcp_token_url", "")
if not token_url:
raise ValueError(
f"MCP credential {credentials.id} is missing 'mcp_token_url' metadata; "
"cannot refresh tokens"
)
return MCPOAuthHandler(
client_id=meta.get("mcp_client_id", ""),
client_secret=meta.get("mcp_client_secret", ""),
redirect_uri="", # Not needed for token refresh
authorize_url="", # Not needed for token refresh
token_url=token_url,
resource_url=meta.get("mcp_resource_url"),
)

View File

@@ -30,7 +30,6 @@ class ProviderName(str, Enum):
IDEOGRAM = "ideogram"
JINA = "jina"
LLAMA_API = "llama_api"
MCP = "mcp"
MEDIUM = "medium"
MEM0 = "mem0"
NOTION = "notion"

View File

@@ -51,21 +51,6 @@ async def _on_graph_activate(graph: "BaseGraph | GraphModel", user_id: str):
if (
creds_meta := new_node.input_default.get(creds_field_name)
) and not await get_credentials(creds_meta["id"]):
# If the credential field is optional (has a default in the
# schema, or node metadata marks it optional), clear the stale
# reference instead of blocking the save.
creds_field_optional = (
new_node.credentials_optional
or creds_field_name not in block_input_schema.get_required_fields()
)
if creds_field_optional:
new_node.input_default[creds_field_name] = {}
logger.warning(
f"Node #{new_node.id}: cleared stale optional "
f"credentials #{creds_meta['id']} for "
f"'{creds_field_name}'"
)
continue
raise ValueError(
f"Node #{new_node.id} input '{creds_field_name}' updated with "
f"non-existent credentials #{creds_meta['id']}"

View File

@@ -0,0 +1,941 @@
from __future__ import annotations
from typing import Any, Iterable, Sequence, cast
import prisma
import prisma.models
from backend.data.db import transaction
from backend.server.v2.llm import model as llm_model
from backend.util.models import Pagination
def _json_dict(value: Any | None) -> dict[str, Any]:
if not value:
return {}
if isinstance(value, dict):
return value
return {}
def _map_cost(record: prisma.models.LlmModelCost) -> llm_model.LlmModelCost:
return llm_model.LlmModelCost(
id=record.id,
unit=record.unit,
credit_cost=record.creditCost,
credential_provider=record.credentialProvider,
credential_id=record.credentialId,
credential_type=record.credentialType,
currency=record.currency,
metadata=_json_dict(record.metadata),
)
def _map_creator(
record: prisma.models.LlmModelCreator,
) -> llm_model.LlmModelCreator:
return llm_model.LlmModelCreator(
id=record.id,
name=record.name,
display_name=record.displayName,
description=record.description,
website_url=record.websiteUrl,
logo_url=record.logoUrl,
metadata=_json_dict(record.metadata),
)
def _map_model(record: prisma.models.LlmModel) -> llm_model.LlmModel:
costs = []
if record.Costs:
costs = [_map_cost(cost) for cost in record.Costs]
creator = None
if hasattr(record, "Creator") and record.Creator:
creator = _map_creator(record.Creator)
return llm_model.LlmModel(
id=record.id,
slug=record.slug,
display_name=record.displayName,
description=record.description,
provider_id=record.providerId,
creator_id=record.creatorId,
creator=creator,
context_window=record.contextWindow,
max_output_tokens=record.maxOutputTokens,
is_enabled=record.isEnabled,
is_recommended=record.isRecommended,
capabilities=_json_dict(record.capabilities),
metadata=_json_dict(record.metadata),
costs=costs,
)
def _map_provider(record: prisma.models.LlmProvider) -> llm_model.LlmProvider:
models: list[llm_model.LlmModel] = []
if record.Models:
models = [_map_model(model) for model in record.Models]
return llm_model.LlmProvider(
id=record.id,
name=record.name,
display_name=record.displayName,
description=record.description,
default_credential_provider=record.defaultCredentialProvider,
default_credential_id=record.defaultCredentialId,
default_credential_type=record.defaultCredentialType,
supports_tools=record.supportsTools,
supports_json_output=record.supportsJsonOutput,
supports_reasoning=record.supportsReasoning,
supports_parallel_tool=record.supportsParallelTool,
metadata=_json_dict(record.metadata),
models=models,
)
async def list_providers(
include_models: bool = True, enabled_only: bool = False
) -> list[llm_model.LlmProvider]:
"""
List all LLM providers.
Args:
include_models: Whether to include models for each provider
enabled_only: If True, only include enabled models (for public routes)
"""
include: Any = None
if include_models:
model_where = {"isEnabled": True} if enabled_only else None
include = {
"Models": {
"include": {"Costs": True, "Creator": True},
"where": model_where,
}
}
records = await prisma.models.LlmProvider.prisma().find_many(include=include)
return [_map_provider(record) for record in records]
async def upsert_provider(
request: llm_model.UpsertLlmProviderRequest,
provider_id: str | None = None,
) -> llm_model.LlmProvider:
data: Any = {
"name": request.name,
"displayName": request.display_name,
"description": request.description,
"defaultCredentialProvider": request.default_credential_provider,
"defaultCredentialId": request.default_credential_id,
"defaultCredentialType": request.default_credential_type,
"supportsTools": request.supports_tools,
"supportsJsonOutput": request.supports_json_output,
"supportsReasoning": request.supports_reasoning,
"supportsParallelTool": request.supports_parallel_tool,
"metadata": prisma.Json(request.metadata or {}),
}
include: Any = {"Models": {"include": {"Costs": True, "Creator": True}}}
if provider_id:
record = await prisma.models.LlmProvider.prisma().update(
where={"id": provider_id},
data=data,
include=include,
)
else:
record = await prisma.models.LlmProvider.prisma().create(
data=data,
include=include,
)
if record is None:
raise ValueError("Failed to create/update provider")
return _map_provider(record)
async def delete_provider(provider_id: str) -> bool:
"""
Delete an LLM provider.
A provider can only be deleted if it has no associated models.
Due to onDelete: Restrict on LlmModel.Provider, the database will
block deletion if models exist.
Args:
provider_id: UUID of the provider to delete
Returns:
True if deleted successfully
Raises:
ValueError: If provider not found or has associated models
"""
# Check if provider exists
provider = await prisma.models.LlmProvider.prisma().find_unique(
where={"id": provider_id},
include={"Models": True},
)
if not provider:
raise ValueError(f"Provider with id '{provider_id}' not found")
# Check if provider has any models
model_count = len(provider.Models) if provider.Models else 0
if model_count > 0:
raise ValueError(
f"Cannot delete provider '{provider.displayName}' because it has "
f"{model_count} model(s). Delete all models first."
)
# Safe to delete
await prisma.models.LlmProvider.prisma().delete(where={"id": provider_id})
return True
async def list_models(
provider_id: str | None = None,
enabled_only: bool = False,
page: int = 1,
page_size: int = 50,
) -> llm_model.LlmModelsResponse:
"""
List LLM models with pagination.
Args:
provider_id: Optional filter by provider ID
enabled_only: If True, only return enabled models (for public routes)
page: Page number (1-indexed)
page_size: Number of models per page
"""
# Validate pagination inputs to avoid runtime errors
if page_size < 1:
page_size = 50
if page < 1:
page = 1
where: Any = {}
if provider_id:
where["providerId"] = provider_id
if enabled_only:
where["isEnabled"] = True
# Get total count for pagination
total_items = await prisma.models.LlmModel.prisma().count(
where=where if where else None
)
# Calculate pagination
skip = (page - 1) * page_size
total_pages = (total_items + page_size - 1) // page_size if total_items > 0 else 0
records = await prisma.models.LlmModel.prisma().find_many(
where=where if where else None,
include={"Costs": True, "Creator": True},
skip=skip,
take=page_size,
)
models = [_map_model(record) for record in records]
return llm_model.LlmModelsResponse(
models=models,
pagination=Pagination(
total_items=total_items,
total_pages=total_pages,
current_page=page,
page_size=page_size,
),
)
def _cost_create_payload(
costs: Sequence[llm_model.LlmModelCostInput],
) -> dict[str, Iterable[dict[str, Any]]]:
create_items = []
for cost in costs:
item: dict[str, Any] = {
"unit": cost.unit,
"creditCost": cost.credit_cost,
"credentialProvider": cost.credential_provider,
}
# Only include optional fields if they have values
if cost.credential_id:
item["credentialId"] = cost.credential_id
if cost.credential_type:
item["credentialType"] = cost.credential_type
if cost.currency:
item["currency"] = cost.currency
# Handle metadata - use Prisma Json type
if cost.metadata is not None and cost.metadata != {}:
item["metadata"] = prisma.Json(cost.metadata)
create_items.append(item)
return {"create": create_items}
async def create_model(
request: llm_model.CreateLlmModelRequest,
) -> llm_model.LlmModel:
data: Any = {
"slug": request.slug,
"displayName": request.display_name,
"description": request.description,
"Provider": {"connect": {"id": request.provider_id}},
"contextWindow": request.context_window,
"maxOutputTokens": request.max_output_tokens,
"isEnabled": request.is_enabled,
"capabilities": prisma.Json(request.capabilities or {}),
"metadata": prisma.Json(request.metadata or {}),
"Costs": _cost_create_payload(request.costs),
}
if request.creator_id:
data["Creator"] = {"connect": {"id": request.creator_id}}
record = await prisma.models.LlmModel.prisma().create(
data=data,
include={"Costs": True, "Creator": True, "Provider": True},
)
return _map_model(record)
async def update_model(
model_id: str,
request: llm_model.UpdateLlmModelRequest,
) -> llm_model.LlmModel:
# Build scalar field updates (non-relation fields)
scalar_data: Any = {}
if request.display_name is not None:
scalar_data["displayName"] = request.display_name
if request.description is not None:
scalar_data["description"] = request.description
if request.context_window is not None:
scalar_data["contextWindow"] = request.context_window
if request.max_output_tokens is not None:
scalar_data["maxOutputTokens"] = request.max_output_tokens
if request.is_enabled is not None:
scalar_data["isEnabled"] = request.is_enabled
if request.capabilities is not None:
scalar_data["capabilities"] = request.capabilities
if request.metadata is not None:
scalar_data["metadata"] = request.metadata
# Foreign keys can be updated directly as scalar fields
if request.provider_id is not None:
scalar_data["providerId"] = request.provider_id
if request.creator_id is not None:
# Empty string means remove the creator
scalar_data["creatorId"] = request.creator_id if request.creator_id else None
# If we have costs to update, we need to handle them separately
# because nested writes have different constraints
if request.costs is not None:
# Wrap cost replacement in a transaction for atomicity
async with transaction() as tx:
# First update scalar fields
if scalar_data:
await tx.llmmodel.update(
where={"id": model_id},
data=scalar_data,
)
# Then handle costs: delete existing and create new
await tx.llmmodelcost.delete_many(where={"llmModelId": model_id})
if request.costs:
cost_payload = _cost_create_payload(request.costs)
for cost_item in cost_payload["create"]:
cost_item["llmModelId"] = model_id
await tx.llmmodelcost.create(data=cast(Any, cost_item))
# Fetch the updated record (outside transaction)
record = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id},
include={"Costs": True, "Creator": True},
)
else:
# No costs update - simple update
record = await prisma.models.LlmModel.prisma().update(
where={"id": model_id},
data=scalar_data,
include={"Costs": True, "Creator": True},
)
if not record:
raise ValueError(f"Model with id '{model_id}' not found")
return _map_model(record)
async def toggle_model(
model_id: str,
is_enabled: bool,
migrate_to_slug: str | None = None,
migration_reason: str | None = None,
custom_credit_cost: int | None = None,
) -> llm_model.ToggleLlmModelResponse:
"""
Toggle a model's enabled status, optionally migrating workflows when disabling.
Args:
model_id: UUID of the model to toggle
is_enabled: New enabled status
migrate_to_slug: If disabling and this is provided, migrate all workflows
using this model to the specified replacement model
migration_reason: Optional reason for the migration (e.g., "Provider outage")
custom_credit_cost: Optional custom pricing override for migrated workflows.
When set, the billing system should use this cost instead
of the target model's cost for affected nodes.
Returns:
ToggleLlmModelResponse with the updated model and optional migration stats
"""
import json
# Get the model being toggled
model = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id}, include={"Costs": True}
)
if not model:
raise ValueError(f"Model with id '{model_id}' not found")
nodes_migrated = 0
migration_id: str | None = None
# If disabling with migration, perform migration first
if not is_enabled and migrate_to_slug:
# Validate replacement model exists and is enabled
replacement = await prisma.models.LlmModel.prisma().find_unique(
where={"slug": migrate_to_slug}
)
if not replacement:
raise ValueError(f"Replacement model '{migrate_to_slug}' not found")
if not replacement.isEnabled:
raise ValueError(
f"Replacement model '{migrate_to_slug}' is disabled. "
f"Please enable it before using it as a replacement."
)
# Perform all operations atomically within a single transaction
# This ensures no nodes are missed between query and update
async with transaction() as tx:
# Get the IDs of nodes that will be migrated (inside transaction for consistency)
node_ids_result = await tx.query_raw(
"""
SELECT id
FROM "AgentNode"
WHERE "constantInput"::jsonb->>'model' = $1
FOR UPDATE
""",
model.slug,
)
migrated_node_ids = (
[row["id"] for row in node_ids_result] if node_ids_result else []
)
nodes_migrated = len(migrated_node_ids)
if nodes_migrated > 0:
# Update by IDs to ensure we only update the exact nodes we queried
# Use JSON array and jsonb_array_elements_text for safe parameterization
node_ids_json = json.dumps(migrated_node_ids)
await tx.execute_raw(
"""
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
to_jsonb($1::text)
)
WHERE id::text IN (
SELECT jsonb_array_elements_text($2::jsonb)
)
""",
migrate_to_slug,
node_ids_json,
)
record = await tx.llmmodel.update(
where={"id": model_id},
data={"isEnabled": is_enabled},
include={"Costs": True},
)
# Create migration record for revert capability
if nodes_migrated > 0:
migration_data: Any = {
"sourceModelSlug": model.slug,
"targetModelSlug": migrate_to_slug,
"reason": migration_reason,
"migratedNodeIds": json.dumps(migrated_node_ids),
"nodeCount": nodes_migrated,
"customCreditCost": custom_credit_cost,
}
migration_record = await tx.llmmodelmigration.create(
data=migration_data
)
migration_id = migration_record.id
else:
# Simple toggle without migration
record = await prisma.models.LlmModel.prisma().update(
where={"id": model_id},
data={"isEnabled": is_enabled},
include={"Costs": True},
)
if record is None:
raise ValueError(f"Model with id '{model_id}' not found")
return llm_model.ToggleLlmModelResponse(
model=_map_model(record),
nodes_migrated=nodes_migrated,
migrated_to_slug=migrate_to_slug if nodes_migrated > 0 else None,
migration_id=migration_id,
)
async def get_model_usage(model_id: str) -> llm_model.LlmModelUsageResponse:
"""Get usage count for a model."""
import prisma as prisma_module
model = await prisma.models.LlmModel.prisma().find_unique(where={"id": model_id})
if not model:
raise ValueError(f"Model with id '{model_id}' not found")
count_result = await prisma_module.get_client().query_raw(
"""
SELECT COUNT(*) as count
FROM "AgentNode"
WHERE "constantInput"::jsonb->>'model' = $1
""",
model.slug,
)
node_count = int(count_result[0]["count"]) if count_result else 0
return llm_model.LlmModelUsageResponse(model_slug=model.slug, node_count=node_count)
async def delete_model(
model_id: str, replacement_model_slug: str | None = None
) -> llm_model.DeleteLlmModelResponse:
"""
Delete a model and optionally migrate all AgentNodes using it to a replacement model.
This performs an atomic operation within a database transaction:
1. Validates the model exists
2. Counts affected nodes
3. If nodes exist, validates replacement model and migrates them
4. Deletes the LlmModel record (CASCADE deletes costs)
Args:
model_id: UUID of the model to delete
replacement_model_slug: Slug of the model to migrate to (required only if nodes use this model)
Returns:
DeleteLlmModelResponse with migration stats
Raises:
ValueError: If model not found, nodes exist but no replacement provided,
replacement not found, or replacement is disabled
"""
# 1. Get the model being deleted (early validation - outside transaction)
model = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id}, include={"Costs": True}
)
if not model:
raise ValueError(f"Model with id '{model_id}' not found")
deleted_slug = model.slug
deleted_display_name = model.displayName
# 2. Perform all mutation logic atomically within a transaction
# This prevents TOCTOU issues where nodes could be created between count and delete
async with transaction() as tx:
# Count affected nodes inside the transaction
count_result = await tx.query_raw(
"""
SELECT COUNT(*) as count
FROM "AgentNode"
WHERE "constantInput"::jsonb->>'model' = $1
""",
deleted_slug,
)
nodes_to_migrate = int(count_result[0]["count"]) if count_result else 0
# Validate replacement model only if there are nodes to migrate
if nodes_to_migrate > 0:
if not replacement_model_slug:
raise ValueError(
f"Cannot delete model '{deleted_slug}': {nodes_to_migrate} workflow node(s) "
f"are using it. Please provide a replacement_model_slug to migrate them."
)
replacement = await tx.llmmodel.find_unique(
where={"slug": replacement_model_slug}
)
if not replacement:
raise ValueError(
f"Replacement model '{replacement_model_slug}' not found"
)
if not replacement.isEnabled:
raise ValueError(
f"Replacement model '{replacement_model_slug}' is disabled. "
f"Please enable it before using it as a replacement."
)
# Migrate all AgentNode.constantInput->model to replacement
await tx.execute_raw(
"""
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
to_jsonb($1::text)
)
WHERE "constantInput"::jsonb->>'model' = $2
""",
replacement_model_slug,
deleted_slug,
)
# Delete the model (CASCADE will delete costs automatically)
await tx.llmmodel.delete(where={"id": model_id})
# Build appropriate message based on whether migration happened
if nodes_to_migrate > 0:
message = (
f"Successfully deleted model '{deleted_display_name}' ({deleted_slug}) "
f"and migrated {nodes_to_migrate} workflow node(s) to '{replacement_model_slug}'."
)
else:
message = (
f"Successfully deleted model '{deleted_display_name}' ({deleted_slug}). "
f"No workflows were using this model."
)
return llm_model.DeleteLlmModelResponse(
deleted_model_slug=deleted_slug,
deleted_model_display_name=deleted_display_name,
replacement_model_slug=replacement_model_slug,
nodes_migrated=nodes_to_migrate,
message=message,
)
def _map_migration(
record: prisma.models.LlmModelMigration,
) -> llm_model.LlmModelMigration:
return llm_model.LlmModelMigration(
id=record.id,
source_model_slug=record.sourceModelSlug,
target_model_slug=record.targetModelSlug,
reason=record.reason,
node_count=record.nodeCount,
custom_credit_cost=record.customCreditCost,
is_reverted=record.isReverted,
created_at=record.createdAt,
reverted_at=record.revertedAt,
)
async def list_migrations(
include_reverted: bool = False,
) -> list[llm_model.LlmModelMigration]:
"""
List model migrations, optionally including reverted ones.
Args:
include_reverted: If True, include reverted migrations. Default is False.
Returns:
List of LlmModelMigration records
"""
where: Any = None if include_reverted else {"isReverted": False}
records = await prisma.models.LlmModelMigration.prisma().find_many(
where=where,
order={"createdAt": "desc"},
)
return [_map_migration(record) for record in records]
async def get_migration(migration_id: str) -> llm_model.LlmModelMigration | None:
"""Get a specific migration by ID."""
record = await prisma.models.LlmModelMigration.prisma().find_unique(
where={"id": migration_id}
)
return _map_migration(record) if record else None
async def revert_migration(
migration_id: str,
re_enable_source_model: bool = True,
) -> llm_model.RevertMigrationResponse:
"""
Revert a model migration, restoring affected nodes to their original model.
This only reverts the specific nodes that were migrated, not all nodes
currently using the target model.
Args:
migration_id: UUID of the migration to revert
re_enable_source_model: Whether to re-enable the source model if it's disabled
Returns:
RevertMigrationResponse with revert stats
Raises:
ValueError: If migration not found, already reverted, or source model not available
"""
import json
from datetime import datetime, timezone
# Get the migration record
migration = await prisma.models.LlmModelMigration.prisma().find_unique(
where={"id": migration_id}
)
if not migration:
raise ValueError(f"Migration with id '{migration_id}' not found")
if migration.isReverted:
raise ValueError(
f"Migration '{migration_id}' has already been reverted "
f"on {migration.revertedAt.isoformat() if migration.revertedAt else 'unknown date'}"
)
# Check if source model exists
source_model = await prisma.models.LlmModel.prisma().find_unique(
where={"slug": migration.sourceModelSlug}
)
if not source_model:
raise ValueError(
f"Source model '{migration.sourceModelSlug}' no longer exists. "
f"Cannot revert migration."
)
# Get the migrated node IDs (Prisma auto-parses JSONB to list)
migrated_node_ids: list[str] = (
migration.migratedNodeIds
if isinstance(migration.migratedNodeIds, list)
else json.loads(migration.migratedNodeIds) # type: ignore
)
if not migrated_node_ids:
raise ValueError("No nodes to revert in this migration")
# Track if we need to re-enable the source model
source_model_was_disabled = not source_model.isEnabled
should_re_enable = source_model_was_disabled and re_enable_source_model
source_model_re_enabled = False
# Perform revert atomically
async with transaction() as tx:
# Re-enable the source model if requested and it was disabled
if should_re_enable:
await tx.llmmodel.update(
where={"id": source_model.id},
data={"isEnabled": True},
)
source_model_re_enabled = True
# Update only the specific nodes that were migrated
# We need to check that they still have the target model (haven't been changed since)
# Use a single batch update for efficiency
# Use JSON array and jsonb_array_elements_text for safe parameterization
node_ids_json = json.dumps(migrated_node_ids)
result = await tx.execute_raw(
"""
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
to_jsonb($1::text)
)
WHERE id::text IN (
SELECT jsonb_array_elements_text($2::jsonb)
)
AND "constantInput"::jsonb->>'model' = $3
""",
migration.sourceModelSlug,
node_ids_json,
migration.targetModelSlug,
)
nodes_reverted = result if result else 0
# Mark migration as reverted
await tx.llmmodelmigration.update(
where={"id": migration_id},
data={
"isReverted": True,
"revertedAt": datetime.now(timezone.utc),
},
)
# Calculate nodes that were already changed since migration
nodes_already_changed = len(migrated_node_ids) - nodes_reverted
# Build appropriate message
message_parts = [
f"Successfully reverted migration: {nodes_reverted} node(s) restored "
f"from '{migration.targetModelSlug}' to '{migration.sourceModelSlug}'."
]
if nodes_already_changed > 0:
message_parts.append(
f" {nodes_already_changed} node(s) were already changed and not reverted."
)
if source_model_re_enabled:
message_parts.append(
f" Model '{migration.sourceModelSlug}' has been re-enabled."
)
return llm_model.RevertMigrationResponse(
migration_id=migration_id,
source_model_slug=migration.sourceModelSlug,
target_model_slug=migration.targetModelSlug,
nodes_reverted=nodes_reverted,
nodes_already_changed=nodes_already_changed,
source_model_re_enabled=source_model_re_enabled,
message="".join(message_parts),
)
# ============================================================================
# Creator CRUD operations
# ============================================================================
async def list_creators() -> list[llm_model.LlmModelCreator]:
"""List all LLM model creators."""
records = await prisma.models.LlmModelCreator.prisma().find_many(
order={"displayName": "asc"}
)
return [_map_creator(record) for record in records]
async def get_creator(creator_id: str) -> llm_model.LlmModelCreator | None:
"""Get a specific creator by ID."""
record = await prisma.models.LlmModelCreator.prisma().find_unique(
where={"id": creator_id}
)
return _map_creator(record) if record else None
async def upsert_creator(
request: llm_model.UpsertLlmCreatorRequest,
creator_id: str | None = None,
) -> llm_model.LlmModelCreator:
"""Create or update a model creator."""
data: Any = {
"name": request.name,
"displayName": request.display_name,
"description": request.description,
"websiteUrl": request.website_url,
"logoUrl": request.logo_url,
"metadata": prisma.Json(request.metadata or {}),
}
if creator_id:
record = await prisma.models.LlmModelCreator.prisma().update(
where={"id": creator_id},
data=data,
)
else:
record = await prisma.models.LlmModelCreator.prisma().create(data=data)
if record is None:
raise ValueError("Failed to create/update creator")
return _map_creator(record)
async def delete_creator(creator_id: str) -> bool:
"""
Delete a model creator.
This will set creatorId to NULL on all associated models (due to onDelete: SetNull).
Args:
creator_id: UUID of the creator to delete
Returns:
True if deleted successfully
Raises:
ValueError: If creator not found
"""
creator = await prisma.models.LlmModelCreator.prisma().find_unique(
where={"id": creator_id}
)
if not creator:
raise ValueError(f"Creator with id '{creator_id}' not found")
await prisma.models.LlmModelCreator.prisma().delete(where={"id": creator_id})
return True
async def get_recommended_model() -> llm_model.LlmModel | None:
"""
Get the currently recommended LLM model.
Returns:
The recommended model, or None if no model is marked as recommended.
"""
record = await prisma.models.LlmModel.prisma().find_first(
where={"isRecommended": True, "isEnabled": True},
include={"Costs": True, "Creator": True},
)
return _map_model(record) if record else None
async def set_recommended_model(
model_id: str,
) -> tuple[llm_model.LlmModel, str | None]:
"""
Set a model as the recommended model.
This will clear the isRecommended flag from any other model and set it
on the specified model. The model must be enabled.
Args:
model_id: UUID of the model to set as recommended
Returns:
Tuple of (the updated model, previous recommended model slug or None)
Raises:
ValueError: If model not found or not enabled
"""
# First, verify the model exists and is enabled
target_model = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id}
)
if not target_model:
raise ValueError(f"Model with id '{model_id}' not found")
if not target_model.isEnabled:
raise ValueError(
f"Cannot set disabled model '{target_model.slug}' as recommended"
)
# Get the current recommended model (if any)
current_recommended = await prisma.models.LlmModel.prisma().find_first(
where={"isRecommended": True}
)
previous_slug = current_recommended.slug if current_recommended else None
# Use a transaction to ensure atomicity
async with transaction() as tx:
# Clear isRecommended from all models
await tx.llmmodel.update_many(
where={"isRecommended": True},
data={"isRecommended": False},
)
# Set the new recommended model
await tx.llmmodel.update(
where={"id": model_id},
data={"isRecommended": True},
)
# Fetch and return the updated model
updated_record = await prisma.models.LlmModel.prisma().find_unique(
where={"id": model_id},
include={"Costs": True, "Creator": True},
)
if not updated_record:
raise ValueError("Failed to fetch updated model")
return _map_model(updated_record), previous_slug
async def get_recommended_model_slug() -> str | None:
"""
Get the slug of the currently recommended LLM model.
Returns:
The slug of the recommended model, or None if no model is marked as recommended.
"""
record = await prisma.models.LlmModel.prisma().find_first(
where={"isRecommended": True, "isEnabled": True},
)
return record.slug if record else None

View File

@@ -0,0 +1,235 @@
from __future__ import annotations
import re
from datetime import datetime
from typing import Any, Optional
import prisma.enums
import pydantic
from backend.util.models import Pagination
# Pattern for valid model slugs: alphanumeric start, then alphanumeric, dots, underscores, slashes, hyphens
SLUG_PATTERN = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9._/-]*$")
class LlmModelCost(pydantic.BaseModel):
id: str
unit: prisma.enums.LlmCostUnit = prisma.enums.LlmCostUnit.RUN
credit_cost: int
credential_provider: str
credential_id: Optional[str] = None
credential_type: Optional[str] = None
currency: Optional[str] = None
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class LlmModelCreator(pydantic.BaseModel):
"""Represents the organization that created/trained the model (e.g., OpenAI, Meta)."""
id: str
name: str
display_name: str
description: Optional[str] = None
website_url: Optional[str] = None
logo_url: Optional[str] = None
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class LlmModel(pydantic.BaseModel):
id: str
slug: str
display_name: str
description: Optional[str] = None
provider_id: str
creator_id: Optional[str] = None
creator: Optional[LlmModelCreator] = None
context_window: int
max_output_tokens: Optional[int] = None
is_enabled: bool = True
is_recommended: bool = False
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
costs: list[LlmModelCost] = pydantic.Field(default_factory=list)
class LlmProvider(pydantic.BaseModel):
id: str
name: str
display_name: str
description: Optional[str] = None
default_credential_provider: Optional[str] = None
default_credential_id: Optional[str] = None
default_credential_type: Optional[str] = None
supports_tools: bool = True
supports_json_output: bool = True
supports_reasoning: bool = False
supports_parallel_tool: bool = False
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
models: list[LlmModel] = pydantic.Field(default_factory=list)
class LlmProvidersResponse(pydantic.BaseModel):
providers: list[LlmProvider]
class LlmModelsResponse(pydantic.BaseModel):
models: list[LlmModel]
pagination: Optional[Pagination] = None
class LlmCreatorsResponse(pydantic.BaseModel):
creators: list[LlmModelCreator]
class UpsertLlmProviderRequest(pydantic.BaseModel):
name: str
display_name: str
description: Optional[str] = None
default_credential_provider: Optional[str] = None
default_credential_id: Optional[str] = None
default_credential_type: Optional[str] = "api_key"
supports_tools: bool = True
supports_json_output: bool = True
supports_reasoning: bool = False
supports_parallel_tool: bool = False
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class UpsertLlmCreatorRequest(pydantic.BaseModel):
name: str
display_name: str
description: Optional[str] = None
website_url: Optional[str] = None
logo_url: Optional[str] = None
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class LlmModelCostInput(pydantic.BaseModel):
unit: prisma.enums.LlmCostUnit = prisma.enums.LlmCostUnit.RUN
credit_cost: int
credential_provider: str
credential_id: Optional[str] = None
credential_type: Optional[str] = "api_key"
currency: Optional[str] = None
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
class CreateLlmModelRequest(pydantic.BaseModel):
slug: str
display_name: str
description: Optional[str] = None
provider_id: str
creator_id: Optional[str] = None
context_window: int
max_output_tokens: Optional[int] = None
is_enabled: bool = True
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
costs: list[LlmModelCostInput]
@pydantic.field_validator("slug")
@classmethod
def validate_slug(cls, v: str) -> str:
if not v or len(v) > 100:
raise ValueError("Slug must be 1-100 characters")
if not SLUG_PATTERN.match(v):
raise ValueError(
"Slug must start with alphanumeric and contain only "
"alphanumeric characters, dots, underscores, slashes, or hyphens"
)
return v
class UpdateLlmModelRequest(pydantic.BaseModel):
display_name: Optional[str] = None
description: Optional[str] = None
context_window: Optional[int] = None
max_output_tokens: Optional[int] = None
is_enabled: Optional[bool] = None
capabilities: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
provider_id: Optional[str] = None
creator_id: Optional[str] = None
costs: Optional[list[LlmModelCostInput]] = None
class ToggleLlmModelRequest(pydantic.BaseModel):
is_enabled: bool
migrate_to_slug: Optional[str] = None
migration_reason: Optional[str] = None # e.g., "Provider outage"
# Custom pricing override for migrated workflows. When set, billing should use
# this cost instead of the target model's cost for affected nodes.
# See LlmModelMigration in schema.prisma for full documentation.
custom_credit_cost: Optional[int] = None
class ToggleLlmModelResponse(pydantic.BaseModel):
model: LlmModel
nodes_migrated: int = 0
migrated_to_slug: Optional[str] = None
migration_id: Optional[str] = None # ID of the migration record for revert
class DeleteLlmModelResponse(pydantic.BaseModel):
deleted_model_slug: str
deleted_model_display_name: str
replacement_model_slug: Optional[str] = None
nodes_migrated: int
message: str
class LlmModelUsageResponse(pydantic.BaseModel):
model_slug: str
node_count: int
# Migration tracking models
class LlmModelMigration(pydantic.BaseModel):
id: str
source_model_slug: str
target_model_slug: str
reason: Optional[str] = None
node_count: int
# Custom pricing override - billing should use this instead of target model's cost
custom_credit_cost: Optional[int] = None
is_reverted: bool = False
created_at: datetime
reverted_at: Optional[datetime] = None
class LlmMigrationsResponse(pydantic.BaseModel):
migrations: list[LlmModelMigration]
class RevertMigrationRequest(pydantic.BaseModel):
re_enable_source_model: bool = (
True # Whether to re-enable the source model if disabled
)
class RevertMigrationResponse(pydantic.BaseModel):
migration_id: str
source_model_slug: str
target_model_slug: str
nodes_reverted: int
nodes_already_changed: int = (
0 # Nodes that were modified since migration (not reverted)
)
source_model_re_enabled: bool = False # Whether the source model was re-enabled
message: str
class SetRecommendedModelRequest(pydantic.BaseModel):
model_id: str
class SetRecommendedModelResponse(pydantic.BaseModel):
model: LlmModel
previous_recommended_slug: Optional[str] = None
message: str
class RecommendedModelResponse(pydantic.BaseModel):
model: Optional[LlmModel] = None
slug: Optional[str] = None

View File

@@ -0,0 +1,29 @@
import autogpt_libs.auth
import fastapi
from backend.server.v2.llm import db as llm_db
from backend.server.v2.llm import model as llm_model
router = fastapi.APIRouter(
prefix="/llm",
tags=["llm"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
)
@router.get("/models", response_model=llm_model.LlmModelsResponse)
async def list_models(
page: int = fastapi.Query(default=1, ge=1, description="Page number (1-indexed)"),
page_size: int = fastapi.Query(
default=50, ge=1, le=100, description="Number of models per page"
),
):
"""List all enabled LLM models available to users."""
return await llm_db.list_models(enabled_only=True, page=page, page_size=page_size)
@router.get("/providers", response_model=llm_model.LlmProvidersResponse)
async def list_providers():
"""List all LLM providers with their enabled models."""
providers = await llm_db.list_providers(include_models=True, enabled_only=True)
return llm_model.LlmProvidersResponse(providers=providers)

View File

@@ -38,7 +38,6 @@ class Flag(str, Enum):
AGENT_ACTIVITY = "agent-activity"
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
CHAT = "chat"
COPILOT_SDK = "copilot-sdk"
def is_configured() -> bool:

View File

@@ -101,7 +101,7 @@ class HostResolver(abc.AbstractResolver):
def __init__(self, ssl_hostname: str, ip_addresses: list[str]):
self.ssl_hostname = ssl_hostname
self.ip_addresses = ip_addresses
self._default = aiohttp.ThreadedResolver()
self._default = aiohttp.AsyncResolver()
async def resolve(self, host, port=0, family=socket.AF_INET):
if host == self.ssl_hostname:
@@ -467,7 +467,7 @@ class Requests:
resolver = HostResolver(ssl_hostname=hostname, ip_addresses=ip_addresses)
ssl_context = ssl.create_default_context()
connector = aiohttp.TCPConnector(resolver=resolver, ssl=ssl_context)
session_kwargs: dict = {}
session_kwargs = {}
if connector:
session_kwargs["connector"] = connector

View File

@@ -662,17 +662,6 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
mem0_api_key: str = Field(default="", description="Mem0 API key")
elevenlabs_api_key: str = Field(default="", description="ElevenLabs API key")
linear_api_key: str = Field(
default="", description="Linear API key for system-level operations"
)
linear_feature_request_project_id: str = Field(
default="",
description="Linear project ID where feature requests are tracked",
)
linear_feature_request_team_id: str = Field(
default="",
description="Linear team ID used when creating feature request issues",
)
linear_client_id: str = Field(default="", description="Linear client ID")
linear_client_secret: str = Field(default="", description="Linear client secret")

View File

@@ -0,0 +1,81 @@
-- CreateEnum
CREATE TYPE "LlmCostUnit" AS ENUM ('RUN', 'TOKENS');
-- CreateTable
CREATE TABLE "LlmProvider" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"name" TEXT NOT NULL,
"displayName" TEXT NOT NULL,
"description" TEXT,
"defaultCredentialProvider" TEXT,
"defaultCredentialId" TEXT,
"defaultCredentialType" TEXT,
"supportsTools" BOOLEAN NOT NULL DEFAULT TRUE,
"supportsJsonOutput" BOOLEAN NOT NULL DEFAULT TRUE,
"supportsReasoning" BOOLEAN NOT NULL DEFAULT FALSE,
"supportsParallelTool" BOOLEAN NOT NULL DEFAULT FALSE,
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
CONSTRAINT "LlmProvider_pkey" PRIMARY KEY ("id"),
CONSTRAINT "LlmProvider_name_key" UNIQUE ("name")
);
-- CreateTable
CREATE TABLE "LlmModel" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"slug" TEXT NOT NULL,
"displayName" TEXT NOT NULL,
"description" TEXT,
"providerId" TEXT NOT NULL,
"contextWindow" INTEGER NOT NULL,
"maxOutputTokens" INTEGER,
"isEnabled" BOOLEAN NOT NULL DEFAULT TRUE,
"capabilities" JSONB NOT NULL DEFAULT '{}'::jsonb,
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
CONSTRAINT "LlmModel_pkey" PRIMARY KEY ("id"),
CONSTRAINT "LlmModel_slug_key" UNIQUE ("slug")
);
-- CreateTable
CREATE TABLE "LlmModelCost" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"unit" "LlmCostUnit" NOT NULL DEFAULT 'RUN',
"creditCost" INTEGER NOT NULL,
"credentialProvider" TEXT NOT NULL,
"credentialId" TEXT,
"credentialType" TEXT,
"currency" TEXT,
"metadata" JSONB NOT NULL DEFAULT '{}'::jsonb,
"llmModelId" TEXT NOT NULL,
CONSTRAINT "LlmModelCost_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE INDEX "LlmModel_providerId_isEnabled_idx" ON "LlmModel"("providerId", "isEnabled");
-- CreateIndex
CREATE INDEX "LlmModel_slug_idx" ON "LlmModel"("slug");
-- CreateIndex
CREATE INDEX "LlmModelCost_llmModelId_idx" ON "LlmModelCost"("llmModelId");
-- CreateIndex
CREATE INDEX "LlmModelCost_credentialProvider_idx" ON "LlmModelCost"("credentialProvider");
-- CreateIndex
CREATE UNIQUE INDEX "LlmModelCost_llmModelId_credentialProvider_unit_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "unit");
-- AddForeignKey
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_providerId_fkey" FOREIGN KEY ("providerId") REFERENCES "LlmProvider"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LlmModelCost" ADD CONSTRAINT "LlmModelCost_llmModelId_fkey" FOREIGN KEY ("llmModelId") REFERENCES "LlmModel"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -0,0 +1,226 @@
-- Seed LLM Registry from existing hard-coded data
-- This migration populates the LlmProvider, LlmModel, and LlmModelCost tables
-- with data from the existing MODEL_METADATA and MODEL_COST dictionaries
-- Insert Providers
INSERT INTO "LlmProvider" ("id", "name", "displayName", "description", "defaultCredentialProvider", "defaultCredentialType", "supportsTools", "supportsJsonOutput", "supportsReasoning", "supportsParallelTool", "metadata")
VALUES
(gen_random_uuid(), 'openai', 'OpenAI', 'OpenAI language models', 'openai', 'api_key', true, true, true, true, '{}'::jsonb),
(gen_random_uuid(), 'anthropic', 'Anthropic', 'Anthropic Claude models', 'anthropic', 'api_key', true, true, true, false, '{}'::jsonb),
(gen_random_uuid(), 'groq', 'Groq', 'Groq inference API', 'groq', 'api_key', false, true, false, false, '{}'::jsonb),
(gen_random_uuid(), 'open_router', 'OpenRouter', 'OpenRouter unified API', 'open_router', 'api_key', true, true, false, false, '{}'::jsonb),
(gen_random_uuid(), 'aiml_api', 'AI/ML API', 'AI/ML API models', 'aiml_api', 'api_key', false, true, false, false, '{}'::jsonb),
(gen_random_uuid(), 'ollama', 'Ollama', 'Ollama local models', 'ollama', 'api_key', false, true, false, false, '{}'::jsonb),
(gen_random_uuid(), 'llama_api', 'Llama API', 'Llama API models', 'llama_api', 'api_key', false, true, false, false, '{}'::jsonb),
(gen_random_uuid(), 'v0', 'v0', 'v0 by Vercel models', 'v0', 'api_key', true, true, false, false, '{}'::jsonb)
ON CONFLICT ("name") DO NOTHING;
-- Insert Models (using CTEs to reference provider IDs)
WITH provider_ids AS (
SELECT "id", "name" FROM "LlmProvider"
)
INSERT INTO "LlmModel" ("id", "slug", "displayName", "description", "providerId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata")
SELECT
gen_random_uuid(),
model_slug,
model_display_name,
NULL,
p."id",
context_window,
max_output_tokens,
true,
'{}'::jsonb,
'{}'::jsonb
FROM (VALUES
-- OpenAI models
('o3', 'O3', 'openai', 200000, 100000),
('o3-mini', 'O3 Mini', 'openai', 200000, 100000),
('o1', 'O1', 'openai', 200000, 100000),
('o1-mini', 'O1 Mini', 'openai', 128000, 65536),
('gpt-5-2025-08-07', 'GPT 5', 'openai', 400000, 128000),
('gpt-5.1-2025-11-13', 'GPT 5.1', 'openai', 400000, 128000),
('gpt-5-mini-2025-08-07', 'GPT 5 Mini', 'openai', 400000, 128000),
('gpt-5-nano-2025-08-07', 'GPT 5 Nano', 'openai', 400000, 128000),
('gpt-5-chat-latest', 'GPT 5 Chat', 'openai', 400000, 16384),
('gpt-4.1-2025-04-14', 'GPT 4.1', 'openai', 1000000, 32768),
('gpt-4.1-mini-2025-04-14', 'GPT 4.1 Mini', 'openai', 1047576, 32768),
('gpt-4o-mini', 'GPT 4o Mini', 'openai', 128000, 16384),
('gpt-4o', 'GPT 4o', 'openai', 128000, 16384),
('gpt-4-turbo', 'GPT 4 Turbo', 'openai', 128000, 4096),
('gpt-3.5-turbo', 'GPT 3.5 Turbo', 'openai', 16385, 4096),
-- Anthropic models
('claude-opus-4-1-20250805', 'Claude 4.1 Opus', 'anthropic', 200000, 32000),
('claude-opus-4-20250514', 'Claude 4 Opus', 'anthropic', 200000, 32000),
('claude-sonnet-4-20250514', 'Claude 4 Sonnet', 'anthropic', 200000, 64000),
('claude-opus-4-5-20251101', 'Claude 4.5 Opus', 'anthropic', 200000, 64000),
('claude-sonnet-4-5-20250929', 'Claude 4.5 Sonnet', 'anthropic', 200000, 64000),
('claude-haiku-4-5-20251001', 'Claude 4.5 Haiku', 'anthropic', 200000, 64000),
('claude-3-7-sonnet-20250219', 'Claude 3.7 Sonnet', 'anthropic', 200000, 64000),
('claude-3-haiku-20240307', 'Claude 3 Haiku', 'anthropic', 200000, 4096),
-- AI/ML API models
('Qwen/Qwen2.5-72B-Instruct-Turbo', 'Qwen 2.5 72B', 'aiml_api', 32000, 8000),
('nvidia/llama-3.1-nemotron-70b-instruct', 'Llama 3.1 Nemotron 70B', 'aiml_api', 128000, 40000),
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 'Llama 3.3 70B', 'aiml_api', 128000, NULL),
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 'Meta Llama 3.1 70B', 'aiml_api', 131000, 2000),
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 'Llama 3.2 3B', 'aiml_api', 128000, NULL),
-- Groq models
('llama-3.3-70b-versatile', 'Llama 3.3 70B', 'groq', 128000, 32768),
('llama-3.1-8b-instant', 'Llama 3.1 8B', 'groq', 128000, 8192),
-- Ollama models
('llama3.3', 'Llama 3.3', 'ollama', 8192, NULL),
('llama3.2', 'Llama 3.2', 'ollama', 8192, NULL),
('llama3', 'Llama 3', 'ollama', 8192, NULL),
('llama3.1:405b', 'Llama 3.1 405B', 'ollama', 8192, NULL),
('dolphin-mistral:latest', 'Dolphin Mistral', 'ollama', 32768, NULL),
-- OpenRouter models
('google/gemini-2.5-pro-preview-03-25', 'Gemini 2.5 Pro', 'open_router', 1050000, 8192),
('google/gemini-3-pro-preview', 'Gemini 3 Pro Preview', 'open_router', 1048576, 65535),
('google/gemini-2.5-flash', 'Gemini 2.5 Flash', 'open_router', 1048576, 65535),
('google/gemini-2.0-flash-001', 'Gemini 2.0 Flash', 'open_router', 1048576, 8192),
('google/gemini-2.5-flash-lite-preview-06-17', 'Gemini 2.5 Flash Lite Preview', 'open_router', 1048576, 65535),
('google/gemini-2.0-flash-lite-001', 'Gemini 2.0 Flash Lite', 'open_router', 1048576, 8192),
('mistralai/mistral-nemo', 'Mistral Nemo', 'open_router', 128000, 4096),
('cohere/command-r-08-2024', 'Command R', 'open_router', 128000, 4096),
('cohere/command-r-plus-08-2024', 'Command R Plus', 'open_router', 128000, 4096),
('deepseek/deepseek-chat', 'DeepSeek Chat', 'open_router', 64000, 2048),
('deepseek/deepseek-r1-0528', 'DeepSeek R1', 'open_router', 163840, 163840),
('perplexity/sonar', 'Perplexity Sonar', 'open_router', 127000, 8000),
('perplexity/sonar-pro', 'Perplexity Sonar Pro', 'open_router', 200000, 8000),
('perplexity/sonar-deep-research', 'Perplexity Sonar Deep Research', 'open_router', 128000, 16000),
('nousresearch/hermes-3-llama-3.1-405b', 'Hermes 3 Llama 3.1 405B', 'open_router', 131000, 4096),
('nousresearch/hermes-3-llama-3.1-70b', 'Hermes 3 Llama 3.1 70B', 'open_router', 12288, 12288),
('openai/gpt-oss-120b', 'GPT OSS 120B', 'open_router', 131072, 131072),
('openai/gpt-oss-20b', 'GPT OSS 20B', 'open_router', 131072, 32768),
('amazon/nova-lite-v1', 'Amazon Nova Lite', 'open_router', 300000, 5120),
('amazon/nova-micro-v1', 'Amazon Nova Micro', 'open_router', 128000, 5120),
('amazon/nova-pro-v1', 'Amazon Nova Pro', 'open_router', 300000, 5120),
('microsoft/wizardlm-2-8x22b', 'WizardLM 2 8x22B', 'open_router', 65536, 4096),
('gryphe/mythomax-l2-13b', 'MythoMax L2 13B', 'open_router', 4096, 4096),
('meta-llama/llama-4-scout', 'Llama 4 Scout', 'open_router', 131072, 131072),
('meta-llama/llama-4-maverick', 'Llama 4 Maverick', 'open_router', 1048576, 1000000),
('x-ai/grok-4', 'Grok 4', 'open_router', 256000, 256000),
('x-ai/grok-4-fast', 'Grok 4 Fast', 'open_router', 2000000, 30000),
('x-ai/grok-4.1-fast', 'Grok 4.1 Fast', 'open_router', 2000000, 30000),
('x-ai/grok-code-fast-1', 'Grok Code Fast 1', 'open_router', 256000, 10000),
('moonshotai/kimi-k2', 'Kimi K2', 'open_router', 131000, 131000),
('qwen/qwen3-235b-a22b-thinking-2507', 'Qwen 3 235B Thinking', 'open_router', 262144, 262144),
('qwen/qwen3-coder', 'Qwen 3 Coder', 'open_router', 262144, 262144),
-- Llama API models
('Llama-4-Scout-17B-16E-Instruct-FP8', 'Llama 4 Scout', 'llama_api', 128000, 4028),
('Llama-4-Maverick-17B-128E-Instruct-FP8', 'Llama 4 Maverick', 'llama_api', 128000, 4028),
('Llama-3.3-8B-Instruct', 'Llama 3.3 8B', 'llama_api', 128000, 4028),
('Llama-3.3-70B-Instruct', 'Llama 3.3 70B', 'llama_api', 128000, 4028),
-- v0 models
('v0-1.5-md', 'v0 1.5 MD', 'v0', 128000, 64000),
('v0-1.5-lg', 'v0 1.5 LG', 'v0', 512000, 64000),
('v0-1.0-md', 'v0 1.0 MD', 'v0', 128000, 64000)
) AS models(model_slug, model_display_name, provider_name, context_window, max_output_tokens)
JOIN provider_ids p ON p."name" = models.provider_name
ON CONFLICT ("slug") DO NOTHING;
-- Insert Costs (using CTEs to reference model IDs)
WITH model_ids AS (
SELECT "id", "slug", "providerId" FROM "LlmModel"
),
provider_ids AS (
SELECT "id", "name" FROM "LlmProvider"
)
INSERT INTO "LlmModelCost" ("id", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId")
SELECT
gen_random_uuid(),
'RUN'::"LlmCostUnit",
cost,
p."name",
NULL,
'api_key',
NULL,
'{}'::jsonb,
m."id"
FROM (VALUES
-- OpenAI costs
('o3', 4),
('o3-mini', 2),
('o1', 16),
('o1-mini', 4),
('gpt-5-2025-08-07', 2),
('gpt-5.1-2025-11-13', 5),
('gpt-5-mini-2025-08-07', 1),
('gpt-5-nano-2025-08-07', 1),
('gpt-5-chat-latest', 5),
('gpt-4.1-2025-04-14', 2),
('gpt-4.1-mini-2025-04-14', 1),
('gpt-4o-mini', 1),
('gpt-4o', 3),
('gpt-4-turbo', 10),
('gpt-3.5-turbo', 1),
-- Anthropic costs
('claude-opus-4-1-20250805', 21),
('claude-opus-4-20250514', 21),
('claude-sonnet-4-20250514', 5),
('claude-haiku-4-5-20251001', 4),
('claude-opus-4-5-20251101', 14),
('claude-sonnet-4-5-20250929', 9),
('claude-3-7-sonnet-20250219', 5),
('claude-3-haiku-20240307', 1),
-- AI/ML API costs
('Qwen/Qwen2.5-72B-Instruct-Turbo', 1),
('nvidia/llama-3.1-nemotron-70b-instruct', 1),
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 1),
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 1),
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 1),
-- Groq costs
('llama-3.3-70b-versatile', 1),
('llama-3.1-8b-instant', 1),
-- Ollama costs
('llama3.3', 1),
('llama3.2', 1),
('llama3', 1),
('llama3.1:405b', 1),
('dolphin-mistral:latest', 1),
-- OpenRouter costs
('google/gemini-2.5-pro-preview-03-25', 4),
('google/gemini-3-pro-preview', 5),
('mistralai/mistral-nemo', 1),
('cohere/command-r-08-2024', 1),
('cohere/command-r-plus-08-2024', 3),
('deepseek/deepseek-chat', 2),
('perplexity/sonar', 1),
('perplexity/sonar-pro', 5),
('perplexity/sonar-deep-research', 10),
('nousresearch/hermes-3-llama-3.1-405b', 1),
('nousresearch/hermes-3-llama-3.1-70b', 1),
('amazon/nova-lite-v1', 1),
('amazon/nova-micro-v1', 1),
('amazon/nova-pro-v1', 1),
('microsoft/wizardlm-2-8x22b', 1),
('gryphe/mythomax-l2-13b', 1),
('meta-llama/llama-4-scout', 1),
('meta-llama/llama-4-maverick', 1),
('x-ai/grok-4', 9),
('x-ai/grok-4-fast', 1),
('x-ai/grok-4.1-fast', 1),
('x-ai/grok-code-fast-1', 1),
('moonshotai/kimi-k2', 1),
('qwen/qwen3-235b-a22b-thinking-2507', 1),
('qwen/qwen3-coder', 9),
('google/gemini-2.5-flash', 1),
('google/gemini-2.0-flash-001', 1),
('google/gemini-2.5-flash-lite-preview-06-17', 1),
('google/gemini-2.0-flash-lite-001', 1),
('deepseek/deepseek-r1-0528', 1),
('openai/gpt-oss-120b', 1),
('openai/gpt-oss-20b', 1),
-- Llama API costs
('Llama-4-Scout-17B-16E-Instruct-FP8', 1),
('Llama-4-Maverick-17B-128E-Instruct-FP8', 1),
('Llama-3.3-8B-Instruct', 1),
('Llama-3.3-70B-Instruct', 1),
-- v0 costs
('v0-1.5-md', 1),
('v0-1.5-lg', 2),
('v0-1.0-md', 1)
) AS costs(model_slug, cost)
JOIN model_ids m ON m."slug" = costs.model_slug
JOIN provider_ids p ON p."id" = m."providerId"
ON CONFLICT ("llmModelId", "credentialProvider", "unit") DO NOTHING;

View File

@@ -0,0 +1,25 @@
-- CreateTable
CREATE TABLE "LlmModelMigration" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"sourceModelSlug" TEXT NOT NULL,
"targetModelSlug" TEXT NOT NULL,
"reason" TEXT,
"migratedNodeIds" JSONB NOT NULL DEFAULT '[]',
"nodeCount" INTEGER NOT NULL,
"customCreditCost" INTEGER,
"isReverted" BOOLEAN NOT NULL DEFAULT false,
"revertedAt" TIMESTAMP(3),
CONSTRAINT "LlmModelMigration_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE INDEX "LlmModelMigration_sourceModelSlug_idx" ON "LlmModelMigration"("sourceModelSlug");
-- CreateIndex
CREATE INDEX "LlmModelMigration_targetModelSlug_idx" ON "LlmModelMigration"("targetModelSlug");
-- CreateIndex
CREATE INDEX "LlmModelMigration_isReverted_idx" ON "LlmModelMigration"("isReverted");

View File

@@ -0,0 +1,127 @@
-- Add LlmModelCreator table
-- Creator represents who made/trained the model (e.g., OpenAI, Meta)
-- This is distinct from Provider who hosts/serves the model (e.g., OpenRouter)
-- Create the LlmModelCreator table
CREATE TABLE "LlmModelCreator" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"name" TEXT NOT NULL,
"displayName" TEXT NOT NULL,
"description" TEXT,
"websiteUrl" TEXT,
"logoUrl" TEXT,
"metadata" JSONB NOT NULL DEFAULT '{}',
CONSTRAINT "LlmModelCreator_pkey" PRIMARY KEY ("id")
);
-- Create unique index on name
CREATE UNIQUE INDEX "LlmModelCreator_name_key" ON "LlmModelCreator"("name");
-- Add creatorId column to LlmModel
ALTER TABLE "LlmModel" ADD COLUMN "creatorId" TEXT;
-- Add foreign key constraint
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_creatorId_fkey"
FOREIGN KEY ("creatorId") REFERENCES "LlmModelCreator"("id") ON DELETE SET NULL ON UPDATE CASCADE;
-- Create index on creatorId
CREATE INDEX "LlmModel_creatorId_idx" ON "LlmModel"("creatorId");
-- Seed creators based on known model creators
INSERT INTO "LlmModelCreator" ("id", "updatedAt", "name", "displayName", "description", "websiteUrl", "metadata")
VALUES
(gen_random_uuid(), CURRENT_TIMESTAMP, 'openai', 'OpenAI', 'Creator of GPT models', 'https://openai.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'anthropic', 'Anthropic', 'Creator of Claude models', 'https://anthropic.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'meta', 'Meta', 'Creator of Llama models', 'https://ai.meta.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'google', 'Google', 'Creator of Gemini models', 'https://deepmind.google', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'mistral', 'Mistral AI', 'Creator of Mistral models', 'https://mistral.ai', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'cohere', 'Cohere', 'Creator of Command models', 'https://cohere.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'deepseek', 'DeepSeek', 'Creator of DeepSeek models', 'https://deepseek.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'perplexity', 'Perplexity AI', 'Creator of Sonar models', 'https://perplexity.ai', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'qwen', 'Qwen (Alibaba)', 'Creator of Qwen models', 'https://qwenlm.github.io', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'xai', 'xAI', 'Creator of Grok models', 'https://x.ai', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'amazon', 'Amazon', 'Creator of Nova models', 'https://aws.amazon.com/bedrock', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'microsoft', 'Microsoft', 'Creator of WizardLM models', 'https://microsoft.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'moonshot', 'Moonshot AI', 'Creator of Kimi models', 'https://moonshot.cn', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'nvidia', 'NVIDIA', 'Creator of Nemotron models', 'https://nvidia.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'nous_research', 'Nous Research', 'Creator of Hermes models', 'https://nousresearch.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'vercel', 'Vercel', 'Creator of v0 models', 'https://vercel.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'cognitive_computations', 'Cognitive Computations', 'Creator of Dolphin models', 'https://erichartford.com', '{}'),
(gen_random_uuid(), CURRENT_TIMESTAMP, 'gryphe', 'Gryphe', 'Creator of MythoMax models', 'https://huggingface.co/Gryphe', '{}')
ON CONFLICT ("name") DO NOTHING;
-- Update existing models with their creators
-- OpenAI models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'openai')
WHERE "slug" LIKE 'gpt-%' OR "slug" LIKE 'o1%' OR "slug" LIKE 'o3%' OR "slug" LIKE 'openai/%';
-- Anthropic models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'anthropic')
WHERE "slug" LIKE 'claude-%';
-- Meta/Llama models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'meta')
WHERE "slug" LIKE 'llama%' OR "slug" LIKE 'Llama%' OR "slug" LIKE 'meta-llama/%' OR "slug" LIKE '%/llama-%';
-- Google models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'google')
WHERE "slug" LIKE 'google/%' OR "slug" LIKE 'gemini%';
-- Mistral models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'mistral')
WHERE "slug" LIKE 'mistral%' OR "slug" LIKE 'mistralai/%';
-- Cohere models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'cohere')
WHERE "slug" LIKE 'cohere/%' OR "slug" LIKE 'command-%';
-- DeepSeek models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'deepseek')
WHERE "slug" LIKE 'deepseek/%' OR "slug" LIKE 'deepseek-%';
-- Perplexity models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'perplexity')
WHERE "slug" LIKE 'perplexity/%' OR "slug" LIKE 'sonar%';
-- Qwen models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'qwen')
WHERE "slug" LIKE 'Qwen/%' OR "slug" LIKE 'qwen/%';
-- xAI/Grok models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'xai')
WHERE "slug" LIKE 'x-ai/%' OR "slug" LIKE 'grok%';
-- Amazon models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'amazon')
WHERE "slug" LIKE 'amazon/%' OR "slug" LIKE 'nova-%';
-- Microsoft models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'microsoft')
WHERE "slug" LIKE 'microsoft/%' OR "slug" LIKE 'wizardlm%';
-- Moonshot models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'moonshot')
WHERE "slug" LIKE 'moonshotai/%' OR "slug" LIKE 'kimi%';
-- NVIDIA models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'nvidia')
WHERE "slug" LIKE 'nvidia/%' OR "slug" LIKE '%nemotron%';
-- Nous Research models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'nous_research')
WHERE "slug" LIKE 'nousresearch/%' OR "slug" LIKE 'hermes%';
-- Vercel/v0 models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'vercel')
WHERE "slug" LIKE 'v0-%';
-- Dolphin models (Cognitive Computations / Eric Hartford)
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'cognitive_computations')
WHERE "slug" LIKE 'dolphin-%';
-- Gryphe models
UPDATE "LlmModel" SET "creatorId" = (SELECT "id" FROM "LlmModelCreator" WHERE "name" = 'gryphe')
WHERE "slug" LIKE 'gryphe/%' OR "slug" LIKE 'mythomax%';

View File

@@ -0,0 +1,4 @@
-- CreateIndex
-- Index for efficient LLM model lookups on AgentNode.constantInput->>'model'
-- This improves performance of model migration queries in the LLM registry
CREATE INDEX "AgentNode_constantInput_model_idx" ON "AgentNode" ((("constantInput"->>'model')));

View File

@@ -0,0 +1,52 @@
-- Add GPT-5.2 model and update O3 slug
-- This migration adds the new GPT-5.2 model added in dev branch
-- Update O3 slug to match dev branch format
UPDATE "LlmModel"
SET "slug" = 'o3-2025-04-16'
WHERE "slug" = 'o3';
-- Update cost reference for O3 if needed
-- (costs are linked by model ID, so no update needed)
-- Add GPT-5.2 model
WITH provider_id AS (
SELECT "id" FROM "LlmProvider" WHERE "name" = 'openai'
)
INSERT INTO "LlmModel" ("id", "slug", "displayName", "description", "providerId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata")
SELECT
gen_random_uuid(),
'gpt-5.2-2025-12-11',
'GPT 5.2',
'OpenAI GPT-5.2 model',
p."id",
400000,
128000,
true,
'{}'::jsonb,
'{}'::jsonb
FROM provider_id p
ON CONFLICT ("slug") DO NOTHING;
-- Add cost for GPT-5.2
WITH model_id AS (
SELECT m."id", p."name" as provider_name
FROM "LlmModel" m
JOIN "LlmProvider" p ON p."id" = m."providerId"
WHERE m."slug" = 'gpt-5.2-2025-12-11'
)
INSERT INTO "LlmModelCost" ("id", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId")
SELECT
gen_random_uuid(),
'RUN'::"LlmCostUnit",
3, -- Same cost tier as GPT-5.1
m.provider_name,
NULL,
'api_key',
NULL,
'{}'::jsonb,
m."id"
FROM model_id m
WHERE NOT EXISTS (
SELECT 1 FROM "LlmModelCost" c WHERE c."llmModelId" = m."id"
);

View File

@@ -0,0 +1,11 @@
-- Add isRecommended field to LlmModel table
-- This allows admins to mark a model as the recommended default
ALTER TABLE "LlmModel" ADD COLUMN "isRecommended" BOOLEAN NOT NULL DEFAULT false;
-- Set gpt-4o-mini as the default recommended model (if it exists)
UPDATE "LlmModel" SET "isRecommended" = true WHERE "slug" = 'gpt-4o-mini' AND "isEnabled" = true;
-- Create unique partial index to enforce only one recommended model at the database level
-- This prevents multiple rows from having isRecommended = true
CREATE UNIQUE INDEX "LlmModel_single_recommended_idx" ON "LlmModel" ("isRecommended") WHERE "isRecommended" = true;

View File

@@ -0,0 +1,61 @@
-- Add new columns to LlmModel table for extended model metadata
-- These columns support the LLM Picker UI enhancements
-- Add priceTier column: 1=cheapest, 2=medium, 3=expensive
ALTER TABLE "LlmModel" ADD COLUMN IF NOT EXISTS "priceTier" INTEGER NOT NULL DEFAULT 1;
-- Add creatorId column for model creator relationship (if not exists)
ALTER TABLE "LlmModel" ADD COLUMN IF NOT EXISTS "creatorId" TEXT;
-- Add isRecommended column (if not exists)
ALTER TABLE "LlmModel" ADD COLUMN IF NOT EXISTS "isRecommended" BOOLEAN NOT NULL DEFAULT FALSE;
-- Add index on creatorId if not exists
CREATE INDEX IF NOT EXISTS "LlmModel_creatorId_idx" ON "LlmModel"("creatorId");
-- Add foreign key for creatorId if not exists
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'LlmModel_creatorId_fkey') THEN
-- Only add FK if LlmModelCreator table exists
IF EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'LlmModelCreator') THEN
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_creatorId_fkey"
FOREIGN KEY ("creatorId") REFERENCES "LlmModelCreator"("id") ON DELETE SET NULL ON UPDATE CASCADE;
END IF;
END IF;
END $$;
-- Update priceTier values for existing models based on original MODEL_METADATA
-- Tier 1 = cheapest, Tier 2 = medium, Tier 3 = expensive
-- OpenAI models
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'o3';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'o3-mini';
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'o1';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'o1-mini';
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'gpt-5.2';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'gpt-5.1';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-5';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-5-mini';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-5-nano';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'gpt-5-chat-latest';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" LIKE 'gpt-4.1%';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-4o-mini';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" = 'gpt-4o';
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'gpt-4-turbo';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'gpt-3.5-turbo';
-- Anthropic models
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE 'claude-opus%';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE 'claude-sonnet%';
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE 'claude%-4-5-sonnet%';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE 'claude%-haiku%';
UPDATE "LlmModel" SET "priceTier" = 1 WHERE "slug" = 'claude-3-haiku-20240307';
-- OpenRouter models - Pro/expensive tiers
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE 'google/gemini%-pro%';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE '%command-r-plus%';
UPDATE "LlmModel" SET "priceTier" = 2 WHERE "slug" LIKE '%sonar-pro%';
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE '%sonar-deep-research%';
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" = 'x-ai/grok-4';
UPDATE "LlmModel" SET "priceTier" = 3 WHERE "slug" LIKE '%qwen3-coder%';

View File

@@ -0,0 +1,6 @@
-- Add composite index on LlmModelMigration for optimized active migration queries
-- This index improves performance when querying for non-reverted migrations by model slug
-- Used by the billing system to apply customCreditCost overrides
-- CreateIndex
CREATE INDEX "LlmModelMigration_sourceModelSlug_isReverted_idx" ON "LlmModelMigration"("sourceModelSlug", "isReverted");

View File

@@ -0,0 +1,65 @@
-- Sync LLM models with latest dev branch changes
-- This migration adds new models and removes deprecated ones
-- Remove models that were deleted from dev
DELETE FROM "LlmModelCost" WHERE "llmModelId" IN (
SELECT "id" FROM "LlmModel" WHERE "slug" IN ('o3', 'o3-mini', 'claude-3-7-sonnet-20250219')
);
DELETE FROM "LlmModel" WHERE "slug" IN ('o3', 'o3-mini', 'claude-3-7-sonnet-20250219');
-- Add new models from dev
WITH provider_ids AS (
SELECT "id", "name" FROM "LlmProvider"
)
INSERT INTO "LlmModel" ("id", "slug", "displayName", "description", "providerId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata", "createdAt", "updatedAt")
SELECT
gen_random_uuid(),
model_slug,
model_display_name,
NULL,
p."id",
context_window,
max_output_tokens,
true,
'{}'::jsonb,
'{}'::jsonb,
NOW(),
NOW()
FROM (VALUES
-- New OpenAI model
('gpt-5.2-2025-12-11', 'GPT 5.2', 'openai', 400000, 128000),
-- New Anthropic model
('claude-opus-4-6', 'Claude 4.6 Opus', 'anthropic', 200000, 64000)
) AS models(model_slug, model_display_name, provider_name, context_window, max_output_tokens)
JOIN provider_ids p ON p."name" = models.provider_name
ON CONFLICT ("slug") DO NOTHING;
-- Add costs for new models
WITH model_ids AS (
SELECT "id", "slug", "providerId" FROM "LlmModel"
),
provider_ids AS (
SELECT "id", "name" FROM "LlmProvider"
)
INSERT INTO "LlmModelCost" ("id", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId", "createdAt", "updatedAt")
SELECT
gen_random_uuid(),
'RUN'::"LlmCostUnit",
cost,
p."name",
NULL,
'api_key',
NULL,
'{}'::jsonb,
m."id",
NOW(),
NOW()
FROM (VALUES
-- New model costs (estimate based on similar models)
('gpt-5.2-2025-12-11', 5), -- Similar to GPT 5.1
('claude-opus-4-6', 21) -- Similar to other Opus 4.x models
) AS costs(model_slug, cost)
JOIN model_ids m ON m."slug" = costs.model_slug
JOIN provider_ids p ON p."id" = m."providerId"
ON CONFLICT ("llmModelId", "credentialProvider", "unit") DO NOTHING;

View File

@@ -897,29 +897,6 @@ files = [
{file = "charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a"},
]
[[package]]
name = "claude-agent-sdk"
version = "0.1.35"
description = "Python SDK for Claude Code"
optional = false
python-versions = ">=3.10"
groups = ["main"]
files = [
{file = "claude_agent_sdk-0.1.35-py3-none-macosx_11_0_arm64.whl", hash = "sha256:df67f4deade77b16a9678b3a626c176498e40417f33b04beda9628287f375591"},
{file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:14963944f55ded7c8ed518feebfa5b4284aa6dd8d81aeff2e5b21a962ce65097"},
{file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:84344dcc535d179c1fc8a11c6f34c37c3b583447bdf09d869effb26514fd7a65"},
{file = "claude_agent_sdk-0.1.35-py3-none-win_amd64.whl", hash = "sha256:1b3d54b47448c93f6f372acd4d1757f047c3c1e8ef5804be7a1e3e53e2c79a5f"},
{file = "claude_agent_sdk-0.1.35.tar.gz", hash = "sha256:0f98e2b3c71ca85abfc042e7a35c648df88e87fda41c52e6779ef7b038dcbb52"},
]
[package.dependencies]
anyio = ">=4.0.0"
mcp = ">=0.1.0"
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""}
[package.extras]
dev = ["anyio[trio] (>=4.0.0)", "mypy (>=1.0.0)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.20.0)", "pytest-cov (>=4.0.0)", "ruff (>=0.1.0)"]
[[package]]
name = "cleo"
version = "2.1.0"
@@ -2616,18 +2593,6 @@ http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "httpx-sse"
version = "0.4.3"
description = "Consume Server-Sent Event (SSE) messages with HTTPX."
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc"},
{file = "httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d"},
]
[[package]]
name = "huggingface-hub"
version = "1.4.1"
@@ -3345,39 +3310,6 @@ files = [
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
]
[[package]]
name = "mcp"
version = "1.26.0"
description = "Model Context Protocol SDK"
optional = false
python-versions = ">=3.10"
groups = ["main"]
files = [
{file = "mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca"},
{file = "mcp-1.26.0.tar.gz", hash = "sha256:db6e2ef491eecc1a0d93711a76f28dec2e05999f93afd48795da1c1137142c66"},
]
[package.dependencies]
anyio = ">=4.5"
httpx = ">=0.27.1"
httpx-sse = ">=0.4"
jsonschema = ">=4.20.0"
pydantic = ">=2.11.0,<3.0.0"
pydantic-settings = ">=2.5.2"
pyjwt = {version = ">=2.10.1", extras = ["crypto"]}
python-multipart = ">=0.0.9"
pywin32 = {version = ">=310", markers = "sys_platform == \"win32\""}
sse-starlette = ">=1.6.1"
starlette = ">=0.27"
typing-extensions = ">=4.9.0"
typing-inspection = ">=0.4.1"
uvicorn = {version = ">=0.31.1", markers = "sys_platform != \"emscripten\""}
[package.extras]
cli = ["python-dotenv (>=1.0.0)", "typer (>=0.16.0)"]
rich = ["rich (>=13.9.4)"]
ws = ["websockets (>=15.0.1)"]
[[package]]
name = "mdurl"
version = "0.1.2"
@@ -6062,7 +5994,7 @@ description = "Python for Window Extensions"
optional = false
python-versions = "*"
groups = ["main"]
markers = "sys_platform == \"win32\" or platform_system == \"Windows\""
markers = "platform_system == \"Windows\""
files = [
{file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"},
{file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"},
@@ -7042,28 +6974,6 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"]
pymysql = ["pymysql"]
sqlcipher = ["sqlcipher3_binary"]
[[package]]
name = "sse-starlette"
version = "3.2.0"
description = "SSE plugin for Starlette"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "sse_starlette-3.2.0-py3-none-any.whl", hash = "sha256:5876954bd51920fc2cd51baee47a080eb88a37b5b784e615abb0b283f801cdbf"},
{file = "sse_starlette-3.2.0.tar.gz", hash = "sha256:8127594edfb51abe44eac9c49e59b0b01f1039d0c7461c6fd91d4e03b70da422"},
]
[package.dependencies]
anyio = ">=4.7.0"
starlette = ">=0.49.1"
[package.extras]
daphne = ["daphne (>=4.2.0)"]
examples = ["aiosqlite (>=0.21.0)", "fastapi (>=0.115.12)", "sqlalchemy[asyncio] (>=2.0.41)", "uvicorn (>=0.34.0)"]
granian = ["granian (>=2.3.1)"]
uvicorn = ["uvicorn (>=0.34.0)"]
[[package]]
name = "stagehand"
version = "0.5.9"
@@ -8530,4 +8440,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.14"
content-hash = "55e095de555482f0fe47de7695f390fe93e7bcf739b31c391b2e5e3c3d938ae3"
content-hash = "fa9c5deadf593e815dd2190f58e22152373900603f5f244b9616cd721de84d2f"

View File

@@ -16,7 +16,6 @@ anthropic = "^0.79.0"
apscheduler = "^3.11.1"
autogpt-libs = { path = "../autogpt_libs", develop = true }
bleach = { extras = ["css"], version = "^6.2.0" }
claude-agent-sdk = "^0.1.0"
click = "^8.2.0"
cryptography = "^46.0"
discord-py = "^2.5.2"

View File

@@ -1143,6 +1143,154 @@ enum APIKeyStatus {
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
///////////// LLM REGISTRY AND BILLING DATA /////////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
// LlmCostUnit: Defines how LLM MODEL costs are calculated (per run or per token).
// This is distinct from BlockCostType (in backend/data/block.py) which defines
// how BLOCK EXECUTION costs are calculated (per run, per byte, or per second).
// LlmCostUnit is for pricing individual LLM model API calls in the registry,
// while BlockCostType is for billing platform block executions.
enum LlmCostUnit {
RUN
TOKENS
}
model LlmModelCreator {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
name String @unique // e.g., "openai", "anthropic", "meta"
displayName String // e.g., "OpenAI", "Anthropic", "Meta"
description String?
websiteUrl String? // Link to creator's website
logoUrl String? // URL to creator's logo
metadata Json @default("{}")
Models LlmModel[]
}
model LlmProvider {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
name String @unique
displayName String
description String?
defaultCredentialProvider String?
defaultCredentialId String?
defaultCredentialType String?
supportsTools Boolean @default(true)
supportsJsonOutput Boolean @default(true)
supportsReasoning Boolean @default(false)
supportsParallelTool Boolean @default(false)
metadata Json @default("{}")
Models LlmModel[]
}
model LlmModel {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
slug String @unique
displayName String
description String?
providerId String
Provider LlmProvider @relation(fields: [providerId], references: [id], onDelete: Restrict)
// Creator is the organization that created/trained the model (e.g., OpenAI, Meta)
// This is distinct from the provider who hosts/serves the model (e.g., OpenRouter)
creatorId String?
Creator LlmModelCreator? @relation(fields: [creatorId], references: [id], onDelete: SetNull)
contextWindow Int
maxOutputTokens Int?
priceTier Int @default(1) // 1=cheapest, 2=medium, 3=expensive
isEnabled Boolean @default(true)
isRecommended Boolean @default(false)
capabilities Json @default("{}")
metadata Json @default("{}")
Costs LlmModelCost[]
@@index([providerId, isEnabled])
@@index([creatorId])
@@index([slug])
}
model LlmModelCost {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
unit LlmCostUnit @default(RUN)
creditCost Int
credentialProvider String
credentialId String?
credentialType String?
currency String?
metadata Json @default("{}")
llmModelId String
Model LlmModel @relation(fields: [llmModelId], references: [id], onDelete: Cascade)
@@unique([llmModelId, credentialProvider, unit])
@@index([llmModelId])
@@index([credentialProvider])
}
// Tracks model migrations for revert capability
// When a model is disabled with migration, we record which nodes were affected
// so they can be reverted when the original model is back online
model LlmModelMigration {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
sourceModelSlug String // The original model that was disabled
targetModelSlug String // The model workflows were migrated to
reason String? // Why the migration happened (e.g., "Provider outage")
// Track affected nodes as JSON array of node IDs
// Format: ["node-uuid-1", "node-uuid-2", ...]
migratedNodeIds Json @default("[]")
nodeCount Int // Number of nodes migrated
// Custom pricing override for migrated workflows during the migration period.
// Use case: When migrating users from an expensive model (e.g., GPT-4) to a cheaper
// one (e.g., GPT-3.5), you may want to temporarily maintain the original pricing
// to avoid billing surprises, or offer a discount during the transition.
//
// IMPORTANT: This field is intended for integration with the billing system.
// When billing calculates costs for nodes affected by this migration, it should
// check if customCreditCost is set and use it instead of the target model's cost.
// If null, the target model's normal cost applies.
//
// TODO: Integrate with billing system to apply this override during cost calculation.
customCreditCost Int?
// Revert tracking
isReverted Boolean @default(false)
revertedAt DateTime?
@@index([sourceModelSlug])
@@index([targetModelSlug])
@@index([isReverted])
@@index([sourceModelSlug, isReverted]) // Composite index for active migration queries
}
////////////// OAUTH PROVIDER TABLES //////////////////
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////

View File

@@ -1,66 +0,0 @@
from typing import cast
import pytest
from backend.blocks.jina._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
JinaCredentialsInput,
)
from backend.blocks.jina.search import ExtractWebsiteContentBlock
from backend.util.request import HTTPClientError
@pytest.mark.asyncio
async def test_extract_website_content_returns_content(monkeypatch):
block = ExtractWebsiteContentBlock()
input_data = block.Input(
url="https://example.com",
credentials=cast(JinaCredentialsInput, TEST_CREDENTIALS_INPUT),
raw_content=True,
)
async def fake_get_request(url, json=False, headers=None):
assert url == "https://example.com"
assert headers == {}
return "page content"
monkeypatch.setattr(block, "get_request", fake_get_request)
results = [
output
async for output in block.run(
input_data=input_data, credentials=TEST_CREDENTIALS
)
]
assert ("content", "page content") in results
assert all(key != "error" for key, _ in results)
@pytest.mark.asyncio
async def test_extract_website_content_handles_http_error(monkeypatch):
block = ExtractWebsiteContentBlock()
input_data = block.Input(
url="https://example.com",
credentials=cast(JinaCredentialsInput, TEST_CREDENTIALS_INPUT),
raw_content=False,
)
async def fake_get_request(url, json=False, headers=None):
raise HTTPClientError("HTTP 400 Error: Bad Request", 400)
monkeypatch.setattr(block, "get_request", fake_get_request)
results = [
output
async for output in block.run(
input_data=input_data, credentials=TEST_CREDENTIALS
)
]
assert ("content", "page content") not in results
error_messages = [value for key, value in results if key == "error"]
assert error_messages
assert "Client error (400)" in error_messages[0]
assert "https://example.com" in error_messages[0]

Some files were not shown because too many files have changed in this diff Show More